Source code for shapiq.tree.interventional.game

"""Interventional game for tree-based models."""

from __future__ import annotations

import numpy as np

from shapiq.game import Game
from shapiq.utils.modules import safe_isinstance


[docs] class InterventionalGame(Game): """A cooperative game for interventional tree-based model explanations.""" def __init__( self, model: object, reference_data: np.ndarray, target_instance: np.ndarray, class_index: int | None = None, ) -> None: """Initialize the InterventionalGame. Args: model: The tree-based model to explain. reference_data: Background dataset used as reference. target_instance: The instance to explain. class_index: Class index for classification models. Defaults to ``None``. """ if target_instance.ndim == 1: target_instance = target_instance.reshape(1, -1) super().__init__( n_players=target_instance.shape[1], normalize=False, normalization_value=0 ) # number of features # Set class index if classification model if hasattr(model, "predict_proba") and class_index is None: class_index = 1 # default to positive class for binary classification self.model = model self.data = reference_data self.target_instance = target_instance self.class_index = class_index
[docs] def value_function(self, coalitions: np.ndarray) -> np.ndarray: """Compute the value function for the given coalitions. Args: coalitions: Boolean array of shape (n_coalitions, n_players). Returns: Array of values for each coalition. """ n_coalitions = coalitions.shape[0] values = np.zeros(n_coalitions) for i in range(n_coalitions): coalition = coalitions[i] vls = None instanceses = np.where(coalition, self.target_instance, self.data) if self.class_index is not None: if safe_isinstance(self.model, "xgboost.sklearn.XGBClassifier"): import xgboost as xgb # For XGBClassifier, we need to use DMatrix for prediction with output_margin dmatrix_instance = xgb.DMatrix(instanceses) booster = self.model.get_booster() # ty: ignore[unresolved-attribute] logits = booster.predict(dmatrix_instance, output_margin=True) # Append the logit for the specified class index if logits.ndim == 1: # Binary classification case vls = logits if self.class_index == 1 else -logits else: vls = logits[:, self.class_index] elif safe_isinstance(self.model, "lightgbm.LGBMClassifier"): raw_scores = self.model.predict( # ty: ignore[unresolved-attribute] instanceses, raw_score=True ) if raw_scores.ndim == 1: vls = raw_scores if self.class_index == 1 else -raw_scores else: vls = raw_scores[:, self.class_index] else: proba = self.model.predict_proba( # ty: ignore[unresolved-attribute] instanceses ) vls = proba[:, self.class_index] else: vls = self.model.predict(instanceses) # ty: ignore[unresolved-attribute] values[i] = np.mean(vls) return values