Source code for shapiq.plot.stacked_bar

"""This module contains functions to plot the n_sii stacked bar charts."""

from __future__ import annotations

import contextlib
from copy import deepcopy
from typing import TYPE_CHECKING, Any

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch

from ._config import COLORS_K_SII

__all__ = ["stacked_bar_plot"]


if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from shapiq.interaction_values import InteractionValues


[docs] def stacked_bar_plot( interaction_values: InteractionValues, *, feature_names: list[Any] | None = None, max_order: int | None = None, title: str | None = None, xlabel: str | None = None, ylabel: str | None = None, show: bool = False, ) -> tuple[Figure, Axes] | None: """The stacked bar plot interaction scores. This stacked bar plot can be used to visualize the amount of interaction between the features for a given instance. The interaction values are plotted as stacked bars with positive and negative parts stacked on top of each other. The colors represent the order of the interaction values. For a detailed explanation of this plot, we refer to Bordt and von Luxburg (2023) [Bor23stk]_. An example of the plot is shown below. .. image:: /_static/stacked_bar_exampl.png :width: 400 :align: center Args: interaction_values(InteractionValues): n-SII values as InteractionValues object feature_names: The feature names used for plotting. If no feature names are provided, the feature indices are used instead. Defaults to ``None``. max_order (int): The order of the n-SII values. title (str): The title of the plot. xlabel (str): The label of the x-axis. ylabel (str): The label of the y-axis. show (bool): Whether to show the plot. Defaults to ``False``. Returns: tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: A tuple containing the figure and the axis of the plot. Note: To change the figure size, font size, etc., use the [matplotlib parameters](https://matplotlib.org/stable/users/explain/customizing.html). Example: >>> import numpy as np >>> from shapiq.plot import stacked_bar_plot >>> interaction_values = InteractionValues( ... values=np.array([1, -1.5, 1.75, 0.25, -0.5, 0.75,0.2]), ... index="SII", ... min_order=1, ... max_order=3, ... n_players=3, ... baseline_value=0 ... ) >>> feature_names = ["a", "b", "c"] >>> fig, axes = stacked_bar_plot( ... interaction_values=interaction_values, ... feature_names=feature_names, ... ) >>> plt.show() References: .. [Bor23stk] Bordt, M., and von Luxburg, U. (2023). From Shapley Values to Generalized Additive Models and back. Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, PMLR 206:709-745. url: https://proceedings.mlr.press/v206/bordt23a.html """ # sanitize inputs if max_order is None: max_order = interaction_values.max_order fig, axis = plt.subplots() # transform data to make plotting easier values_pos = np.array( [ interaction_values.get_n_order_values(order) .clip(min=0) .sum(axis=tuple(range(1, order))) for order in range(1, max_order + 1) ], ) values_neg = np.array( [ interaction_values.get_n_order_values(order) .clip(max=0) .sum(axis=tuple(range(1, order))) for order in range(1, max_order + 1) ], ) # get the number of features and the feature names n_features = len(values_pos[0]) if feature_names is None: feature_names = [str(i + 1) for i in range(n_features)] x = np.arange(n_features) # get helper variables for plotting the bars min_max_values = [0, 0] # to set the y-axis limits after all bars are plotted reference_pos = np.zeros(n_features) # to plot the bars on top of each other reference_neg = deepcopy(values_neg[0]) # to plot the bars below of each other # plot the bar segments for order in range(len(values_pos)): axis.bar(x, height=values_pos[order], bottom=reference_pos, color=COLORS_K_SII[order]) axis.bar(x, height=abs(values_neg[order]), bottom=reference_neg, color=COLORS_K_SII[order]) axis.axhline(y=0, color="black", linestyle="solid", linewidth=0.5) reference_pos += values_pos[order] with contextlib.suppress(IndexError): reference_neg += values_neg[order + 1] min_max_values[0] = min(min_max_values[0], *reference_neg) min_max_values[1] = max(min_max_values[1], *reference_pos) # add a legend to the plots legend_elements = [ Patch(facecolor=COLORS_K_SII[order], edgecolor="black", label=f"Order {order + 1}") for order in range(max_order) ] axis.legend(handles=legend_elements, loc="upper center", ncol=min(max_order, 4)) x_ticks_labels = list(feature_names) # might be unnecessary axis.set_xticks(x) axis.set_xticklabels(x_ticks_labels, rotation=45, ha="right") axis.set_xlim(-0.5, n_features - 0.5) axis.set_ylim( min_max_values[0] - abs(min_max_values[1] - min_max_values[0]) * 0.02, min_max_values[1] + abs(min_max_values[1] - min_max_values[0]) * 0.3, ) # set title and labels if not provided if title is not None: axis.set_title(title) axis.set_xlabel("features") if xlabel is None else axis.set_xlabel(xlabel) axis.set_ylabel("SI values") if ylabel is None else axis.set_ylabel(ylabel) plt.tight_layout() if not show: return fig, axis plt.show() return None