Source code for shapiq.plot.waterfall

"""Wrapper for the waterfall plot from the ``shap`` package.

Note:
    Code and implementation was taken and adapted from the [SHAP package](https://github.com/shap/shap)
    which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).

"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.transforms import ScaledTranslation

from ._config import BLUE, RED
from .utils import abbreviate_feature_names, format_labels, format_value

if TYPE_CHECKING:
    from matplotlib.axes import Axes

    from shapiq.interaction_values import InteractionValues

__all__ = ["waterfall_plot"]


def _draw_waterfall_plot(
    values: np.ndarray,
    base_values: float,
    feature_names: np.ndarray | list[str],
    *,
    max_display: int = 10,
    show: bool = False,
) -> Axes | None:
    """The waterfall plot from the SHAP package.

    Note:
        This function was taken and adapted from the [SHAP package](https://github.com/shap/shap/blob/master/shap/plots/_waterfall.py)
        which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
        Do not use this function directly, use the ``waterfall_plot`` function instead.

    Args:
        values: The values to plot.
        base_values: The base value.
        feature_names: The names of the features.
        max_display: The maximum number of features to display.
        show: Whether to show the plot.

    Returns:
        The plot if ``show`` is ``False``.

    """
    # Turn off interactive plot
    if show is False:
        plt.ioff()

    # init variables we use for tracking the plot locations
    num_features = min(max_display, len(values))
    row_height = 0.5
    rng = range(num_features - 1, -1, -1)
    order = np.argsort(-np.abs(values))
    pos_lefts = []
    pos_inds = []
    pos_widths = []
    pos_low = []
    pos_high = []
    neg_lefts = []
    neg_inds = []
    neg_widths = []
    neg_low = []
    neg_high = []
    loc = base_values + values.sum()
    yticklabels = ["" for _ in range(num_features + 1)]

    # size the plot based on how many features we are plotting
    plt.gcf().set_size_inches(8, num_features * row_height + 3.5)

    # see how many individual (vs. grouped at the end) features we are plotting
    num_individual = num_features if num_features == len(values) else num_features - 1

    # compute the locations of the individual features and plot the dashed connecting lines
    for i in range(num_individual):
        sval = values[order[i]]
        loc -= sval
        if sval >= 0:
            pos_inds.append(rng[i])
            pos_widths.append(sval)
            pos_lefts.append(loc)
        else:
            neg_inds.append(rng[i])
            neg_widths.append(sval)
            neg_lefts.append(loc)
        if num_individual != num_features or i + 4 < num_individual:
            plt.plot(
                [loc, loc],
                [rng[i] - 1 - 0.4, rng[i] + 0.4],
                color="#bbbbbb",
                linestyle="--",
                linewidth=0.5,
                zorder=-1,
            )
        yticklabels[rng[i]] = str(feature_names[order[i]])

    # add a last grouped feature to represent the impact of all the features we didn't show
    if num_features < len(values):
        yticklabels[0] = f"{int(len(values) - num_features + 1)} other features"
        remaining_impact = base_values - loc
        if remaining_impact < 0:
            pos_inds.append(0)
            pos_widths.append(-remaining_impact)
            pos_lefts.append(loc + remaining_impact)
        else:
            neg_inds.append(0)
            neg_widths.append(-remaining_impact)
            neg_lefts.append(loc + remaining_impact)

    points = (
        pos_lefts
        + list(np.array(pos_lefts) + np.array(pos_widths))
        + neg_lefts
        + list(np.array(neg_lefts) + np.array(neg_widths))
    )
    dataw = np.max(points) - np.min(points)

    # draw invisible bars just for sizing the axes
    label_padding = np.array([0.1 * dataw if w < 1 else 0 for w in pos_widths])
    plt.barh(
        pos_inds,
        np.array(pos_widths) + label_padding + 0.02 * dataw,
        left=np.array(pos_lefts) - 0.01 * dataw,
        color=RED.hex,
        alpha=0,
    )
    label_padding = np.array([-0.1 * dataw if -w < 1 else 0 for w in neg_widths])
    plt.barh(
        neg_inds,
        np.array(neg_widths) + label_padding - 0.02 * dataw,
        left=np.array(neg_lefts) + 0.01 * dataw,
        color=BLUE.hex,
        alpha=0,
    )

    # define variable we need for plotting the arrows
    head_length = 0.08
    bar_width = 0.8
    xlen = plt.xlim()[1] - plt.xlim()[0]
    fig = plt.gcf()
    ax = plt.gca()
    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    width = bbox.width
    bbox_to_xscale = xlen / width
    hl_scaled = bbox_to_xscale * head_length
    dpi = fig.dpi
    renderer = fig.canvas.get_renderer()  # type: ignore[union-attr]

    # draw the positive arrows
    for i in range(len(pos_inds)):
        dist = pos_widths[i]
        arrow_obj = plt.arrow(
            pos_lefts[i],
            pos_inds[i],
            max(dist - hl_scaled, 0.000001),
            0,
            head_length=min(dist, hl_scaled),
            color=RED.hex,
            width=bar_width,
            head_width=bar_width,
        )

        if pos_low is not None and i < len(pos_low):
            plt.errorbar(
                pos_lefts[i] + pos_widths[i],
                pos_inds[i],
                xerr=np.array([[pos_widths[i] - pos_low[i]], [pos_high[i] - pos_widths[i]]]),
                ecolor=BLUE.hex,
            )

        txt_obj = plt.text(
            pos_lefts[i] + 0.5 * dist,
            pos_inds[i],
            format_value(pos_widths[i], "%+0.02f"),
            horizontalalignment="center",
            verticalalignment="center",
            color="white",
            fontsize=12,
        )
        text_bbox = txt_obj.get_window_extent(renderer=renderer)
        arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)

        # if the text overflows the arrow then draw it after the arrow
        if text_bbox.width > arrow_bbox.width:
            txt_obj.remove()

            txt_obj = plt.text(
                pos_lefts[i] + (5 / 72) * bbox_to_xscale + dist,
                pos_inds[i],
                format_value(pos_widths[i], "%+0.02f"),
                horizontalalignment="left",
                verticalalignment="center",
                color=RED.hex,
                fontsize=12,
            )

    # draw the negative arrows
    for i in range(len(neg_inds)):
        dist = neg_widths[i]

        arrow_obj = plt.arrow(
            neg_lefts[i],
            neg_inds[i],
            -max(-dist - hl_scaled, 0.000001),
            0,
            head_length=min(-dist, hl_scaled),
            color=BLUE.hex,
            width=bar_width,
            head_width=bar_width,
        )

        if neg_low is not None and i < len(neg_low):
            plt.errorbar(
                neg_lefts[i] + neg_widths[i],
                neg_inds[i],
                xerr=np.array([[neg_widths[i] - neg_low[i]], [neg_high[i] - neg_widths[i]]]),
                ecolor=RED.hex,
            )

        txt_obj = plt.text(
            neg_lefts[i] + 0.5 * dist,
            neg_inds[i],
            format_value(neg_widths[i], "%+0.02f"),
            horizontalalignment="center",
            verticalalignment="center",
            color="white",
            fontsize=12,
        )
        text_bbox = txt_obj.get_window_extent(renderer=renderer)
        arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)

        # if the text overflows the arrow then draw it after the arrow
        if text_bbox.width > arrow_bbox.width:
            txt_obj.remove()

            plt.text(
                neg_lefts[i] - (5 / 72) * bbox_to_xscale + dist,
                neg_inds[i],
                format_value(neg_widths[i], "%+0.02f"),
                horizontalalignment="right",
                verticalalignment="center",
                color=BLUE.hex,
                fontsize=12,
            )

    # draw the y-ticks twice, once in gray and then again with just the feature names in black
    # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
    ytick_pos = list(range(num_features)) + list(np.arange(num_features) + 1e-8)
    plt.yticks(
        ytick_pos,
        yticklabels[:-1] + [label.split("=")[-1] for label in yticklabels[:-1]],
        fontsize=13,
    )

    # Check that the y-ticks are not drawn outside the plot
    max_label_width = (
        max([label.get_window_extent(renderer=renderer).width for label in ax.get_yticklabels()])
        / dpi
    )
    if max_label_width > 0.1 * fig.get_size_inches()[0]:
        required_width = max_label_width / 0.1
        fig_height = fig.get_size_inches()[1]
        fig.set_size_inches(required_width, fig_height, forward=True)

    # put horizontal lines for each feature row
    for i in range(num_features):
        plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)

    # mark the prior expected value and the model prediction
    plt.axvline(
        base_values,
        0,
        1 / num_features,
        color="#bbbbbb",
        linestyle="--",
        linewidth=0.5,
        zorder=-1,
    )
    fx = base_values + values.sum()
    plt.axvline(fx, 0, 1, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1)

    # clean up the main axis
    plt.gca().xaxis.set_ticks_position("bottom")
    plt.gca().yaxis.set_ticks_position("none")
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["left"].set_visible(False)
    ax.tick_params(labelsize=13)

    # draw the E[f(X)] tick mark
    xmin, xmax = ax.get_xlim()
    ax2 = ax.twiny()
    ax2.set_xlim(xmin, xmax)
    ax2.set_xticks(
        [base_values, base_values + 1e-8],
    )  # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
    ax2.set_xticklabels(
        ["\n$E[f(X)]$", "\n$ = " + format_value(base_values, "%0.03f") + "$"],
        fontsize=12,
        ha="left",
    )
    ax2.spines["right"].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax2.spines["left"].set_visible(False)

    # draw the f(x) tick mark
    ax3 = ax2.twiny()
    ax3.set_xlim(xmin, xmax)
    # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
    ax3.set_xticks([base_values + values.sum(), base_values + values.sum() + 1e-8])
    ax3.set_xticklabels(
        ["$f(x)$", "$ = " + format_value(fx, "%0.03f") + "$"],
        fontsize=12,
        ha="left",
    )
    tick_labels = ax3.xaxis.get_majorticklabels()
    tick_labels[0].set_transform(
        tick_labels[0].get_transform() + ScaledTranslation(-10 / 72.0, 0, fig.dpi_scale_trans),
    )
    tick_labels[1].set_transform(
        tick_labels[1].get_transform() + ScaledTranslation(12 / 72.0, 0, fig.dpi_scale_trans),
    )
    tick_labels[1].set_color("#999999")
    ax3.spines["right"].set_visible(False)
    ax3.spines["top"].set_visible(False)
    ax3.spines["left"].set_visible(False)

    # adjust the position of the E[f(X)] = x.xx label
    tick_labels = ax2.xaxis.get_majorticklabels()
    tick_labels[0].set_transform(
        tick_labels[0].get_transform() + ScaledTranslation(-20 / 72.0, 0, fig.dpi_scale_trans),
    )
    tick_labels[1].set_transform(
        tick_labels[1].get_transform()
        + ScaledTranslation(22 / 72.0, -1 / 72.0, fig.dpi_scale_trans),
    )

    tick_labels[1].set_color("#999999")

    # color the y tick labels that have the feature values as gray
    # (these fall behind the black ones with just the feature name)
    tick_labels = ax.yaxis.get_majorticklabels()
    for i in range(num_features):
        tick_labels[i].set_color("#999999")

    if show:
        plt.show()
        return None
    return plt.gca()


[docs] def waterfall_plot( interaction_values: InteractionValues, *, feature_names: np.ndarray | list[str] | None = None, show: bool = False, max_display: int = 10, abbreviate: bool = True, ) -> Axes | None: """Draws a waterfall plot with the interaction values. The waterfall plot shows the individual contributions of the features to the interaction values. The plot is based on the waterfall plot from the `SHAP <https://github.com/shap/shap>`_ package. Args: interaction_values: The interaction values as an interaction object. feature_names: The names of the features. Defaults to ``None``. To display feature values alongside feature names, provide strings in the format ``"value=feature"`` (e.g., ``"25=Age"``). The plot will show the value in gray and the feature name in black. show: Whether to show the plot. Defaults to ``False``. max_display: The maximum number of interactions to display. Defaults to ``10``. abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. Returns: The plot if ``show`` is ``False``. """ if feature_names is None: feature_mapping = {i: str(i) for i in range(interaction_values.n_players)} else: if abbreviate: feature_names = abbreviate_feature_names(feature_names) feature_mapping = {i: feature_names[i] for i in range(interaction_values.n_players)} # create the data for the waterfall plot in the correct format data = [] for feature_tuple, value in interaction_values.dict_values.items(): if len(feature_tuple) > 0: data.append((format_labels(feature_mapping, feature_tuple), str(value))) data = np.array(data, dtype=object) values = data[:, 1].astype(float) feature_names = data[:, 0] return _draw_waterfall_plot( values, float(interaction_values.baseline_value), feature_names, max_display=max_display, show=show, )