shapiq.plot.stacked_bar#

This module contains functions to plot the n_sii stacked bar charts.

Functions

stacked_bar_plot(feature_names, ...[, ...])

Plot the n-SII values for a given instance.

shapiq.plot.stacked_bar.stacked_bar_plot(feature_names, n_shapley_values_pos, n_shapley_values_neg, n_sii_max_order=None, title=None, xlabel=None, ylabel=None)[source]#

Plot the n-SII values for a given instance.

This stacked bar plot can be used to visualize the amount of interaction between the features for a given instance. The n-SII values are plotted as stacked bars with positive and negative parts stacked on top of each other. The colors represent the order of the n-SII values. For a detailed explanation of this plot, see this research paper.

An example of the plot is shown below.

../_images/stacked_bar_exampl.png
Parameters:
  • feature_names (list) – The names of the features.

  • n_shapley_values_pos (dict) – The positive n-SII values.

  • n_shapley_values_neg (dict) – The negative n-SII values.

  • n_sii_max_order (int) – The order of the n-SII values.

  • title (str) – The title of the plot.

  • xlabel (str) – The label of the x-axis.

  • ylabel (str) – The label of the y-axis.

Returns:

A tuple containing the figure and

the axis of the plot.

Return type:

tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]

Note

To change the figure size, font size, etc., use the [matplotlib parameters](https://matplotlib.org/stable/users/explain/customizing.html).

Example

>>> import numpy as np
>>> from shapiq.plot import stacked_bar_plot
>>> n_shapley_values_pos = {
...     1: np.asarray([1, 0, 1.75]),
...     2: np.asarray([0.25, 0.5, 0.75]),
...     3: np.asarray([0.5, 0.25, 0.25]),
... }
>>> n_shapley_values_neg = {
...     1: np.asarray([0, -1.5, 0]),
...     2: np.asarray([-0.25, -0.5, -0.75]),
...     3: np.asarray([-0.5, -0.25, -0.25]),
... }
>>> feature_names = ["a", "b", "c"]
>>> fig, axes = stacked_bar_plot(
...     feature_names=feature_names,
...     n_shapley_values_pos=n_shapley_values_pos,
...     n_shapley_values_neg=n_shapley_values_neg,
... )
>>> plt.show()