"""Central output container for interaction scores produced by approximators and explainers."""
from __future__ import annotations
import contextlib
import copy
import json
from pathlib import Path
from typing import TYPE_CHECKING
from warnings import warn
import numpy as np
from .game_theory.indices import (
AllIndices,
get_index_from_computation_index,
is_empty_value_the_baseline,
is_index_aggregated,
is_index_valid,
)
from .utils.sets import generate_interaction_lookup
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from shapiq.typing import InteractionScores, JSONType
[docs]
class InteractionValues:
"""This class contains the interaction values as estimated by an approximator.
Attributes:
values: The interaction values of the model in vectorized form.
index: The interaction index estimated. All available indices are defined in
``ALL_AVAILABLE_INDICES``.
max_order: The order of the approximation.
n_players: The number of players.
min_order: The minimum order of the approximation. Defaults to ``0``.
interaction_lookup: A dictionary that maps interactions to their index in the values
vector. If ``interaction_lookup`` is not provided, it is computed from the ``n_players``,
``min_order``, and `max_order` parameters. Defaults to ``None``.
estimated: Whether the interaction values are estimated or not. Defaults to ``True``.
estimation_budget: The budget used for the estimation. Defaults to ``None``.
baseline_value: The value of the baseline interaction also known as 'empty prediction' or
``'empty value'`` since it denotes the value of the empty coalition (empty set). If not
provided it is searched for in the values vector (raising an Error if not found).
Defaults to ``None``.
Raises:
UserWarning: If the index is not a valid index as defined in ``ALL_AVAILABLE_INDICES``.
TypeError: If the baseline value is not a number.
"""
interactions: InteractionScores
"""The interactions as a dictionary mapping interactions to their values."""
def __init__(
self,
values: np.ndarray | InteractionScores,
*,
index: str,
max_order: int,
n_players: int,
min_order: int,
interaction_lookup: dict[tuple[int, ...], int] | None = None,
estimated: bool = True,
estimation_budget: int | None = None,
baseline_value: float | np.number = 0.0,
target_index: str | None = None,
) -> None:
"""Initialize the InteractionValues object.
Args:
values: The interaction values as a numpy array or a dictionary mapping interactions to their
values.
index: The index of the interaction values. This should be one of the indices defined in
ALL_AVAILABLE_INDICES. It is used to determine how the interaction values are interpreted.
max_order: The maximum order of the interactions.
n_players: The number of players in the game.
min_order: The minimum order of the interactions. Defaults to 0.
interaction_lookup: A dictionary mapping interactions to their index in the values vector.
Defaults to None, which means it will be generated from the n_players, min_order, and max_order parameters.
estimated: Whether the interaction values are estimated or not. Defaults to True.
estimation_budget: The budget used for the estimation. Defaults to None.
baseline_value: The baseline value of the interaction values, also known as the empty prediction or empty value.
target_index: The index to which the InteractionValues should be finalized. Defaults to None, which means that
target_index = index
"""
if not isinstance(baseline_value, (int | float | np.number)):
msg = f"Baseline value must be provided as a number. Got {type(baseline_value)}."
raise TypeError(msg)
self.baseline_value = baseline_value
if not is_index_valid(index, raise_error=False):
warn(
f"Index `{index}` is not a valid interaction index. "
f"Valid indices are: {', '.join(AllIndices)}.",
stacklevel=2,
)
index = get_index_from_computation_index(index, max_order)
if target_index is None:
target_index = index
interactions = _validate_and_return_interactions(
values=values,
interaction_lookup=interaction_lookup,
n_players=n_players,
min_order=min_order,
max_order=max_order,
baseline_value=baseline_value,
)
interactions, index, min_order, baseline_value = _update_interactions_for_index(
interactions=interactions,
index=index,
target_index=target_index,
min_order=min_order,
max_order=max_order,
baseline_value=baseline_value,
)
self.interactions = interactions
self.index = index
self.max_order = max_order
self.n_players = n_players
self.min_order = min_order
self.estimated = estimated
self.estimation_budget = estimation_budget
@property
def dict_values(self) -> dict[tuple[int, ...], float]:
"""Getter for the dict directly mapping from all interactions to scores."""
return self.interactions
@property
def values(self) -> np.ndarray:
"""Getter for the values of the InteractionValues object.
Returns:
The values of the InteractionValues object as a numpy array.
"""
return np.array(list(self.interactions.values()))
@property
def interaction_lookup(self) -> dict[tuple[int, ...], int]:
"""Getter for the interaction lookup of the InteractionValues object.
Returns:
The interaction lookup of the InteractionValues object as a dictionary mapping interactions
to their index in the values vector.
"""
return {
interaction: index for index, (interaction, _) in enumerate(self.interactions.items())
}
[docs]
def to_json_file(
self,
path: Path,
*,
desc: str | None = None,
created_from: object | None = None,
**kwargs: JSONType,
) -> None:
"""Saves the InteractionValues object to a JSON file.
Args:
path: The path to the JSON file.
desc: A description of the InteractionValues object. Defaults to ``None``.
created_from: An object from which the InteractionValues object was created. Defaults to
``None``.
**kwargs: Additional parameters to store in the metadata of the JSON file.
"""
from shapiq.utils.saving import (
interactions_to_dict,
make_file_metadata,
save_json,
)
file_metadata = make_file_metadata(
object_to_store=self,
data_type="interaction_values",
desc=desc,
created_from=created_from,
parameters=kwargs,
)
json_data = {
**file_metadata,
"metadata": {
"n_players": self.n_players,
"index": self.index,
"max_order": self.max_order,
"min_order": self.min_order,
"estimated": self.estimated,
"estimation_budget": self.estimation_budget,
"baseline_value": self.baseline_value,
},
"data": interactions_to_dict(interactions=self.dict_values),
}
save_json(json_data, path)
[docs]
@classmethod
def from_json_file(cls, path: Path) -> InteractionValues:
"""Loads an InteractionValues object from a JSON file.
Args:
path: The path to the JSON file. Note that the path must end with `'.json'`.
Returns:
The InteractionValues object loaded from the JSON file.
Raises:
ValueError: If the path does not end with `'.json'`.
"""
from shapiq.utils.saving import dict_to_lookup_and_values
if not path.name.endswith(".json"):
msg = f"Path {path} does not end with .json. Cannot load InteractionValues."
raise ValueError(msg)
with path.open("r", encoding="utf-8") as file:
json_data = json.load(file)
metadata = json_data["metadata"]
interaction_dict = json_data["data"]
interaction_lookup, values = dict_to_lookup_and_values(interaction_dict)
return cls(
values=values,
index=metadata["index"],
max_order=metadata["max_order"],
n_players=metadata["n_players"],
min_order=metadata["min_order"],
interaction_lookup=interaction_lookup,
estimated=metadata["estimated"],
estimation_budget=metadata["estimation_budget"],
baseline_value=metadata["baseline_value"],
)
[docs]
def sparsify(self, threshold: float = 1e-3) -> None:
"""Manually sets values close to zero actually to zero (removing values).
Args:
threshold: The threshold value below which interactions are zeroed out. Defaults to
1e-3.
"""
# find interactions to remove in self.interactions
sparse_interactions = copy.deepcopy(self.interactions)
for interaction, value in self.interactions.items():
if np.abs(value) < threshold:
del sparse_interactions[interaction]
self.interactions = sparse_interactions
[docs]
def get_top_k_interactions(self, k: int) -> InteractionValues:
"""Returns the top k interactions.
Args:
k: The number of top interactions to return.
Returns:
The top k interactions as an InteractionValues object.
"""
top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
new_values = np.zeros(k, dtype=float)
new_interaction_lookup = {}
for interaction_pos, interaction in enumerate(self.interaction_lookup):
if interaction_pos in top_k_indices:
new_position = len(new_interaction_lookup)
new_values[new_position] = float(self[interaction_pos])
new_interaction_lookup[interaction] = new_position
return InteractionValues(
values=new_values,
index=self.index,
max_order=self.max_order,
n_players=self.n_players,
min_order=self.min_order,
interaction_lookup=new_interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=self.baseline_value,
)
[docs]
def get_top_k(
self, k: int, *, as_interaction_values: bool = True
) -> InteractionValues | tuple[dict, list[tuple]]:
"""Returns the top k interactions.
Args:
k: The number of top interactions to return.
as_interaction_values: Whether to return the top `k` interactions as an InteractionValues
object. Defaults to ``False``.
Returns:
The top k interactions as a dictionary and a sorted list of tuples.
Examples:
>>> interaction_values = InteractionValues(
... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> top_k_interactions, sorted_top_k_interactions = interaction_values.get_top_k(2, False)
>>> top_k_interactions
{(0, 2): 0.5, (1, 0): 0.6}
>>> sorted_top_k_interactions
[((1, 0), 0.6), ((0, 2), 0.5)]
"""
if as_interaction_values:
return self.get_top_k_interactions(k)
top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
top_k_interactions = {}
for interaction, index in self.interaction_lookup.items():
if index in top_k_indices:
top_k_interactions[interaction] = self.values[index]
sorted_top_k_interactions = [
(interaction, top_k_interactions[interaction])
for interaction in sorted(
top_k_interactions, key=lambda x: top_k_interactions[x], reverse=True
)
]
return top_k_interactions, sorted_top_k_interactions
def __repr__(self) -> str:
"""Returns the representation of the InteractionValues object."""
representation = "InteractionValues(\n"
representation += (
f" index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
f" n_players={self.n_players}, baseline_value={self.baseline_value}\n)"
)
return representation
def __str__(self) -> str:
"""Returns the string representation of the InteractionValues object."""
representation = self.__repr__()
representation = representation[:-2] # remove the last "\n)" and add values
_, sorted_top_10_interactions = self.get_top_k(
10, as_interaction_values=False
) # get top 10 interactions
# add values to string representation
representation += ",\n Top 10 interactions:\n"
for interaction, value in sorted_top_10_interactions:
representation += f" {interaction}: {value}\n"
representation += ")"
return representation
def __len__(self) -> int:
"""Returns the length of the InteractionValues object."""
return len(self.values) # might better to return the theoretical no. of interactions
def __iter__(self) -> np.nditer:
"""Returns an iterator over the values of the InteractionValues object."""
return np.nditer(self.values)
def __getitem__(self, item: int | tuple[int, ...]) -> float:
"""Returns the score for the given interaction.
Args:
item: The interaction as a tuple of integers for which to return the score. If ``item`` is
an integer it serves as the index to the values vector.
Returns:
The interaction value. If the interaction is not present zero is returned.
"""
if isinstance(item, int):
return float(self.values[item])
item = tuple(sorted(item))
try:
return float(self.interactions[item])
except KeyError:
return 0.0
def __setitem__(self, item: int | tuple[int, ...], value: float) -> None:
"""Sets the score for the given interaction.
Args:
item: The interaction as a tuple of integers for which to set the score. If ``item`` is an
integer it serves as the index to the values vector.
value: The value to set for the interaction.
Raises:
KeyError: If the interaction is not found in the InteractionValues object.
"""
try:
if isinstance(item, int):
# dict.items() preserves the order of insertion, so we can use it to set the value
for i, (interaction, _) in enumerate(self.interactions.items()):
if i == item:
self.interactions[interaction] = value
break
else:
item = tuple(sorted(item))
if self.interactions[item] is not None:
# if the interaction is already present, update its value. Otherwise KeyError is raised
self.interactions[item] = value
except Exception as e:
msg = f"Interaction {item} not found in the InteractionValues. Unable to set a value."
raise KeyError(msg) from e
def __eq__(self, other: object) -> bool:
"""Checks if two InteractionValues objects are equal.
Args:
other: The other InteractionValues object.
Returns:
True if the two objects are equal, False otherwise.
"""
if not isinstance(other, InteractionValues):
msg = "Cannot compare InteractionValues with other types."
raise TypeError(msg)
if (
self.index != other.index
or self.max_order != other.max_order
or self.min_order != other.min_order
or self.n_players != other.n_players
or not np.allclose(self.baseline_value, other.baseline_value)
):
return False
if not np.allclose(self.values, other.values):
return False
return self.interaction_lookup == other.interaction_lookup
def __ne__(self, other: object) -> bool:
"""Checks if two InteractionValues objects are not equal.
Args:
other: The other InteractionValues object.
Returns:
True if the two objects are not equal, False otherwise.
"""
return not self.__eq__(other)
def __hash__(self) -> int:
"""Returns the hash of the InteractionValues object."""
return hash(
(
self.index,
self.max_order,
self.min_order,
self.n_players,
tuple(self.values.flatten()),
),
)
def __copy__(self) -> InteractionValues:
"""Returns a copy of the InteractionValues object."""
return InteractionValues(
values=copy.deepcopy(self.values),
index=self.index,
max_order=self.max_order,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
n_players=self.n_players,
interaction_lookup=copy.deepcopy(self.interaction_lookup),
min_order=self.min_order,
baseline_value=self.baseline_value,
)
def __add__(self, other: InteractionValues | float) -> InteractionValues:
"""Adds two InteractionValues objects together or a scalar."""
n_players, min_order, max_order = self.n_players, self.min_order, self.max_order
if isinstance(other, InteractionValues):
if self.index != other.index: # different indices
msg = (
f"The indices of the InteractionValues objects are different: "
f"{self.index} != {other.index}. Addition might not be meaningful."
)
warn(msg, stacklevel=2)
if (
self.interaction_lookup != other.interaction_lookup
or self.n_players != other.n_players
or self.min_order != other.min_order
or self.max_order != other.max_order
): # different interactions but addable
added_interactions = self.interactions.copy()
for interaction in other.interactions:
if interaction not in added_interactions:
added_interactions[interaction] = other.interactions[interaction]
else:
added_interactions[interaction] += other.interactions[interaction]
interaction_lookup = {
interaction: i for i, interaction in enumerate(added_interactions)
}
# adjust n_players, min_order, and max_order
n_players = max(self.n_players, other.n_players)
min_order = min(self.min_order, other.min_order)
max_order = max(self.max_order, other.max_order)
baseline_value = self.baseline_value + other.baseline_value
else: # basic case with same interactions
added_interactions = {
interaction: self.interactions[interaction] + other.interactions[interaction]
for interaction in self.interactions
}
interaction_lookup = self.interaction_lookup
baseline_value = self.baseline_value + other.baseline_value
elif isinstance(other, int | float):
added_interactions = {
interaction: self.interactions[interaction] + other
for interaction in self.interactions
}
interaction_lookup = self.interaction_lookup.copy()
baseline_value = self.baseline_value + other
else:
msg = f"Cannot add InteractionValues with object of type {type(other)}."
raise TypeError(msg)
return InteractionValues(
values=added_interactions,
index=self.index,
max_order=max_order,
n_players=n_players,
min_order=min_order,
interaction_lookup=interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=baseline_value,
)
def __radd__(self, other: InteractionValues | float) -> InteractionValues:
"""Adds two InteractionValues objects together or a scalar."""
return self.__add__(other)
def __neg__(self) -> InteractionValues:
"""Negates the InteractionValues object."""
return InteractionValues(
values=-self.values,
index=self.index,
max_order=self.max_order,
n_players=self.n_players,
min_order=self.min_order,
interaction_lookup=self.interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=-self.baseline_value,
)
def __sub__(self, other: InteractionValues | float) -> InteractionValues:
"""Subtracts two InteractionValues objects or a scalar."""
return self.__add__(-other)
def __rsub__(self, other: InteractionValues | float) -> InteractionValues:
"""Subtracts two InteractionValues objects or a scalar."""
return (-self).__add__(other)
def __mul__(self, other: float) -> InteractionValues:
"""Multiplies an InteractionValues object by a scalar."""
interactions = {
interaction: value * other for interaction, value in self.interactions.items()
}
return InteractionValues(
values=interactions,
index=self.index,
max_order=self.max_order,
n_players=self.n_players,
min_order=self.min_order,
interaction_lookup=self.interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=self.baseline_value * other,
)
def __rmul__(self, other: float) -> InteractionValues:
"""Multiplies an InteractionValues object by a scalar."""
return self.__mul__(other)
def __abs__(self) -> InteractionValues:
"""Returns the absolute values of the InteractionValues object."""
interactions = {interaction: abs(value) for interaction, value in self.interactions.items()}
return InteractionValues(
values=interactions,
index=self.index,
max_order=self.max_order,
n_players=self.n_players,
min_order=self.min_order,
interaction_lookup=self.interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=self.baseline_value,
)
[docs]
def get_n_order_values(self, order: int) -> np.ndarray:
"""Returns the interaction values of a specific order as a numpy array.
Note:
Depending on the order and number of players the resulting array might be sparse and
very large.
Args:
order: The order of the interactions to return.
Returns:
The interaction values of the specified order as a numpy array of shape ``(n_players,)``
for order ``1`` and ``(n_players, n_players)`` for order ``2``, etc.
Raises:
ValueError: If the order is less than ``1``.
"""
from itertools import permutations
if order < 1:
msg = "Order must be greater or equal to 1."
raise ValueError(msg)
values_shape = tuple([self.n_players] * order)
values = np.zeros(values_shape, dtype=float)
for interaction in self.interaction_lookup:
if len(interaction) != order:
continue
# get all orderings of the interaction (e.g. (0, 1) and (1, 0) for interaction (0, 1))
for perm in permutations(interaction):
values[perm] = self[interaction]
return values
[docs]
def get_n_order(
self,
order: int | None = None,
min_order: int | None = None,
max_order: int | None = None,
) -> InteractionValues:
"""Select particular order of interactions.
Creates a new InteractionValues object containing only the interactions within the
specified order range.
You can specify:
- `order`: to select interactions of a single specific order (e.g., all pairwise
interactions).
- `min_order` and/or `max_order`: to select a range of interaction orders.
- If `order` and `min_order`/`max_order` are both set, `min_order` and `max_order` will
override the `order` value.
Example:
>>> interaction_values = InteractionValues(
... values=np.array([1, 2, 3, 4, 5, 6, 7]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5, (0, 1, 2): 6},
... index="SII",
... max_order=3,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> interaction_values.get_n_order(order=1).dict_values
{(0,): 1.0, (1,): 2.0, (2,): 3.0}
>>> interaction_values.get_n_order(min_order=1, max_order=2).dict_values
{(0,): 1.0, (1,): 2.0, (2,): 3.0, (0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0}
>>> interaction_values.get_n_order(min_order=2).dict_values
{(0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0, (0, 1, 2): 7.0}
Args:
order: The order of the interactions to return. Defaults to ``None`` which requires
``min_order`` or ``max_order`` to be set.
min_order: The minimum order of the interactions to return. Defaults to ``None`` which
sets it to the order.
max_order: The maximum order of the interactions to return. Defaults to ``None`` which
sets it to the order.
Returns:
The interaction values of the specified order.
Raises:
ValueError: If all three parameters are set to ``None``.
"""
if order is None and min_order is None and max_order is None:
msg = "Either order, min_order or max_order must be set."
raise ValueError(msg)
if order is not None:
max_order = order if max_order is None else max_order
min_order = order if min_order is None else min_order
else: # order is None
min_order = self.min_order if min_order is None else min_order
max_order = self.max_order if max_order is None else max_order
if min_order > max_order:
msg = f"min_order ({min_order}) must be less than or equal to max_order ({max_order})."
raise ValueError(msg)
new_values = []
new_interaction_lookup = {}
for interaction in self.interaction_lookup:
if len(interaction) < min_order or len(interaction) > max_order:
continue
interaction_idx = len(new_interaction_lookup)
new_values.append(self[interaction])
new_interaction_lookup[interaction] = interaction_idx
return InteractionValues(
values=np.array(new_values),
index=self.index,
max_order=max_order,
n_players=self.n_players,
min_order=min_order,
interaction_lookup=new_interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=self.baseline_value,
)
[docs]
def get_subset(self, players: list[int]) -> InteractionValues:
"""Selects a subset of players from the InteractionValues object.
Args:
players: List of players to select from the InteractionValues object.
Returns:
InteractionValues: Filtered InteractionValues object containing only values related to
selected players.
Example:
>>> interaction_values = InteractionValues(
... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> interaction_values.get_subset([0, 1]).dict_values
{(0,): 0.1, (1,): 0.2, (0, 1): 0.3}
>>> interaction_values.get_subset([0, 2]).dict_values
{(0,): 0.1, (2,): 0.3, (0, 2): 0.4}
>>> interaction_values.get_subset([1]).dict_values
{(1,): 0.2}
"""
keys = self.interaction_lookup.keys()
idx, keys_in_subset = [], []
for i, key in enumerate(keys):
if all(p in players for p in key):
idx.append(i)
keys_in_subset.append(key)
new_values = self.values[idx]
new_interaction_lookup = {key: index for index, key in enumerate(keys_in_subset)}
n_players = self.n_players - len(players)
return InteractionValues(
values=new_values,
index=self.index,
max_order=self.max_order,
n_players=n_players,
min_order=self.min_order,
interaction_lookup=new_interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=self.baseline_value,
)
[docs]
def save(self, path: Path) -> None:
"""Save the InteractionValues object to a JSON file.
Args:
path: The path to save the InteractionValues object to.
"""
# check if the directory exists
directory = Path(path).parent
if not Path(directory).exists():
with contextlib.suppress(FileNotFoundError):
Path(directory).mkdir(parents=True, exist_ok=True)
self.to_json_file(path)
[docs]
@classmethod
def load(cls, path: Path | str) -> InteractionValues:
"""Load an InteractionValues object from a file.
Args:
path: The path to load the InteractionValues object from.
Returns:
The loaded InteractionValues object.
"""
path = Path(path)
if not path.name.endswith(".json"):
msg = f"Path {path} does not end with .json. Cannot load InteractionValues."
raise ValueError(msg)
return cls.from_json_file(path)
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> InteractionValues:
"""Create an InteractionValues object from a dictionary.
Args:
data: The dictionary containing the data to create the InteractionValues object from.
Returns:
The InteractionValues object created from the dictionary.
"""
return cls(
values=data["values"],
index=data["index"],
max_order=data["max_order"],
n_players=data["n_players"],
min_order=data["min_order"],
interaction_lookup=data["interaction_lookup"],
estimated=data["estimated"],
estimation_budget=data["estimation_budget"],
baseline_value=data["baseline_value"],
)
[docs]
def to_dict(self) -> dict:
"""Convert the InteractionValues object to a dictionary.
Returns:
The InteractionValues object as a dictionary.
"""
return {
"values": self.interactions,
"index": self.index,
"max_order": self.max_order,
"n_players": self.n_players,
"min_order": self.min_order,
"interaction_lookup": self.interaction_lookup,
"estimated": self.estimated,
"estimation_budget": self.estimation_budget,
"baseline_value": self.baseline_value,
}
[docs]
@classmethod
def from_first_order_array(
cls, first_order_values: np.ndarray, index: str, baseline_value: float = 0
) -> InteractionValues:
"""Convert an array of first-order values to an :class:`shapiq.InteractionValues` object.
Args:
first_order_values: An array containing the value of the ith training point at index i.
index: The game theoretic index of the resulting :class:`shapiq.InteractionValues` object.
baseline_value: Baseline value, defaults to ``0``.
Returns:
An :class:`~shapiq.InteractionValues` object containing the provided values.
"""
n_players = first_order_values.shape[0]
interaction_lookup: dict[tuple[int, ...], int] = {(i,): i for i in range(n_players)}
return InteractionValues(
first_order_values,
index=index,
min_order=0,
max_order=1,
n_players=n_players,
baseline_value=baseline_value,
interaction_lookup=interaction_lookup,
)
[docs]
def to_first_order_array(self) -> np.ndarray:
"""Convert to an array of first-order values.
Returns:
An array of shape ``(self.n_players,)`` containing at index ``i`` the first-order value of player ``i``.
Raises:
ValueError: If the method was called on an :class:`~shapiq.InteractionValues` object with max order
not equal to ``1``.
"""
if self.max_order != 1:
msg = f"Max order must be 1 but was {self.max_order}"
raise ValueError(msg)
out = np.zeros((self.n_players,))
for coalition, lookup_idx in self.interaction_lookup.items():
if coalition == ():
continue
out[coalition[0]] = self.values[lookup_idx]
return out
[docs]
def aggregate(
self,
others: Sequence[InteractionValues],
aggregation: str = "mean",
) -> InteractionValues:
"""Aggregates InteractionValues objects using a specific aggregation method.
Args:
others: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
Returns:
The aggregated InteractionValues object.
Note:
For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
function.
"""
return aggregate_interaction_values([self, *others], aggregation)
[docs]
def plot_network(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
"""Visualize InteractionValues on a graph.
Note:
For arguments, see :func:`shapiq.plot.network.network_plot` and
:func:`shapiq.plot.si_graph.si_graph_plot`.
Args:
show: Whether to show the plot. Defaults to ``True``.
**kwargs: Additional keyword arguments to pass to the plotting function.
Returns:
If show is ``False``, the function returns a tuple with the figure and the axis of the
plot.
"""
from shapiq.plot.network import network_plot
if self.max_order > 1:
return network_plot(
interaction_values=self,
show=show,
**kwargs,
)
msg = (
"InteractionValues contains only 1-order values,"
"but requires also 2-order values for the network plot."
)
raise ValueError(msg)
[docs]
def plot_si_graph(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
"""Visualize InteractionValues as a SI graph.
For arguments, see shapiq.plots.si_graph_plot().
Returns:
The SI graph as a tuple containing the figure and the axes.
"""
from shapiq.plot.si_graph import si_graph_plot
return si_graph_plot(self, show=show, **kwargs)
[docs]
def plot_stacked_bar(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
"""Visualize InteractionValues on a graph.
For arguments, see shapiq.plots.stacked_bar_plot().
Returns:
The stacked bar plot as a tuple containing the figure and the axes.
"""
from shapiq import stacked_bar_plot
return stacked_bar_plot(self, show=show, **kwargs)
[docs]
def plot_force(
self,
feature_names: np.ndarray | None = None,
*,
show: bool = True,
abbreviate: bool = True,
contribution_threshold: float = 0.05,
) -> Figure | None:
"""Visualize InteractionValues on a force plot.
For arguments, see shapiq.plots.force_plot().
Args:
feature_names: The feature names used for plotting. If no feature names are provided, the
feature indices are used instead. Defaults to ``None``.
show: Whether to show the plot. Defaults to ``False``.
abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
contribution_threshold: The threshold for contributions to be displayed in percent.
Defaults to ``0.05``.
Returns:
The force plot as a matplotlib figure (if show is ``False``).
"""
from .plot import force_plot
return force_plot(
self,
feature_names=feature_names,
show=show,
abbreviate=abbreviate,
contribution_threshold=contribution_threshold,
)
[docs]
def plot_waterfall(
self,
feature_names: np.ndarray | None = None,
*,
show: bool = True,
abbreviate: bool = True,
max_display: int = 10,
) -> Axes | None:
"""Draws interaction values on a waterfall plot.
Note:
Requires the ``shap`` Python package to be installed.
Args:
feature_names: The feature names used for plotting. If no feature names are provided, the
feature indices are used instead. Defaults to ``None``.
show: Whether to show the plot. Defaults to ``False``.
abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
max_display: The maximum number of interactions to display. Defaults to ``10``.
"""
from shapiq import waterfall_plot
return waterfall_plot(
self,
feature_names=feature_names,
show=show,
max_display=max_display,
abbreviate=abbreviate,
)
[docs]
def plot_sentence(
self,
words: list[str],
*,
show: bool = True,
**kwargs: Any,
) -> tuple[Figure, Axes] | None:
"""Plots the first order effects (attributions) of a sentence or paragraph.
For arguments, see shapiq.plots.sentence_plot().
Returns:
If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple
with the figure and the axis of the plot.
"""
from shapiq.plot.sentence import sentence_plot
return sentence_plot(self, words, show=show, **kwargs)
[docs]
def plot_upset(self, *, show: bool = True, **kwargs: Any) -> Figure | None:
"""Plots the upset plot.
For arguments, see shapiq.plot.upset_plot().
Returns:
The upset plot as a matplotlib figure (if show is ``False``).
"""
from shapiq.plot.upset import upset_plot
return upset_plot(self, show=show, **kwargs)
[docs]
def aggregate_interaction_values(
interaction_values: Sequence[InteractionValues],
aggregation: str = "mean",
) -> InteractionValues:
"""Aggregates InteractionValues objects using a specific aggregation method.
Args:
interaction_values: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
Returns:
The aggregated InteractionValues object.
Example:
>>> iv1 = InteractionValues(
... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> iv2 = InteractionValues(
... values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), # this iv is missing the (1, 2) value
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, # no (1, 2)
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=1.0,
... )
>>> aggregate_interaction_values([iv1, iv2], "mean")
InteractionValues(
index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
n_players=3, baseline_value=0.5,
Top 10 interactions:
(1, 2): 0.60
(0, 2): 0.35
(0, 1): 0.25
(0,): 0.15
(1,): 0.25
(2,): 0.35
)
Note:
The index of the aggregated InteractionValues object is set to the index of the first
InteractionValues object in the list.
Raises:
ValueError: If the aggregation method is not supported.
"""
def _aggregate(vals: list[float], method: str) -> float:
"""Does the actual aggregation of the values."""
if method == "mean":
return float(np.mean(vals))
if method == "median":
return float(np.median(vals))
if method == "sum":
return np.sum(vals)
if method == "max":
return np.max(vals)
if method == "min":
return np.min(vals)
msg = f"Aggregation method {method} is not supported."
raise ValueError(msg)
# get all keys from all InteractionValues objects
all_keys = set()
for iv in interaction_values:
all_keys.update(iv.interaction_lookup.keys())
all_keys = sorted(all_keys)
# aggregate the values
new_values = np.zeros(len(all_keys), dtype=float)
new_lookup = {}
for i, key in enumerate(all_keys):
new_lookup[key] = i
values = [iv[key] for iv in interaction_values]
new_values[i] = _aggregate(values, aggregation)
max_order = max([iv.max_order for iv in interaction_values])
min_order = min([iv.min_order for iv in interaction_values])
n_players = max([iv.n_players for iv in interaction_values])
baseline_value = _aggregate(
[float(iv.baseline_value) for iv in interaction_values], aggregation
)
estimation_budget = interaction_values[0].estimation_budget
return InteractionValues(
values=new_values,
index=interaction_values[0].index,
max_order=max_order,
n_players=n_players,
min_order=min_order,
interaction_lookup=new_lookup,
estimated=True,
estimation_budget=estimation_budget,
baseline_value=baseline_value,
)
def _validate_and_return_interactions(
values: np.ndarray | dict[tuple[int, ...], float],
interaction_lookup: dict[tuple[int, ...], int] | None,
n_players: int,
min_order: int,
max_order: int,
baseline_value: float | np.number,
) -> dict[tuple[int, ...], float]:
"""Check the interactions for validity and consistency.
Args:
values (np.ndarray | dict[tuple[int, ...], float]): The interaction values.
interaction_lookup (dict[tuple[int, ...], int]): A mapping from interactions to their indices.
n_players (int): The number of players.
min_order (int): The minimum order of interactions.
max_order (int): The maximum order of interactions.
baseline_value (float | np.number): The baseline value to use for empty interactions.
Raises:
TypeError: If the values or interaction_lookup are not of the expected types.
"""
interactions: dict[tuple[int, ...], float]
if isinstance(values, dict):
# Shallow copy is sufficient, as keys and values are immutable (tuples[float,...] and float).
interactions = dict(values)
else:
if interaction_lookup is None:
interaction_lookup = generate_interaction_lookup(
players=n_players,
min_order=min_order,
max_order=max_order,
)
if not isinstance(interaction_lookup, dict):
msg = f"Interaction lookup must be a dictionary. Got {type(interaction_lookup)}."
raise TypeError(msg)
interactions = {
interaction: values[index].item() for interaction, index in interaction_lookup.items()
}
if min_order == 0 and () not in interactions:
interactions[()] = float(baseline_value)
return interactions
def _update_interactions_for_index(
interactions: InteractionScores,
index: str,
target_index: str,
max_order: int,
min_order: int,
baseline_value: float | np.number,
) -> tuple[InteractionScores, str, int, float]:
from .game_theory.aggregation import aggregate_base_attributions
if is_index_aggregated(target_index) and target_index != index:
interactions, index, min_order = aggregate_base_attributions(
interactions=interactions,
index=index,
order=max_order,
min_order=min_order,
baseline_value=float(baseline_value),
)
if () in interactions:
empty_value = interactions[()]
if empty_value != baseline_value and index != "SII":
if is_empty_value_the_baseline(index):
# insert the empty value given in baseline into the values
interactions[()] = float(baseline_value)
else: # manually set baseline to the empty value
baseline_value = interactions[()]
elif min_order == 0:
# TODO(mmshlk): this might not be what we really want to do always: what if empty and baseline are different?
# https://github.com/mmschlk/shapiq/issues/385
interactions[()] = float(baseline_value)
return interactions, index, min_order, float(baseline_value)