Source code for shapiq.plot.scatter

"""Scatter (a.k.a. dependence) plot for :class:`~shapiq.InteractionValues`.

Plots the per-sample interaction value of a chosen interaction tuple against
the value of one feature. For first-order interactions this matches
``shap.plots.scatter``; for higher-order interactions the x-axis is restricted
to a single feature (selected from the interaction tuple).
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from shapiq.interaction_values import InteractionValues, aggregate_interaction_values

from .beeswarm import _get_red_blue_cmap
from .utils import abbreviate_feature_names

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure


__all__ = ["scatter_plot"]


def _resolve_feature(
    feature: int | str | np.integer,
    name_to_idx: dict[str, int],
    n_players: int,
) -> int:
    """Resolves a feature identifier (index or name) to an integer index."""
    if isinstance(feature, int | np.integer) and not isinstance(feature, bool):
        idx = int(feature)
        if not 0 <= idx < n_players:
            error_message = f"Feature index {idx} out of range [0, {n_players})."
            raise ValueError(error_message)
        return idx
    if isinstance(feature, str):
        if feature not in name_to_idx:
            error_message = f"Unknown feature name: {feature!r}."
            raise ValueError(error_message)
        return name_to_idx[feature]
    error_message = f"Feature identifier must be int or str, got {type(feature).__name__}."
    raise TypeError(error_message)


def _resolve_interaction(
    interaction: tuple[int, ...] | tuple[str, ...] | int | str | None,
    interaction_values_list: list[InteractionValues],
    name_to_idx: dict[str, int],
    n_players: int,
) -> tuple[int, ...]:
    """Resolves an ``interaction`` argument to a sorted tuple of feature indices."""
    if interaction is None:
        agg = aggregate_interaction_values(
            [abs(iv) for iv in interaction_values_list], aggregation="mean"
        )
        candidates = [(k, v) for k, v in agg.interactions.items() if len(k) >= 1]
        if not candidates:
            error_message = "No non-empty interactions available to plot."
            raise ValueError(error_message)
        candidates.sort(key=lambda kv: kv[1], reverse=True)
        return candidates[0][0]

    if isinstance(interaction, int | np.integer | str):
        return (_resolve_feature(interaction, name_to_idx, n_players),)

    if isinstance(interaction, tuple):
        resolved = tuple(sorted(_resolve_feature(f, name_to_idx, n_players) for f in interaction))
        if len(resolved) == 0:
            error_message = "interaction tuple must contain at least one feature."
            raise ValueError(error_message)
        return resolved

    error_message = (
        f"interaction must be a tuple, int, str, or None. Got {type(interaction).__name__}."
    )
    raise TypeError(error_message)


[docs] def scatter_plot( interaction_values_list: list[InteractionValues], data: pd.DataFrame | np.ndarray, interaction: tuple[int, ...] | tuple[str, ...] | int | str | None = None, *, x_feature: int | str | None = None, color: int | str | None = None, feature_names: list[str] | None = None, abbreviate: bool = True, alpha: float = 0.8, dot_size: float = 16, jitter: float = 0.0, hist: bool = True, ax: Axes | None = None, show: bool = True, ) -> Axes | None: """Plots a scatter (dependence) plot of an interaction's per-sample value against one feature. Inspired by `SHAP's <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/scatter.html>`_ ``shap.plots.scatter``. For a first-order interaction ``(i,)`` the x-axis is feature ``i``'s value across samples and the y-axis is its Shapley value. For higher-order interactions like ``(i, j)`` the x-axis is the value of a single feature in the interaction (selected via ``x_feature``, defaulting to the first feature in the sorted tuple) and the y-axis is the higher-order interaction value. Args: interaction_values_list: A non-empty list of :class:`~shapiq.InteractionValues` objects, one per sample row of ``data``. data: The feature values for the samples, as a ``pandas.DataFrame`` or 2D ``numpy`` array. Must have the same number of rows as ``interaction_values_list``. interaction: Identifies the interaction to plot. Accepts an ``int`` or ``str`` (treated as a main effect single-element tuple), a tuple of feature indices like ``(0, 2)``, or a tuple of feature names like ``("MedInc", "Latitude")``. If ``None``, the globally most important interaction (by mean absolute aggregated value) is selected. Defaults to ``None``. x_feature: For higher-order interactions, which feature in ``interaction`` to place on the x-axis. Must be a member of ``interaction``. Ignored for first-order interactions. Defaults to the first feature in the sorted interaction tuple. color: Feature index or name used to color the points (with a red-blue colormap and a colorbar). If ``None`` (default), all points are drawn in a neutral color and no colorbar is shown. ``NaN`` color values render gray. feature_names: Names of the features. Defaults to ``["F0", "F1", ...]``. abbreviate: Whether to abbreviate feature names for axis labels. Defaults to ``True``. alpha: Transparency of the points, in ``(0, 1]``. Defaults to ``0.8``. dot_size: Size of the scatter points. Defaults to ``16``. jitter: If positive, adds Gaussian jitter to the plotted x-values, scaled to ``jitter * std(x_vals)``. Useful for categorical or integer-valued features. Defaults to ``0.0`` (disabled). hist: Whether to draw a faint histogram of the x-axis feature's distribution along the bottom of the plot (SHAP-style). The bars share the main x-axis: no separate axes is created. Defaults to ``True``. ax: ``matplotlib`` ``Axes`` object to plot on. If ``None``, a new figure and axes are created. show: Whether to call ``plt.show()`` at the end. If ``False``, returns the axes instead. Defaults to ``True``. Returns: The ``Axes`` object if ``show=False``, otherwise ``None``. Raises: ValueError: If inputs are inconsistent (empty list, length mismatch, unknown feature names or indices, an interaction tuple absent from every sample's lookup, an out-of-tuple ``x_feature``, or invalid numeric parameters). TypeError: If ``data`` is not a DataFrame or ndarray, or if a feature identifier has an unsupported type. """ if not isinstance(interaction_values_list, list) or len(interaction_values_list) == 0: error_message = "interaction_values_list must be a non-empty list." raise ValueError(error_message) if not isinstance(data, pd.DataFrame) and not isinstance(data, np.ndarray): error_message = f"data must be a pandas DataFrame or a numpy array. Got: {type(data)}." raise TypeError(error_message) if len(interaction_values_list) != len(data): error_message = "Length of interaction_values_list must match number of rows in data." raise ValueError(error_message) if alpha <= 0 or alpha > 1: error_message = "alpha must be between 0 and 1." raise ValueError(error_message) if dot_size <= 0: error_message = "dot_size must be a positive value." raise ValueError(error_message) if jitter < 0: error_message = "jitter must be non-negative." raise ValueError(error_message) n_players = interaction_values_list[0].n_players if feature_names is None: feature_names_full = [f"F{i}" for i in range(n_players)] else: if len(feature_names) != n_players: error_message = "Length of feature_names must match n_players." raise ValueError(error_message) feature_names_full = list(feature_names) feature_names_display = ( abbreviate_feature_names(feature_names_full) if abbreviate else list(feature_names_full) ) name_to_idx = {n: i for i, n in enumerate(feature_names_full)} display_mapping = dict(enumerate(feature_names_display)) interaction_tuple = _resolve_interaction( interaction, interaction_values_list, name_to_idx, n_players ) if not any(interaction_tuple in iv.interaction_lookup for iv in interaction_values_list): error_message = f"Interaction {interaction_tuple} not found in InteractionValues lookup." raise ValueError(error_message) if len(interaction_tuple) == 1 or x_feature is None: x_idx = interaction_tuple[0] else: x_idx = _resolve_feature(x_feature, name_to_idx, n_players) if x_idx not in interaction_tuple: error_message = ( f"x_feature {x_feature!r} must be a member of interaction {interaction_tuple}." ) raise ValueError(error_message) color_idx: int | None = None if color is not None: color_idx = _resolve_feature(color, name_to_idx, n_players) x_numpy = data.to_numpy(dtype=float) if isinstance(data, pd.DataFrame) else data.astype(float) x_vals = x_numpy[:, x_idx] y_vals = np.array([iv[interaction_tuple] for iv in interaction_values_list], dtype=float) if ax is None: _fig, ax = plt.subplots(figsize=(7, 5)) fig: Figure = ax.get_figure() # type: ignore[assignment] x_plot = x_vals if jitter > 0: std = float(np.nanstd(x_vals)) if std > 0: rng = np.random.default_rng(0) x_plot = x_vals + rng.normal(0.0, jitter * std, size=x_vals.shape) n_samples = len(x_vals) sc = None if color_idx is None: ax.scatter( x_plot, y_vals, color="#1f77b4", s=dot_size, alpha=alpha, linewidth=0, rasterized=n_samples > 500, ) else: c_vals = x_numpy[:, color_idx] nan_mask = np.isnan(c_vals) valid_mask = ~nan_mask if nan_mask.any(): ax.scatter( x_plot[nan_mask], y_vals[nan_mask], color="#777777", s=dot_size, alpha=alpha * 0.5, linewidth=0, rasterized=n_samples > 500, ) if valid_mask.any(): valid_color_vals = c_vals[valid_mask] vmin = float(np.min(valid_color_vals)) vmax = float(np.max(valid_color_vals)) if vmin == vmax: vmin -= 1e-9 vmax += 1e-9 sc = ax.scatter( x_plot[valid_mask], y_vals[valid_mask], c=valid_color_vals, cmap=_get_red_blue_cmap(), vmin=vmin, vmax=vmax, s=dot_size, alpha=alpha, linewidth=0, rasterized=n_samples > 500, ) if sc is not None and color_idx is not None: cb = fig.colorbar(sc, ax=ax, aspect=80) cb.set_label(display_mapping[color_idx], size=11, labelpad=0) cb.ax.tick_params(labelsize=10, length=0) cb.outline.set_visible(False) # type: ignore[union-attr] valid_x_for_hist = x_vals[~np.isnan(x_vals)] draw_hist = ( hist and len(valid_x_for_hist) >= 2 and float(np.min(valid_x_for_hist)) < float(np.max(valid_x_for_hist)) ) if draw_hist: n_bins = min(50, max(10, len(valid_x_for_hist) // 2)) counts, bin_edges = np.histogram(valid_x_for_hist, bins=n_bins) if counts.max() > 0: hist_band = 0.10 # bottom 10% of the existing plot area rel_heights = (counts / counts.max()) * hist_band widths = np.diff(bin_edges) ax.bar( bin_edges[:-1], rel_heights, width=widths, bottom=0.0, align="edge", color="#aaaaaa", alpha=0.4, edgecolor="none", zorder=-2, transform=ax.get_xaxis_transform(), ) ax.axhline(0, color="#999999", linestyle="-", linewidth=1, zorder=1) ax.set_xlabel(display_mapping[x_idx], fontsize=12) index_name = interaction_values_list[0].index feature_label = ", ".join(display_mapping[f] for f in interaction_tuple) ax.set_ylabel(f"{index_name}({feature_label})", fontsize=12) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) plt.tight_layout() if not show: return ax plt.show() return None