shapiq.plot.scatter_plot¶

shapiq.plot.scatter_plot(interaction_values_list, data, interaction=None, *, x_feature=None, color=None, feature_names=None, abbreviate=True, alpha=0.8, dot_size=16, jitter=0.0, hist=True, ax=None, show=True)[source]¶

Plots a scatter (dependence) plot of an interaction’s per-sample value against one feature.

Inspired by SHAP’s shap.plots.scatter. For a first-order interaction (i,) the x-axis is feature i’s value across samples and the y-axis is its Shapley value. For higher-order interactions like (i, j) the x-axis is the value of a single feature in the interaction (selected via x_feature, defaulting to the first feature in the sorted tuple) and the y-axis is the higher-order interaction value.

Parameters:
  • interaction_values_list (list[InteractionValues]) – A non-empty list of InteractionValues objects, one per sample row of data.

  • data (DataFrame | ndarray) – The feature values for the samples, as a pandas.DataFrame or 2D numpy array. Must have the same number of rows as interaction_values_list.

  • interaction (tuple[int, ...] | tuple[str, ...] | int | str | None) – Identifies the interaction to plot. Accepts an int or str (treated as a main effect single-element tuple), a tuple of feature indices like (0, 2), or a tuple of feature names like ("MedInc", "Latitude"). If None, the globally most important interaction (by mean absolute aggregated value) is selected. Defaults to None.

  • x_feature (int | str | None) – For higher-order interactions, which feature in interaction to place on the x-axis. Must be a member of interaction. Ignored for first-order interactions. Defaults to the first feature in the sorted interaction tuple.

  • color (int | str | None) – Feature index or name used to color the points (with a red-blue colormap and a colorbar). If None (default), all points are drawn in a neutral color and no colorbar is shown. NaN color values render gray.

  • feature_names (list[str] | None) – Names of the features. Defaults to ["F0", "F1", ...].

  • abbreviate (bool) – Whether to abbreviate feature names for axis labels. Defaults to True.

  • alpha (float) – Transparency of the points, in (0, 1]. Defaults to 0.8.

  • dot_size (float) – Size of the scatter points. Defaults to 16.

  • jitter (float) – If positive, adds Gaussian jitter to the plotted x-values, scaled to jitter * std(x_vals). Useful for categorical or integer-valued features. Defaults to 0.0 (disabled).

  • hist (bool) – Whether to draw a faint histogram of the x-axis feature’s distribution along the bottom of the plot (SHAP-style). The bars share the main x-axis: no separate axes is created. Defaults to True.

  • ax (Axes | None) – matplotlib Axes object to plot on. If None, a new figure and axes are created.

  • show (bool) – Whether to call plt.show() at the end. If False, returns the axes instead. Defaults to True.

Return type:

Axes | None

Returns:

The Axes object if show=False, otherwise None.

Raises:
  • ValueError – If inputs are inconsistent (empty list, length mismatch, unknown feature names or indices, an interaction tuple absent from every sample’s lookup, an out-of-tuple x_feature, or invalid numeric parameters).

  • TypeError – If data is not a DataFrame or ndarray, or if a feature identifier has an unsupported type.