shapiq.plot#

This module contains all plotting functions for the shapiq package.

shapiq.plot.network_plot(interaction_values=None, *, first_order_values=None, second_order_values=None, feature_names=None, feature_image_patches=None, feature_image_patches_size=0.2, center_image=None, center_image_size=0.6, draw_legend=True, center_text=None)[source]#

Draws the interaction network.

An interaction network is a graph where the nodes represent the features and the edges represent the interactions. The edge width is proportional to the interaction value. The color of the edge is red if the interaction value is positive and blue if the interaction value is negative. The interaction values should be derived from the n-Shapley interaction index (n-SII). Below is an example of an interaction network with an image in the center.

../_images/network_example.png
Parameters:
  • interaction_values (Optional[InteractionValues]) – The interaction values as an interaction object.

  • first_order_values (Optional[ndarray[float]]) – The first order n-SII values of shape (n_features,).

  • second_order_values (Optional[ndarray[float]]) – The second order n-SII values of shape (n_features, n_features). The diagonal values are ignored. Only the upper triangular values are used.

  • feature_names (Optional[list[Any]]) – The feature names used for plotting. If no feature names are provided, the feature indices are used instead. Defaults to None.

  • feature_image_patches (Optional[dict[int, Image]]) – A dictionary containing the image patches to be displayed instead of the feature labels in the network. The keys are the feature indices and the values are the feature images. Defaults to None.

  • feature_image_patches_size (Union[float, dict[int, float], None]) – The size of the feature image patches. If a dictionary is provided, the keys are the feature indices and the values are the feature image patch. Defaults to 0.2.

  • center_image (Optional[Image]) – The image to be displayed in the center of the network. Defaults to None.

  • center_image_size (Optional[float]) – The size of the center image. Defaults to 0.6.

  • draw_legend (bool) – Whether to draw the legend. Defaults to True.

  • center_text (Optional[str]) – The text to be displayed in the center of the network. Defaults to None.

Return type:

tuple[Figure, Axes]

Returns:

The figure and the axis containing the plot.

shapiq.plot.stacked_bar_plot(feature_names, n_shapley_values_pos, n_shapley_values_neg, n_sii_max_order=None, title=None, xlabel=None, ylabel=None)[source]#

Plot the n-SII values for a given instance.

This stacked bar plot can be used to visualize the amount of interaction between the features for a given instance. The n-SII values are plotted as stacked bars with positive and negative parts stacked on top of each other. The colors represent the order of the n-SII values. For a detailed explanation of this plot, see this research paper.

An example of the plot is shown below.

../_images/stacked_bar_exampl.png
Parameters:
  • feature_names (list) – The names of the features.

  • n_shapley_values_pos (dict) – The positive n-SII values.

  • n_shapley_values_neg (dict) – The negative n-SII values.

  • n_sii_max_order (int) – The order of the n-SII values.

  • title (str) – The title of the plot.

  • xlabel (str) – The label of the x-axis.

  • ylabel (str) – The label of the y-axis.

Returns:

A tuple containing the figure and

the axis of the plot.

Return type:

tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]

Note

To change the figure size, font size, etc., use the [matplotlib parameters](https://matplotlib.org/stable/users/explain/customizing.html).

Example

>>> import numpy as np
>>> from shapiq.plot import stacked_bar_plot
>>> n_shapley_values_pos = {
...     1: np.asarray([1, 0, 1.75]),
...     2: np.asarray([0.25, 0.5, 0.75]),
...     3: np.asarray([0.5, 0.25, 0.25]),
... }
>>> n_shapley_values_neg = {
...     1: np.asarray([0, -1.5, 0]),
...     2: np.asarray([-0.25, -0.5, -0.75]),
...     3: np.asarray([-0.5, -0.25, -0.25]),
... }
>>> feature_names = ["a", "b", "c"]
>>> fig, axes = stacked_bar_plot(
...     feature_names=feature_names,
...     n_shapley_values_pos=n_shapley_values_pos,
...     n_shapley_values_neg=n_shapley_values_neg,
... )
>>> plt.show()

Modules

shapiq.plot.network

This module contains the network plots for the shapiq package.

shapiq.plot.stacked_bar

This module contains functions to plot the n_sii stacked bar charts.