Source code for shapiq.game_theory.aggregation

"""Aggregation functions for summarizing base interaction indices into efficient indices useful for explanations."""

from __future__ import annotations

import warnings

import numpy as np
import scipy as sp

from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset


def _change_index(index: str) -> str:
    """Changes the index of the interaction values to the new index.

    Args:
        index: The current index of the interaction values.

    Returns:
        The new index of the interaction values.

    """
    if index in ["SV", "BV"]:  # no change for probabilistic values like SV or BV
        return index
    return f"k-{index}"


def aggregate_base_attributions(
    interactions: dict[tuple[int, ...], float],
    index: str,
    order: int,
    min_order: int,
    baseline_value: float,
) -> tuple[dict[tuple[int, ...], float], str, int]:
    """Aggregates the interactions into an efficient interactions.

    An example aggregation would be the transformation from `SII` values to `k-SII` values.

    Args:
        interactions: The base interaction values to aggregate.
        index: The index of the interaction values.
        order: The order of the aggregation. For example, the order of the k-SII aggregation.
        min_order: The minimum order of the base interactions. If the base interactions have a minimum
            order greater than 1, a warning is raised.
        baseline_value: The baseline value of the interaction values. For example, the baseline value
            of the SII values must not be the same as the values of the empty set.

    Returns:
        A tuple containing:
            - A dictionary mapping interactions to their values.
            - The new index of the interaction values.
            - The new minimum order of the interaction values (always 0 for this aggregation).

    Raises:
        ValueError: If the `order` is smaller than 0.
    """
    if min_order > 1:
        warnings.warn(
            UserWarning(
                "The base interaction values have a minimum order greater than 1. Aggregation may "
                "not be meaningful.",
            ),
            stacklevel=2,
        )

    bernoulli_numbers = sp.special.bernoulli(order)  # used for aggregation
    transformed_interactions: dict[tuple, float] = {(): baseline_value}  # storage
    # iterate over all interactions in base_interactions and project them onto all interactions T
    # where 1 <= |T| <= order
    for base_interaction, base_interaction_value in interactions.items():
        for interaction in powerset(base_interaction, min_size=1, max_size=order):
            scaling = float(bernoulli_numbers[len(base_interaction) - len(interaction)])
            update_interaction = scaling * base_interaction_value
            if update_interaction == 0:
                continue
            transformed_interactions[interaction] = (
                transformed_interactions.get(interaction, 0) + update_interaction
            )
            # if the interactions sum to 0, we pop them from the dict
            if transformed_interactions[interaction] == 0:
                transformed_interactions.pop(interaction)

    # update the index name after the aggregation (e.g., SII -> k-SII)
    new_index = _change_index(index)
    return (
        transformed_interactions,
        new_index,
        0,
    )  # always order 0 for this aggregation


[docs] def aggregate_base_interaction( base_interactions: InteractionValues, order: int | None = None, ) -> InteractionValues: """Aggregates the basis interaction values into an efficient interaction index. An example aggregation would be the transformation from `SII` values to `k-SII` values. Args: base_interactions: The basis interaction values to aggregate. order: The order of the aggregation. For example, the order of the k-SII aggregation. If `None`, the maximum order of the base interactions is used. Defaults to `None`. Returns: The aggregated interaction values. Raises: ValueError: If the `order` is smaller than 0. Examples: >>> import numpy as np >>> from shapiq.interaction_values import InteractionValues >>> sii_values = InteractionValues( ... n_players=3, ... values=np.array([-0.1, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]), ... index="SII", ... interaction_lookup={(): 0, (1,): 1, (2,): 2, (3,): 3, (1, 2): 4, (2, 3): 5, (1, 3): 6}, ... baseline_value=0, # for SII, the baseline value must not be the same as the values of emptyset ... min_order=0, ... max_order=2, ... ) >>> k_sii_values = aggregate_base_interaction(sii_values) >>> k_sii_values.index 'k-SII' >>> k_sii_values.baseline_value 0 >>> k_sii_values.interaction_lookup {(): 0, (1,): 1, (2,): 2, (3,): 3, (1, 2): 4, (2, 3): 5, (1, 3): 6} >>> k_sii_values.max_order 2 """ order = order or base_interactions.max_order transformed_interactions, new_index, new_min_order = aggregate_base_attributions( interactions=base_interactions.interactions, index=base_interactions.index, order=order, min_order=base_interactions.min_order, baseline_value=float(base_interactions.baseline_value), ) return InteractionValues( values=transformed_interactions, n_players=base_interactions.n_players, index=new_index, baseline_value=base_interactions.baseline_value, min_order=new_min_order, max_order=order, estimated=base_interactions.estimated, estimation_budget=base_interactions.estimation_budget, )
def aggregate_to_one_dimension( interactions: InteractionValues, ) -> tuple[np.ndarray, np.ndarray]: """Flattens the higher-order interaction values to positive and negative one-dimensional values. The aggregation summarizes all higher-order interaction in the positive and negative one-dimensional values for each player. The aggregation is done by distributing the interaction scores uniformly to all players in the interaction. For example, the interaction value 5 of the interaction `(1, 2)` is distributed to player 1 and player 2 as 2.5 each. Args: interactions: The interaction values to convert. Returns: The positive and negative interaction values as a 1-dimensional array for each player. """ n = interactions.n_players pos_values = np.zeros(shape=(n,), dtype=float) neg_values = np.zeros(shape=(n,), dtype=float) for interaction in interactions.interaction_lookup: if len(interaction) == 0: continue # skip the empty set interaction_value = interactions[interaction] / len(interaction) # distribute uniformly for player in interaction: if interaction_value >= 0: pos_values[player] += interaction_value else: neg_values[player] += interaction_value return pos_values, neg_values