"""This module contains the upset plot."""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
from ._config import BLUE, RED
if TYPE_CHECKING:
from collections.abc import Sequence
from matplotlib.figure import Figure
from shapiq.interaction_values import InteractionValues
[docs]
def upset_plot(
interaction_values: InteractionValues,
*,
n_interactions: int = 20,
feature_names: Sequence[str] | None = None,
color_matrix: bool = False,
all_features: bool = True,
figsize: tuple[float, float] | None = None,
show: bool = False,
) -> Figure | None:
"""Plots the upset plot.
UpSet plots [Lex14]_ can be used to visualize the interactions between features. The plot
consists of two parts: the upper part shows the interaction values as bars, and the lower part
shows the interactions as a matrix. Originally, the UpSet plot was introduced by Lex et al.
(2014) [Lex14]_.
For a more detailed explanation about the plots, see the references or the original
[documentation](https://upset.app/).
An example of this plot is shown below.
.. image:: /_static/images/upset_plot.png
:width: 600
:align: center
Args:
interaction_values: The interaction values as an ``InteractionValues`` object.
feature_names: The names of the features. Defaults to ``None``. If ``None``, the features
will be named with their index.
n_interactions: The number of top interactions to plot. Defaults to ``20``. Note this number
is completely arbitrary and can be adjusted to the user's needs.
color_matrix: Whether to color the matrix (red for positive values, blue for negative) or
not (black). Defaults to ``False``.
all_features: Whether to plot all ``n_players`` features or only the features that are
present in the top interactions. Defaults to ``True``.
figsize: The size of the figure. Defaults to ``None``. If ``None``, the size will be set
automatically depending on the number of features.
show: Whether to show the plot. Defaults to ``False``.
Returns:
If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple with
the figure and the axis of the plot.
References:
.. [Lex14] Alexander Lex, Nils Gehlenborg, Hendrik Strobelt, Romain Vuillemot, Hanspeter Pfister. UpSet: Visualization of Intersecting Sets IEEE Transactions on Visualization and Computer Graphics (InfoVis), 20(12): 1983--1992, doi:10.1109/TVCG.2014.2346248, 2014.
"""
# prepare data ---------------------------------------------------------------------------------
values = interaction_values.values
values_ids: dict[int, tuple[int, ...]] = {
v: k for k, v in interaction_values.interaction_lookup.items()
}
values_abs = abs(values)
idx = values_abs.argsort()[::-1]
idx = idx[:n_interactions] if n_interactions > 0 else idx
values = values[idx]
interactions: list[tuple[int, ...]] = [values_ids[i] for i in idx]
# prepare feature names ------------------------------------------------------------------------
if all_features:
features = set(range(interaction_values.n_players))
else:
features = {feature for interaction in interactions for feature in interaction}
n_features = len(features)
feature_pos = {feature: n_features - 1 - i for i, feature in enumerate(features)}
if feature_names is None:
feature_names = [f"Feature {feature}" for feature in features]
else:
feature_names = [feature_names[feature] for feature in features]
# create figure --------------------------------------------------------------------------------
height_upper, height_lower = 5, n_features * 0.75
height = height_upper + height_lower
ratio = [height_upper, height_lower]
if figsize is None:
figsize = (10, height)
else:
if figsize[1] is None:
figsize = (figsize[0], height)
if figsize[0] is None:
figsize = (10, figsize[1])
fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={"height_ratios": ratio}, sharex=True)
# plot lower part of the upset plot
for x_pos, interaction in enumerate(interactions):
color = RED.hex if values[x_pos] >= 0 else BLUE.hex
# plot upper part
bar = ax[0].bar(x_pos, values[x_pos], color=color)
label = [f"{values[x_pos]:.2f}"]
ax[0].bar_label(bar, label, label_type="edge", color="black", fontsize=12, padding=3)
# plot lower part
# plot the matrix in the background
ax[1].plot(
[x_pos for _ in range(n_features)],
list(range(n_features)),
color="lightgray",
marker="o",
markersize=15,
linewidth=0,
)
# add the interaction to the matrix
y_pos = [feature_pos[feature] for feature in interaction]
ax[1].plot(
[x_pos for _ in range(len(interaction))],
y_pos,
color="black" if not color_matrix else color,
marker="o",
markersize=15,
linewidth=1.5,
)
# beautify upper plot --------------------------------------------------------------------------
min_max = (min(values), max(values))
delta = (min_max[1] - min_max[0]) * 0.1
ax[0].set_ylim(min_max[0] - delta, min_max[1] + delta)
ax[0].set_ylabel("Interaction Value")
ax[0].spines["top"].set_visible(False)
ax[0].spines["right"].set_visible(False)
ax[0].spines["bottom"].set_visible(False)
ax[0].axhline(0, color="black", linewidth=0.5) # add line at 0
# beautify lower plot --------------------------------------------------------------------------
ax[1].set_ylim(-1, n_features)
ax[1].yaxis.set_ticks(range(n_features))
ax[1].set_yticklabels(reversed(feature_names))
ax[1].tick_params(axis="y", length=0) # remove y-ticks
ax[1].set_xticks([]) # remove x-axis
ax[1].spines["top"].set_visible(False)
ax[1].spines["right"].set_visible(False)
ax[1].spines["bottom"].set_visible(False)
ax[1].spines["left"].set_visible(False)
# background shading
for i in range(n_features):
if i % 2 == 0:
ax[1].axhspan(i - 0.5, i + 0.5, color="lightgray", alpha=0.25, zorder=0, lw=0)
# adjust whitespace
plt.subplots_adjust(hspace=0.0)
if not show:
return fig
plt.show()
return None