Source code for shapiq.approximator.permutation.sii

"""This module implements the Permutation Sampling approximator for the SII (and k-SII) index."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, get_args

import numpy as np

from shapiq.approximator.base import Approximator
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset

if TYPE_CHECKING:
    from collections.abc import Callable

    from shapiq.game import Game

ValidPermutationSIIIndices = Literal["SII", "k-SII"]


[docs] class PermutationSamplingSII(Approximator[ValidPermutationSIIIndices]): """Permutation Sampling approximator for the SII (and k-SII) index. See Also: - :class:`~shapiq.approximator.permutation.stii.PermutationSamplingSTII`: The Permutation Sampling approximator for the STII index - :class:`~shapiq.approximator.permutation.sv.PermutationSamplingSV`: The Permutation Sampling approximator for the SV index """ #: override the valid indices for this approximator valid_indices: tuple[ValidPermutationSIIIndices, ...] = tuple( get_args(ValidPermutationSIIIndices) ) """The valid indices for this permutation sampling approximator.""" def __init__( self, n: int, max_order: int = 2, index: ValidPermutationSIIIndices = "k-SII", *, top_order: bool = False, random_state: int | None = None, ) -> None: """Initialize the Permutation Sampling approximator for SII (and k-SII). Args: n: The number of players. max_order: The interaction order of the approximation. Defaults to ``2``. index: The interaction index to compute. Must be either ``'SII'`` or ``'k-SII'``. top_order: Whether to approximate only the top order interactions (``True``) or all orders up to the specified order (``False``, default). random_state: The random state to use for the permutation sampling. Defaults to ``None``. """ if index not in ["SII", "k-SII"]: msg = f"Invalid index {index}. Must be either 'SII' or 'k-SII'." raise ValueError(msg) super().__init__( n=n, max_order=max_order, index=index, top_order=top_order, random_state=random_state, ) self.iteration_cost: int = self._compute_iteration_cost() def _compute_iteration_cost(self) -> int: """Compute the cost of a single iteration of the permutation sampling. Computes the cost of performing a single iteration of the permutation sampling given the order, the number of players, and the SII index. Returns: int: The cost of a single iteration. """ iteration_cost: int = 0 min_order = 1 if not self.top_order else self.max_order for s in range(min_order, self.max_order + 1): iteration_cost += (self.n - s + 1) * 2**s return iteration_cost def _compute_order_iterator(self) -> np.ndarray: """Computes the order iterator for the SII index. Returns: np.ndarray: The order iterator. """ min_order = 1 if not self.top_order else self.max_order return np.arange(min_order, self.max_order + 1)
[docs] def approximate( self, budget: int, game: Game | Callable[[np.ndarray], np.ndarray], batch_size: int | None = 5, **kwargs: Any, # noqa: ARG002 ) -> InteractionValues: """Approximates the interaction values. Args: budget: The budget for the approximation. game: The game function as a callable that takes a set of players and returns the value. batch_size: The size of the batch. If ``None``, the batch size is set to ``1``. Defaults to ``5``. **kwargs: Additional keyword arguments (unused). Returns: InteractionValues: The estimated interaction values. """ batch_size = 1 if batch_size is None else batch_size used_budget = 0 result = self._init_result() counts = self._init_result(dtype=int) empty_value = game(np.zeros(self.n, dtype=bool))[0] used_budget += 1 # compute the number of iterations and size of the last batch (can be smaller than original) n_iterations, last_batch_size = self._calc_iteration_count( budget - used_budget, batch_size, self.iteration_cost, ) # main permutation sampling loop for iteration in range(1, n_iterations + 1): batch_size = batch_size if iteration != n_iterations else last_batch_size # create the permutations: a 2d matrix of shape (batch_size, n) where each row is a # permutation of the players permutations = np.tile(np.arange(self.n), (batch_size, 1)) self._rng.permuted(permutations, axis=1, out=permutations) n_permutations = permutations.shape[0] n_subsets = n_permutations * self.iteration_cost # get all subsets to evaluate per iteration subsets = np.zeros(shape=(n_subsets, self.n), dtype=bool) subset_index = 0 for permutation_id in range(n_permutations): for order in self._compute_order_iterator(): for k in range(self.n - order + 1): subset = permutations[permutation_id, k : k + order] previous_subset = permutations[permutation_id, :k] for subset_ in powerset(subset, min_size=0): subset_eval = np.concatenate((previous_subset, subset_)).astype(int) subsets[subset_index, subset_eval] = True subset_index += 1 # evaluate all subsets on the game game_values = game(subsets) # update the interaction scores by iterating over the permutations again subset_index = 0 for permutation_id in range(n_permutations): for order in self._compute_order_iterator(): for k in range(self.n - order + 1): interaction = permutations[permutation_id, k : k + order] interaction = tuple(sorted(interaction)) interaction_index = self._interaction_lookup[interaction] counts[interaction_index] += 1 # update the discrete derivative given the subset for subset_ in powerset(interaction, min_size=0): game_value = game_values[subset_index] update = game_value * (-1) ** (order - len(subset_)) result[interaction_index] += update subset_index += 1 used_budget += self.iteration_cost * batch_size # compute mean of interactions result = np.divide(result, counts, out=result, where=counts != 0) return InteractionValues( n_players=self.n, values=result, index=self.approximation_index, interaction_lookup=self._interaction_lookup, baseline_value=empty_value, min_order=self.min_order, max_order=self.max_order, estimated=True, estimation_budget=used_budget, target_index=self.index, )