"""ProxySHAP approximator class."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import numpy as np
from shapiq.approximator.base import Approximator
from shapiq.approximator.montecarlo.shapiq import SHAPIQ
from shapiq.approximator.montecarlo.svarmiq import SVARMIQ
from shapiq.approximator.regression.kernelshapiq import KernelSHAPIQ, ValidKernelSHAPIQIndices
from shapiq.game import Game
from shapiq.interaction_values import InteractionValues
from shapiq.tree.interventional.explainer import InterventionalTreeExplainer
if TYPE_CHECKING:
from collections.abc import Callable
from shapiq.typing import CoalitionMatrix, FloatVector, GameValues
ValidProxySHAPIndices = Literal["k-SII", "FSII", "FBII", "SII", "SV", "BV"]
class ResidualGame(Game):
"""Residual game class for adjusting the proxy model's predictions."""
def __init__(self, n_players: int, game_values: np.ndarray) -> None:
"""Initialize the residual game with the given values for each coalition."""
super().__init__(n_players=n_players, normalize=False)
self.vals = game_values
def value_function(self, coalitions: CoalitionMatrix) -> GameValues: # noqa: ARG002
"""Return the values of the given coalitions in the residual game.
Args:
coalitions: A binary matrix of shape (n_samples, n_features) where each row represents a coalition and each column represents a feature. A value of 1 indicates that the feature is included in the coalition, while a value of 0 indicates that it is not.
Note: The coalitions are expected to be ordered in the same way as the values in self.vals, i.e., the i-th row of coalitions corresponds to the i-th entry in self.vals.
Returns:
A vector of shape (n_samples,) where each entry is the value of the corresponding coalition in the residual game.
"""
return self.vals
[docs]
class ProxySHAP(Approximator[ValidProxySHAPIndices]):
"""ProxySHAP is a proxy-based approximator that uses a regression model to approximate the value function and applies an adjustment method to better match the true value function.
It extends RegressionMSR able to compute any-order cardinal-probabilistic indices and supports multiple adjustment methods, including MSR, SVARMIQ, and KernelSHAPIQ.
The regression model is trained on a subset of the coalitions, and its predictions are adjusted using the selected method to better match the true value function.
Example:
>>> from shapiq_games.synthetic import DummyGame
>>> from shapiq.approximator import ProxySHAP
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = ProxySHAP(n=5, max_order=2, index="k-SII", adjustment="svarm")
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=k-SII, max_order=2, estimated=True, estimation_budget=100
)
"""
def __init__(
self,
n: int,
*,
max_order: int = 2,
index: ValidProxySHAPIndices = "k-SII",
proxy_model: object | None = None,
adjustment: str = "msr",
sampling_weights: FloatVector | None = None,
pairing_trick: bool = True,
random_state: int | None = None,
) -> None:
"""Initialize the ProxySHAP approximator.
Args:
n: Number of features (players).
max_order: Maximum order of interactions to consider.
index: Index of the instance to explain.
proxy_model: Optional proxy model to use for approximating the value function. If None, a default XGBoost regressor will be used.
adjustment: Method for adjusting the proxy model's predictions to better match the true value function. Options are "none" (no adjustment), "msr","svarm" (statified MSR), "kernel" (KernelSHAPIQ).
sampling_weights: Optional array of weights for the sampling procedure. The weights must be of shape (n + 1,) and are used to determine the probability of sampling a coalition. Defaults to None.
pairing_trick: If True, the pairing trick is applied to the sampling procedure. Defaults to True.
random_state: The random state of the estimator. Defaults to None.
"""
super().__init__(
n=n,
max_order=max_order,
index=index,
sampling_weights=sampling_weights,
pairing_trick=pairing_trick,
random_state=random_state,
initialize_dict=False,
)
self._sampling_weights = sampling_weights
self._pairing_trick = pairing_trick
if proxy_model is not None:
self.proxy_model = proxy_model
else:
try:
from xgboost import XGBRegressor
except ImportError as e:
msg = "XGBoost is required for the default proxy model. Install it with: pip install 'shapiq[proxy]' or provide a custom proxy_model that implements the scikit-learn regressor interface."
raise ImportError(msg) from e
self.proxy_model = XGBRegressor(random_state=random_state)
self.set_adjustment_method(adjustment)
[docs]
def set_adjustment_method(self, adjustment: str) -> None:
"""Select the method for adjusting the proxy model's predictions."""
if adjustment not in {"none", "msr", "svarm", "kernel"}:
msg = f"Invalid adjustment method: {adjustment}"
raise ValueError(msg)
self.adjustment = adjustment
match adjustment:
case "msr":
self.adjustment_method = SHAPIQ(
n=self.n,
max_order=self.max_order,
index=self.index,
sampling_weights=self._sampling_weights,
pairing_trick=self._pairing_trick,
random_state=self._random_state,
)
case "svarm":
self.adjustment_method = SVARMIQ(
n=self.n,
max_order=self.max_order,
index=self.index,
sampling_weights=self._sampling_weights,
pairing_trick=self._pairing_trick,
random_state=self._random_state,
)
case "kernel":
if self.index not in ValidKernelSHAPIQIndices:
msg = f"KernelSHAPIQ adjustment is only supported for indices {ValidKernelSHAPIQIndices}, but got index {self.index}"
raise ValueError(msg)
self.adjustment_method = KernelSHAPIQ(
n=self.n,
max_order=self.max_order,
index=self.index,
sampling_weights=self._sampling_weights,
pairing_trick=self._pairing_trick,
random_state=self._random_state,
)
[docs]
def approximate(
self, budget: int, game: Game | Callable[[np.ndarray], np.ndarray], **_: dict
) -> InteractionValues:
"""Approximate the Shapley values using the proxy model and adjustment method."""
# 1. Sample coalitions and fit proxy tree
self._sampler.sample(budget)
coalitions_matrix = self._sampler.coalitions_matrix
coalition_values = game(coalitions_matrix)
baseline_value = coalition_values[0] # Value of the empty coalition
coalition_values -= baseline_value # Normalize values
self.proxy_model.fit( # ty: ignore[unresolved-attribute]
coalitions_matrix, coalition_values
)
# 2. Compute exact index&max_order for the proxy model
explainer = InterventionalTreeExplainer(
self.proxy_model,
data=np.zeros((1, self.n)), # reference data for boolean tree
class_index=None,
index=self.approximation_index,
max_order=self.max_order,
bool_tree=True,
)
proxy_values = explainer.explain_function(np.ones((1, self.n)))
proxy_interactions = InteractionValues(
values=proxy_values.interactions,
index=self.approximation_index,
max_order=self.max_order,
n_players=self.n,
min_order=0,
estimated=budget >= 2**self.n,
estimation_budget=budget,
baseline_value=float(baseline_value),
target_index=self.index,
)
if self.adjustment != "none":
residual_values = (
coalition_values
- self.proxy_model.predict( # ty: ignore[unresolved-attribute]
coalitions_matrix
)
)
residual_values -= residual_values[0] # Normalize residuals
residual_game = ResidualGame(n_players=self.n, game_values=residual_values)
proxy_interactions += self.adjustment_method.approximate(budget, residual_game)
proxy_interactions.baseline_value = baseline_value
proxy_interactions[()] = baseline_value # Ensure empty coalition value is correct
return proxy_interactions