Source code for shapiq.plot.sentence

"""This module contains the sentence plot."""

from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING

from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib.patches import FancyBboxPatch, PathPatch
from matplotlib.textpath import TextPath

from ._config import BLUE, RED

if TYPE_CHECKING:
    from collections.abc import Sequence

    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from shapiq.interaction_values import InteractionValues


def _get_color_and_alpha(max_value: float, value: float) -> tuple[str, float]:
    """Gets the color and alpha value for an interaction value."""
    color = RED.hex if value >= 0 else BLUE.hex
    ratio = abs(value / max_value)
    ratio = min(ratio, 1.0)  # make ratio at most 1
    return color, ratio


[docs] def sentence_plot( interaction_values: InteractionValues, words: Sequence[str], *, connected_words: Sequence[tuple[str, str]] | None = None, chars_per_line: int = 35, font_family: str = "sans-serif", show: bool = False, max_score: float | None = None, ) -> tuple[Figure, Axes] | None: """Plots the first order effects (attributions) of a sentence or paragraph. An example of the plot is shown below. .. image:: /_static/sentence_plot_example.png :width: 400 :align: center Args: interaction_values: The interaction values as an interaction object. words: The words of the sentence or a paragraph of text. connected_words: A list of tuples with connected words. Defaults to ``None``. If two 'words' are connected, the plot will not add a space between them (e.g., the parts "enjoy" and "able" would be connected to "enjoyable" with potentially different attributions for each part). chars_per_line: The maximum number of characters per line. Defaults to ``35`` after which the text will be wrapped to the next line. Connected words receive a '-' in front of them. font_family: The font family used for the plot. Defaults to ``sans-serif``. For a list of available font families, see the matplotlib documentation of ``matplotlib.font_manager.FontProperties``. Note the plot is optimized for sans-serif. max_score: The maximum score for the attributions to scale the colors and alpha values. This is useful if you want to compare the attributions of different sentences and both plots should have the same color scale. Defaults to ``None``. show: Whether to show the plot. Defaults to ``False``. Returns: If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple with the figure and the axis of the plot. Example: >>> import numpy as np >>> from shapiq.plot import sentence_plot >>> iv = InteractionValues( ... values=np.array([0.45, 0.01, 0.67, -0.2, -0.05, 0.7, 0.1, -0.04, 0.56, 0.7]), ... index="SV", ... n_players=10, ... min_order=1, ... max_order=1, ... estimated=False, ... baseline_value=0.0, ... ) >>> words = ["I", "really", "enjoy", "working", "with", "Shapley", "values", "in", "Python", "!"] >>> connected_words = [("Shapley", "values")] >>> fig, ax = sentence_plot(iv, words, connected_words, show=False, chars_per_line=100) >>> plt.show() .. image:: /_static/sentence_plot_connected_example.png :width: 300 :align: center """ # set all the size parameters fontsize = 20 word_spacing = 15 line_spacing = 10 height_padding = 5 width_padding = 5 # clean the input connected_words = [] if connected_words is None else connected_words words = [word.strip() for word in words] attributions = [interaction_values[(i,)] for i in range(len(words))] # get the maximum score if max_score is None: max_abs_attribution = max([abs(value) for value in attributions]) else: max_abs_attribution = max_score # create plot fig, ax = plt.subplots() max_x_pos = 0 x_pos, y_pos = word_spacing, 0 lines, chars_in_line = 0, 0 for i, (_word, attribution) in enumerate(zip(words, attributions, strict=False)): word = _word # check if the word is connected is_word_connected_first = False is_word_connected_second = (words[i - 1], word) in connected_words with contextlib.suppress(IndexError): is_word_connected_first = (word, words[i + 1]) in connected_words # check if the line is too long and needs to be wrapped chars_in_line += len(word) if chars_in_line > chars_per_line: lines += 1 chars_in_line = 0 x_pos = word_spacing y_pos -= fontsize + line_spacing if is_word_connected_second: word = "-" + word # adjust the x position for connected words if is_word_connected_second: x_pos += 2 # set the position of the word in the plot position = (x_pos, y_pos) # get the color and alpha value color, alpha = _get_color_and_alpha(max_abs_attribution, attribution) # get the text text_color = "black" if alpha < 2 / 3 else "white" fp = FontProperties(family=font_family, style="normal", size=fontsize, weight="normal") text_path = TextPath(position, word, prop=fp) text_path = PathPatch(text_path, facecolor=text_color, edgecolor="none") width_of_text = text_path.get_window_extent().width # get dimensions for the explanation patch height_patch = fontsize + height_padding width_patch = width_of_text + 1 y_pos_patch = y_pos - height_padding x_pos_patch = x_pos + 1 if is_word_connected_first: x_pos_patch -= width_padding / 2 width_patch += width_padding / 2 elif is_word_connected_second: width_patch += width_padding / 2 else: x_pos_patch -= width_padding / 2 width_patch += width_padding # create the explanation patch patch = FancyBboxPatch( xy=(x_pos_patch, y_pos_patch), width=width_patch, height=height_patch, color=color, alpha=alpha, zorder=-1, boxstyle="Round, pad=0, rounding_size=3", ) # draw elements for the word ax.add_patch(patch) ax.add_artist(text_path) # update the x position x_pos += width_of_text + word_spacing max_x_pos = max(max_x_pos, x_pos) if is_word_connected_first: x_pos -= word_spacing # fix up the dimensions of the plot ax.set_xlim(0, max_x_pos) ax.set_ylim(y_pos - fontsize / 2, fontsize + fontsize / 2) width = max_x_pos height = fontsize + fontsize / 2 + abs(y_pos - fontsize / 2) fig.set_size_inches(width / 100, height / 100) # clean up the plot ax.axis("off") plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # draw the plot if not show: return fig, ax plt.show() return None