"""Wrapper for the force plot from the ``shap`` package.
Note:
Code and implementation was taken and adapted from the [SHAP package](https://github.com/shap/shap)
which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import lines
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.font_manager import FontProperties
from matplotlib.patches import PathPatch, Polygon
from matplotlib.path import Path
from .utils import abbreviate_feature_names, format_labels
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from shapiq.interaction_values import InteractionValues
__all__ = ["force_plot"]
def _create_bars(
out_value: float,
features: np.ndarray,
feature_type: str,
width_separators: float,
width_bar: float,
) -> tuple[list, list]:
rectangle_list = []
separator_list = []
pre_val = out_value
for index, feature_iteration in zip(range(len(features)), features, strict=False):
if feature_type == "positive":
left_bound = float(feature_iteration[0])
right_bound = pre_val
pre_val = left_bound
separator_indent = np.abs(width_separators)
separator_pos = left_bound
colors = ["#FF0D57", "#FFC3D5"]
else:
left_bound = pre_val
right_bound = float(feature_iteration[0])
pre_val = right_bound
separator_indent = -np.abs(width_separators)
separator_pos = right_bound
colors = ["#1E88E5", "#D1E6FA"]
# Create rectangle
if index == 0:
if feature_type == "positive":
points_rectangle = [
[left_bound, 0],
[right_bound, 0],
[right_bound, width_bar],
[left_bound, width_bar],
[left_bound + separator_indent, (width_bar / 2)],
]
else:
points_rectangle = [
[right_bound, 0],
[left_bound, 0],
[left_bound, width_bar],
[right_bound, width_bar],
[right_bound + separator_indent, (width_bar / 2)],
]
else:
points_rectangle = [
[left_bound, 0],
[right_bound, 0],
[right_bound + separator_indent * 0.90, (width_bar / 2)],
[right_bound, width_bar],
[left_bound, width_bar],
[left_bound + separator_indent * 0.90, (width_bar / 2)],
]
line = Polygon(
points_rectangle,
closed=True,
fill=True,
facecolor=colors[0],
linewidth=0,
)
rectangle_list += [line]
# Create separator
points_separator = [
[separator_pos, 0],
[separator_pos + separator_indent, (width_bar / 2)],
[separator_pos, width_bar],
]
line = Polygon(points_separator, closed=False, fill=None, edgecolor=colors[1], lw=3)
separator_list += [line]
return rectangle_list, separator_list
def _add_labels(
fig: Figure,
ax: Axes,
out_value: float,
features: np.ndarray,
feature_type: str,
offset_text: float,
total_effect: float = 0,
min_perc: float = 0.05,
text_rotation: float = 0,
) -> tuple[Figure, Axes]:
"""Add labels to the plot.
Args:
fig: Figure of the plot
ax: Axes of the plot
out_value: output value
features: The values and names of the features
feature_type: Indicating whether positive or negative features
offset_text: value to offset name of the features
total_effect: Total value of all features. Used to filter out features that do not contribute at least min_perc to the total effect.
Defaults to 0 indicating that all features are shown.
min_perc: minimal percentage of the total effect that a feature must contribute to be shown. Defaults to 0.05.
text_rotation: Degree the text should be rotated. Defaults to 0.
"""
start_text = out_value
pre_val = out_value
# Define variables specific to positive and negative effect features
if feature_type == "positive":
colors = ["#FF0D57", "#FFC3D5"]
alignment = "right"
sign = 1
else:
colors = ["#1E88E5", "#D1E6FA"]
alignment = "left"
sign = -1
# Draw initial line
if feature_type == "positive":
x, y = np.array([[pre_val, pre_val], [0, -0.18]])
line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0])
line.set_clip_on(False)
ax.add_line(line)
start_text = pre_val
box_end = out_value
val = out_value
for feature in features:
# Exclude all labels that do not contribute at least 10% to the total
feature_contribution = np.abs(float(feature[0]) - pre_val) / np.abs(total_effect)
if feature_contribution < min_perc:
break
# Compute value for current feature
val = float(feature[0])
# Draw labels.
text = feature[1]
va_alignment = "top" if text_rotation != 0 else "baseline"
text_out_val = plt.text(
start_text - sign * offset_text,
-0.15,
text,
fontsize=12,
color=colors[0],
horizontalalignment=alignment,
va=va_alignment,
rotation=text_rotation,
)
text_out_val.set_bbox({"facecolor": "none", "edgecolor": "none"})
# We need to draw the plot to be able to get the size of the
# text box
fig.canvas.draw()
box_size = text_out_val.get_bbox_patch().get_extents().transformed(ax.transData.inverted()) # type: ignore[union-attr]
if feature_type == "positive":
box_end_ = box_size.get_points()[0][0]
else:
box_end_ = box_size.get_points()[1][0]
# Create end line
if (sign * box_end_) > (sign * val):
x, y = np.array([[val, val], [0, -0.18]])
line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0])
line.set_clip_on(False)
ax.add_line(line)
start_text = val
box_end = val
else:
box_end = box_end_ - sign * offset_text
x, y = np.array([[val, box_end, box_end], [0, -0.08, -0.18]])
line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0])
line.set_clip_on(False)
ax.add_line(line)
start_text = box_end
# Update previous value
pre_val = float(feature[0])
# Create line for labels
extent_shading = (out_value, box_end, 0, -0.31)
path = [
[out_value, 0],
[pre_val, 0],
[box_end, -0.08],
[box_end, -0.2],
[out_value, -0.2],
[out_value, 0],
]
path = Path(path)
patch = PathPatch(path, facecolor="none", edgecolor="none")
ax.add_patch(patch)
# Extend axis if needed
lower_lim, upper_lim = ax.get_xlim()
if box_end < lower_lim:
ax.set_xlim(box_end, upper_lim)
if box_end > upper_lim:
ax.set_xlim(lower_lim, box_end)
# Create shading
if feature_type == "positive":
colors = np.array([(255, 13, 87), (255, 255, 255)]) / 255.0
else:
colors = np.array([(30, 136, 229), (255, 255, 255)]) / 255.0
cm = LinearSegmentedColormap.from_list("cm", colors)
_, z2 = np.meshgrid(np.linspace(0, 10), np.linspace(-10, 10))
im = plt.imshow(
z2,
interpolation="quadric",
cmap=cm,
vmax=0.01,
alpha=0.3,
origin="lower",
extent=extent_shading,
clip_path=patch,
clip_on=True,
aspect="auto",
)
im.set_clip_path(patch)
return fig, ax
def _add_output_element(out_name: str, out_value: float, ax: Axes) -> None:
"""Add grew line indicating the output value to the plot.
Args:
out_name: Name of the output value
out_value: Value of the output
ax: Axis of the plot
Returns: Nothing
"""
# Add output value
x, y = np.array([[out_value, out_value], [0, 0.24]])
line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2")
line.set_clip_on(False)
ax.add_line(line)
font0 = FontProperties()
font = font0.copy()
font.set_weight("bold")
text_out_val = plt.text(
out_value,
0.25,
f"{out_value:.2f}",
fontproperties=font,
fontsize=14,
horizontalalignment="center",
)
text_out_val.set_bbox({"facecolor": "white", "edgecolor": "white"})
text_out_val = plt.text(
out_value,
0.33,
out_name,
fontsize=12,
alpha=0.5,
horizontalalignment="center",
)
text_out_val.set_bbox({"facecolor": "white", "edgecolor": "white"})
def _add_base_value(base_value: float, ax: Axes) -> None:
"""Add base value to the plot.
Args:
base_value: the base value of the game
ax: Axes of the plot
Returns: None
"""
x, y = np.array([[base_value, base_value], [0.13, 0.25]])
line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2")
line.set_clip_on(False)
ax.add_line(line)
text_out_val = ax.text(
base_value,
0.25,
"base value",
fontsize=12,
alpha=1,
horizontalalignment="center",
)
text_out_val.set_bbox({"facecolor": "white", "edgecolor": "white"})
def update_axis_limits(
ax: Axes,
total_pos: float,
pos_features: np.ndarray,
total_neg: float,
neg_features: np.ndarray,
base_value: float,
out_value: float,
) -> None:
"""Adjust the axis limits of the plot according to values.
Args:
ax: Axes of the plot
total_pos: value of the total positive features
pos_features: values and names of the positive features
total_neg: value of the total negative features
neg_features: values and names of the negative features
base_value: the base value of the game
out_value: the output value
Returns: None
"""
ax.set_ylim(-0.5, 0.15)
padding = np.max([np.abs(total_pos) * 0.2, np.abs(total_neg) * 0.2])
if len(pos_features) > 0:
min_x = min(np.min(pos_features[:, 0].astype(float)), base_value) - padding
else:
min_x = out_value - padding
if len(neg_features) > 0:
max_x = max(np.max(neg_features[:, 0].astype(float)), base_value) + padding
else:
max_x = out_value + padding
ax.set_xlim(min_x, max_x)
plt.tick_params(
top=True,
bottom=False,
left=False,
right=False,
labelleft=False,
labeltop=True,
labelbottom=False,
)
plt.locator_params(axis="x", nbins=12)
for key, spine in zip(plt.gca().spines.keys(), plt.gca().spines.values(), strict=False):
if key != "top":
spine.set_visible(False)
def _split_features(
interaction_dictionary: dict[tuple[int, ...], float],
feature_to_names: dict[int, str],
out_value: float,
) -> tuple[np.ndarray, np.ndarray, float, float]:
"""Splits the features into positive and negative values.
Args:
interaction_dictionary: Dictionary containing the interaction values mapping from
feature indices to their values.
feature_to_names: Dictionary mapping feature indices to feature names.
out_value: The output value.
Returns:
tuple: A tuple containing the positive features, negative features, total positive value,
and total negative value.
"""
# split features into positive and negative values
pos_features, neg_features = [], []
for coaltion, value in interaction_dictionary.items():
if len(coaltion) == 0:
continue
label = format_labels(feature_to_names, coaltion)
if value >= 0:
pos_features.append([str(value), label])
elif value < 0:
neg_features.append([str(value), label])
# sort feature values descending according to (absolute) features values
pos_features = sorted(pos_features, key=lambda x: float(x[0]), reverse=True)
neg_features = sorted(neg_features, key=lambda x: float(x[0]), reverse=False)
pos_features = np.array(pos_features, dtype=object)
neg_features = np.array(neg_features, dtype=object)
# convert negative feature values to plot values
neg_val = out_value
for i in neg_features:
val = float(i[0])
neg_val = neg_val + np.abs(val)
i[0] = neg_val
if len(neg_features) > 0:
total_neg = np.max(neg_features[:, 0].astype(float)) - np.min(
neg_features[:, 0].astype(float),
)
else:
total_neg = 0
# convert positive feature values to plot values
pos_val = out_value
for i in pos_features:
val = float(i[0])
pos_val = pos_val - np.abs(val)
i[0] = pos_val
if len(pos_features) > 0:
total_pos = np.max(pos_features[:, 0].astype(float)) - np.min(
pos_features[:, 0].astype(float),
)
else:
total_pos = 0
return pos_features, neg_features, total_pos, total_neg
def _add_bars(
ax: Axes,
out_value: float,
pos_features: np.ndarray,
neg_features: np.ndarray,
) -> None:
"""Add bars to the plot.
Args:
ax: Axes of the plot
out_value: grand total value
pos_features: positive features
neg_features: negative features
"""
width_bar = 0.1
width_separators = (ax.get_xlim()[1] - ax.get_xlim()[0]) / 200
# Create bar for negative shap values
rectangle_list, separator_list = _create_bars(
out_value,
neg_features,
"negative",
width_separators,
width_bar,
)
for i in rectangle_list:
ax.add_patch(i)
for i in separator_list:
ax.add_patch(i)
# Create bar for positive shap values
rectangle_list, separator_list = _create_bars(
out_value,
pos_features,
"positive",
width_separators,
width_bar,
)
for i in rectangle_list:
ax.add_patch(i)
for i in separator_list:
ax.add_patch(i)
def draw_higher_lower_element(
out_value: float,
offset_text: float,
) -> None:
plt.text(
out_value - offset_text,
0.35,
"higher",
fontsize=13,
color="#FF0D57",
horizontalalignment="right",
)
plt.text(
out_value + offset_text,
0.35,
"lower",
fontsize=13,
color="#1E88E5",
horizontalalignment="left",
)
plt.text(
out_value,
0.34,
r"$\leftarrow$",
fontsize=13,
color="#1E88E5",
horizontalalignment="center",
)
plt.text(
out_value,
0.36,
r"$\rightarrow$",
fontsize=13,
color="#FF0D57",
horizontalalignment="center",
)
def _draw_force_plot(
interaction_value: InteractionValues,
feature_names: np.ndarray,
*,
figsize: tuple[int, int],
min_perc: float = 0.05,
draw_higher_lower: bool = True,
) -> Figure:
"""Draw the force plot.
Note:
The functionality was taken and adapted from the [SHAP package](https://github.com/shap/shap/blob/master/shap/plots/_force.py)
which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
Do not use this function directly, use the ``force_plot`` function instead.
Args:
interaction_value: The interaction values to be plotted.
feature_names: The names of the features.
figsize: The size of the figure.
min_perc: minimal percentage of the total effect that a feature must contribute to be shown.
Defaults to ``0.05``.
draw_higher_lower: Whether to draw the higher and lower indicator. Defaults to ``True``.
Returns:
The figure of the plot.
"""
# turn off interactive plot
plt.ioff()
# compute overall metrics
base_value = interaction_value.baseline_value
out_value = np.sum(interaction_value.values) # Sum of all values with the baseline value
# split features into positive and negative values
features_to_names = {i: str(name) for i, name in enumerate(feature_names)}
pos_features, neg_features, total_pos, total_neg = _split_features(
interaction_value.dict_values,
features_to_names,
out_value,
)
# define plots
offset_text = (np.abs(total_neg) + np.abs(total_pos)) * 0.04
fig, ax = plt.subplots(figsize=figsize)
# compute axis limit
update_axis_limits(
ax, total_pos, pos_features, total_neg, neg_features, float(base_value), out_value
)
# add the bars to the plot
_add_bars(ax, out_value, pos_features, neg_features)
# add labels
total_effect = np.abs(total_neg) + total_pos
fig, ax = _add_labels(
fig,
ax,
out_value,
neg_features,
"negative",
offset_text,
total_effect,
min_perc=min_perc,
text_rotation=0,
)
fig, ax = _add_labels(
fig,
ax,
out_value,
pos_features,
"positive",
offset_text,
total_effect,
min_perc=min_perc,
text_rotation=0,
)
# add higher and lower element
if draw_higher_lower:
draw_higher_lower_element(out_value, offset_text)
# add label for base value
_add_base_value(float(base_value), ax)
# add output label
out_names = ""
_add_output_element(out_names, out_value, ax)
# fix the whitespace around the plot
plt.tight_layout()
return plt.gcf()
[docs]
def force_plot(
interaction_values: InteractionValues,
*,
feature_names: np.ndarray | list[str] | None = None,
abbreviate: bool = True,
show: bool = False,
figsize: tuple[int, int] = (15, 4),
draw_higher_lower: bool = True,
contribution_threshold: float = 0.05,
) -> Figure | None:
"""Draws a force plot for the given interaction values.
Args:
interaction_values: The ``InteractionValues`` to be plotted.
feature_names: The names of the features. If ``None``, the features are named by their index.
show: Whether to show or return the plot. Defaults to ``False`` and returns the plot.
abbreviate: Whether to abbreviate the feature names. Defaults to ``True.``
figsize: The size of the figure. Defaults to ``(15, 4)``.
draw_higher_lower: Whether to draw the higher and lower indicator. Defaults to ``True``.
contribution_threshold: Define the minimum percentage of the total effect that a feature
must contribute to be shown in the plot. Defaults to 0.05.
Returns:
plt.Figure: The figure of the plot
"""
if feature_names is None:
feature_names = [str(i) for i in range(interaction_values.n_players)]
if abbreviate:
feature_names = abbreviate_feature_names(feature_names)
feature_names = np.array(feature_names)
plot = _draw_force_plot(
interaction_values,
feature_names,
figsize=figsize,
draw_higher_lower=draw_higher_lower,
min_perc=contribution_threshold,
)
if not show:
return plot
plt.show()
return None