Source code for shapiq.approximator.proxy.proxyspex

"""ProxySPEX approximator for sparse higher-order interactions."""

from __future__ import annotations

import math
from collections import defaultdict
from typing import TYPE_CHECKING, Any

import numpy as np
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import GridSearchCV

from shapiq.approximator.base import Approximator
from shapiq.game_theory.moebius_converter import MoebiusConverter, ValidMoebiusConverterIndices
from shapiq.interaction_values import InteractionValues
from shapiq.tree.conversion import convert_tree_model
from shapiq.utils.sets import powerset

if TYPE_CHECKING:
    from collections.abc import Callable

    from shapiq.game import Game
    from shapiq.tree.base import TreeModel
    from shapiq.typing import Model


ValidProxySPEXIndices = ValidMoebiusConverterIndices


[docs] class ProxySPEX(Approximator[ValidProxySPEXIndices]): """ProxySPEX (SParse EXplainer) via Fourier transform sampling. An approximator for cardinal interaction indices using Fourier transform sampling to efficiently compute sparse higher-order interactions. ProxySPEX is presented in :cite:t:`Butler.2025`. """ def __init__( self, *, n: int, max_order: int = 2, index: ValidProxySPEXIndices = "k-SII", proxy_model: Model | None = None, sampling_weights: np.ndarray | None = None, pairing_trick: bool = False, top_order: bool = False, random_state: int | None = None, ) -> None: """Initialize the ProxySPEX approximator. Args: n: Number of players (features). max_order: Maximum interaction order to consider. index: The Interaction index to use. All indices supported by shapiq's :class:`~shapiq.game_theory.moebius_converter.MoebiusConverter` are supported. top_order: If ``True``, only reports interactions of exactly order ``max_order``. Otherwise, reports all interactions up to order ``max_order``. Defaults to ``False``. pairing_trick: If `True`, the pairing trick is applied to the sampling procedure. Defaults to ``False``. sampling_weights: An optional array of weights for the sampling procedure. The weights must be of shape ``(n + 1,)`` and are used to determine the probability of sampling a coalition of a certain size. Defaults to ``None``. proxy_model: Optional proxy model to use for approximating the value function. If None, a default LightGBM regressor with hyperparameter tuning will be used. The model must implement the scikit-learn regressor interface (i.e., it must have fit and predict methods). Defaults to None. random_state: Seed for random number generator. Defaults to ``None``. """ if sampling_weights is None: sampling_weights = np.array([math.comb(n, i) for i in range(n + 1)], dtype=float) if proxy_model is None: try: import lightgbm as lgb except ImportError as err: msg = "The 'lightgbm' package is required for the default proxy model in ProxySPEX but it is not installed. Install it with: pip install 'shapiq[proxy]' or provide a custom proxy_model that implements the scikit-learn regressor interface." raise ImportError(msg) from err decoder_args = { "max_depth": [3, 5], "max_iter": [500, 1000], "learning_rate": [0.01, 0.1], } base_model = lgb.LGBMRegressor(random_state=random_state) self.proxy_model = GridSearchCV( estimator=base_model, param_grid=decoder_args, scoring="r2", cv=5, verbose=0, n_jobs=1, ) else: self.proxy_model = proxy_model super().__init__( n=n, max_order=max_order, index=index, top_order=top_order, pairing_trick=pairing_trick, random_state=random_state, sampling_weights=sampling_weights, initialize_dict=True, )
[docs] def approximate( self, budget: int, game: Game | Callable[[np.ndarray], np.ndarray], **kwargs: Any, # noqa: ARG002 ) -> InteractionValues: """Approximates the interaction values using a sparse transform approach. Args: budget: The budget for the approximation. game: The game function that returns the values for the coalitions. **kwargs: Additional keyword arguments (not used). Returns: The approximated Shapley interaction values. """ # Take the budget amount of uniform samples self._sampler.sample(budget) coalitions_matrix = self._sampler.coalitions_matrix coalition_values = game(coalitions_matrix) # Fit the model on the training data self.proxy_model.fit(coalitions_matrix, coalition_values) if isinstance(self.proxy_model, GridSearchCV): final_model = self.proxy_model.best_estimator_ else: final_model = self.proxy_model # Obtain TreeModel tree_models = convert_tree_model(final_model) # Obtain fourier coefficients unrefined_fourier = self._sklearn_to_fourier(tree_models=tree_models) # Refine the Fourier coefficients using the training data refined_fourier = self._refine( unrefined_fourier, coalitions_matrix, coalition_values, ) # Convert the Fourier coefficients to the Moebius transform moebius_transform = self.fourier_to_moebius(refined_fourier) # Convert the Moebius transform to the desired index result = self._process_moebius(moebius_transform=moebius_transform) # Filter the output as needed if self.top_order: result = self._filter_order(result) # finalize the interactions return InteractionValues( values=result, index=self.approximation_index, min_order=self.min_order, max_order=self.max_order, n_players=self.n, interaction_lookup=self.interaction_lookup, estimated=True, estimation_budget=budget, baseline_value=result[self.interaction_lookup[()]] if () in self.interaction_lookup else 0.0, target_index=self.index, )
[docs] def fourier_to_moebius( self, four_dict: dict[tuple[int, ...], float] ) -> dict[tuple[int, ...], float]: """Converts a Fourier representation of a function to its Moebius representation.""" moebius_dict = defaultdict(float) for four_interaction, four_coef in four_dict.items(): for moebius_interaction in powerset(four_interaction): moebius_dict[moebius_interaction] += four_coef * (-2) ** (len(moebius_interaction)) return dict(moebius_dict)
def _sklearn_to_fourier(self, tree_models: list[TreeModel]) -> dict[tuple[int, ...], float]: """Extracts the aggregated Fourier coefficients from a list of sklearn tree models. This method iterates over all trees in the ensemble, computes the Fourier coefficients for each individual tree using the `_sklearn_tree_to_fourier` helper method, and then sums these coefficients to get the final Fourier representation of the complete model. Args: tree_models: A list of `TreeModel` instances representing the tree using the sklearn.tree structure. Returns: A dictionary that maps interaction tuples (representing Fourier frequencies) to their aggregated Fourier coefficients. """ aggregated_coeffs = defaultdict(float) for tree_model in tree_models: tree_coeffs = self._sklearn_tree_to_fourier(tree_model) for interaction, value in tree_coeffs.items(): aggregated_coeffs[interaction] += value # Convert defaultdict to a standard dict, removing zero-valued coefficients return {k: v for k, v in aggregated_coeffs.items() if v != 0.0} def _sklearn_tree_to_fourier(self, tree_model: TreeModel) -> dict[tuple[int, ...], float]: """Recursively extracts Fourier coefficients from a single sklearn decision tree. This method traverses the sklearn decision tree structure defined in `shapiq.tree.base`. It computes the Fourier representation of the piecewise-constant function that the tree defines. Args: tree_model: A fitted sklearn decision tree model. """ def _combine_coeffs( left_coeffs: dict[tuple[int, ...], float], right_coeffs: dict[tuple[int, ...], float], feature_idx: int, ) -> dict[tuple[int, ...], float]: """Combines Fourier coefficients from the left and right children of a split node.""" combined_coeffs = {} all_interactions = set(left_coeffs.keys()) | set(right_coeffs.keys()) for interaction in all_interactions: left_val = left_coeffs.get(interaction, 0.0) right_val = right_coeffs.get(interaction, 0.0) combined_coeffs[interaction] = (left_val + right_val) / 2 new_interaction = tuple(sorted(set(interaction) | {feature_idx})) combined_coeffs[new_interaction] = (left_val - right_val) / 2 return combined_coeffs def _dfs_traverse(node_id: int) -> dict[tuple[int, ...], float]: """Performs a depth-first traversal of the tree to compute coefficients.""" # Base case: if the node is a leaf, its function is a constant. if tree_model.children_left[node_id] == -1 and tree_model.children_right[node_id] == -1: # The only non-zero coefficient is for the empty interaction (the bias term). return {(): tree_model.values[node_id]} # Recursive step: if the node is a split node. left_child_id = tree_model.children_left[node_id] right_child_id = tree_model.children_right[node_id] left_coeffs = _dfs_traverse(left_child_id) right_coeffs = _dfs_traverse(right_child_id) feature_idx = tree_model.features[node_id] return _combine_coeffs(left_coeffs, right_coeffs, feature_idx) return _dfs_traverse(0) # Start traversal from the root node (id=0) def _refine( self, four_dict: dict[tuple[int, ...], float], train_X: np.ndarray, train_y: np.ndarray, ) -> dict[tuple[int, ...], float]: """Refines the estimated Fourier coefficients using a Ridge regression model. This method takes an initial set of estimated Fourier coefficients and refines them to better fit the observed game values. It first identifies the most significant coefficients by keeping those that contribute to 95% of the total "energy" (sum of squared Fourier coefficients, excluding the baseline). Then, it constructs a new feature matrix based on the Fourier basis functions corresponding to these significant interactions. Finally, it fits a `RidgeCV` model to re-estimate the values of these coefficients, effectively fine-tuning them against the training data. Args: four_dict: A dictionary mapping interaction tuples to their initial estimated Fourier coefficient values. train_X: The training data matrix where rows are coalitions (binary vectors) and columns are players. train_y: The corresponding game values for each coalition in `train_X`. Returns: A dictionary containing the refined Fourier coefficients for the most significant interactions. """ n = train_X.shape[1] four_items = list(four_dict.items()) if len(four_items) <= self.n: return four_dict list_keys = [item[0] for item in four_items] four_coefs = np.array([item[1] for item in four_items]) nfc_idx = list_keys.index(()) if () in list_keys else None four_coefs_for_energy = np.copy(four_coefs) if nfc_idx is not None: four_coefs_for_energy[nfc_idx] = 0 four_coefs_sq = four_coefs_for_energy**2 tot_energy = np.sum(four_coefs_sq) if tot_energy == 0: return four_dict sorted_four_coefs_sq = np.sort(four_coefs_sq)[::-1] cumulative_energy_ratio = np.cumsum(sorted_four_coefs_sq / tot_energy) thresh_idx_95 = np.argmin(cumulative_energy_ratio < 0.95) + 1 thresh = np.sqrt(sorted_four_coefs_sq[thresh_idx_95]) four_dict_trunc = { tuple(int(i in k) for i in range(n)): v for k, v in four_dict.items() if abs(v) > thresh } support = np.array(list(four_dict_trunc.keys())) # Construct the fourier basis coefficient matrix for the training data X = np.real(np.exp(train_X @ (1j * np.pi * support.T))) # Solve the regression problem to obtain refined Fourier coefficients reg = RidgeCV(alphas=np.logspace(-6, 6, 100), fit_intercept=False).fit(X, train_y) regression_coefs = dict( zip([tuple(s.astype(int)) for s in support], reg.coef_, strict=False) ) return {tuple(i for i, x in enumerate(k) if x): v for k, v in regression_coefs.items()} def _process_moebius(self, moebius_transform: dict[tuple, float]) -> np.ndarray: """Convert the Moebius transform into the desired index. Args: moebius_transform: The Moebius transform to process as a dict mapping tuples to float values. Returns: np.ndarray: The converted interaction values based on the specified index. The function also updates the internal _interaction_lookup dictionary. """ moebius_interactions = InteractionValues( values=moebius_transform, index="Moebius", min_order=self.min_order, max_order=self.max_order, n_players=self.n, estimated=True, baseline_value=moebius_transform.get((), 0.0), ) autoconverter = MoebiusConverter(moebius_coefficients=moebius_interactions) converted_interaction_values = autoconverter(index=self.index, order=self.max_order) self._interaction_lookup = converted_interaction_values.interaction_lookup return converted_interaction_values.values def _filter_order(self, result: np.ndarray) -> np.ndarray: """Filters the interactions to keep only those of the maximum order. This method is used when top_order=True to filter out all interactions that are not of exactly the maximum order (self.max_order). Args: result: Array of interaction values. Returns: Filtered array containing only interaction values of the maximum order. The method also updates the internal _interaction_lookup dictionary. """ filtered_interactions = {} filtered_results = [] i = 0 for j, key in enumerate(self.interaction_lookup): if len(key) == self.max_order: filtered_interactions[key] = i filtered_results.append(result[j]) i += 1 self._interaction_lookup = filtered_interactions return np.array(filtered_results)