Source code for shapiq.approximator.permutation.sv

"""This module contains the permutation sampling approximation method for the Shapley value (SV)."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

import numpy as np

from shapiq.approximator.base import Approximator
from shapiq.interaction_values import InteractionValues

if TYPE_CHECKING:
    from collections.abc import Callable

    from shapiq.game import Game
    from shapiq.typing import FloatVector, IntVector


ValidPermutationSVIndices = Literal["SV"]


[docs] class PermutationSamplingSV(Approximator[ValidPermutationSVIndices]): """The Permutation Sampling algorithm for estimating the Shapley values. Permutation Sampling :cite:t:`Castro.2009` (also known as ApproShapley) estimates the Shapley values by drawing random permutations of the player set and extracting all marginal contributions from each permutation. For details, see Castro et al. (2009) :cite:t:`Castro.2009`. Args: n: The number of players. random_state: The random state to use for the permutation sampling. Defaults to ``None``. See Also: - :class:`~shapiq.approximator.permutation.sii.PermutationSamplingSII`: The Permutation Sampling approximator for the SII index - :class:`~shapiq.approximator.permutation.stii.PermutationSamplingSTII`: The Permutation Sampling approximator for the STII index """ valid_indices: tuple[ValidPermutationSVIndices, ...] = ("SV",) def __init__( self, n: int, random_state: int | None = None, **kwargs: Any, # noqa: ARG002 ) -> None: """Initialize the Permutation Sampling approximator for Shapley values. Args: n: The number of players. random_state: The random state to use for the permutation sampling. Defaults to ``None``. **kwargs: Additional keyword arguments (not used for compatibility) """ super().__init__(n=n, max_order=1, index="SV", top_order=False, random_state=random_state) self.iteration_cost: int = n - 1
[docs] def approximate( self, budget: int, game: Game | Callable[[np.ndarray], np.ndarray], batch_size: int | None = 5, **kwargs: Any, # noqa: ARG002 ) -> InteractionValues: """Approximates the Shapley values using ApproShapley. Args: budget: The number of game evaluations for approximation game: The game function as a callable that takes a set of players and returns the value. batch_size: The size of the batch. If ``None``, the batch size is set to ``1``. Defaults to ``5``. *args: Additional positional arguments (not used, only for compatibility). **kwargs: Additional keyword arguments (not used, only for compatibility). Returns: The estimated interaction values. """ result: FloatVector = self._init_result() counts: IntVector = self._init_result(dtype=int) batch_size = 1 if batch_size is None else batch_size # store the values of the empty and full coalition # this saves 2 evaluations per permutation empty_val = float(game(np.zeros(self.n, dtype=bool))[0]) full_val = float(game(np.ones(self.n, dtype=bool))[0]) used_budget = 2 # catch special case of single player game, otherwise iteration through permutations fails if self.n == 1: interaction_index = self._interaction_lookup[self._grand_coalition_tuple] result[interaction_index] = full_val - empty_val counts[interaction_index] = 1 return InteractionValues( values=result, interaction_lookup=self._interaction_lookup, baseline_value=empty_val, min_order=self.min_order, max_order=self.max_order, n_players=self.n, index=self.approximation_index, estimated=True, estimation_budget=used_budget, target_index=self.index, ) # compute the number of iterations and size of the last batch (can be smaller than original) n_iterations, last_batch_size = self._calc_iteration_count( budget - 2, batch_size, self.iteration_cost, ) # main permutation sampling loop for iteration in range(1, n_iterations + 1): batch_size = batch_size if iteration != n_iterations else last_batch_size # create batch_size many permutations: two-dimensional matrix of shape (batch_size, n) # each row is a permutation permutations = np.tile(np.arange(self.n), (batch_size, 1)) self._rng.permuted(permutations, axis=1, out=permutations) n_permutations = batch_size n_coalitions = n_permutations * self.iteration_cost # generate all coalitions to evaluate from permutations coalitions = np.zeros(shape=(n_coalitions, self.n), dtype=bool) coalition_index = 0 for permutation_id in range(n_permutations): permutation = permutations[permutation_id] coalition = set() for i in range(self.n - 1): coalition.add(permutation[i]) coalitions[coalition_index, tuple(coalition)] = True coalition_index += 1 # evaluate the collected coalitions game_values = game(coalitions) used_budget += len(coalitions) # update the estimates coalition_index = 0 for permutation_id in range(n_permutations): # update the first player in the permutation permutation = permutations[permutation_id] marginal_con = game_values[coalition_index] - empty_val permutation_idx = self._interaction_lookup[(permutation[0],)] result[permutation_idx] += marginal_con counts[permutation_idx] += 1 # update the players in the middle of the permutation for i in range(1, self.n - 1): marginal_con = game_values[coalition_index + 1] - game_values[coalition_index] permutation_idx = self._interaction_lookup[(permutation[i],)] result[permutation_idx] += marginal_con counts[permutation_idx] += 1 coalition_index += 1 # update the last player in the permutation marginal_con = full_val - game_values[coalition_index] permutation_idx = self._interaction_lookup[(permutation[self.n - 1],)] result[permutation_idx] += marginal_con counts[permutation_idx] += 1 coalition_index += 1 result = np.divide(result, counts, out=result, where=counts != 0) return InteractionValues( values=result, interaction_lookup=self._interaction_lookup, baseline_value=empty_val, min_order=self.min_order, max_order=self.max_order, n_players=self.n, index=self.approximation_index, estimated=True, estimation_budget=used_budget, target_index=self.index, )