Source code for shapiq.tree.interventional.explainer

"""Interventional TreeShap Explainer Implementation."""

from __future__ import annotations

from math import comb
from typing import TYPE_CHECKING

import numpy as np
from scipy.special import binom

from shapiq.game_theory.indices import get_computation_index
from shapiq.interaction_values import InteractionValues
from shapiq.tree.validation import validate_tree_model

# Maximal budget, for which we still use the flatten path: max_order=3 for n_features up to 100, or max_order=4 for n_features up to 20.
_DENSE_FLATTEN_MAX_RESULT_SIZE = 1_000_000


def _dense_flatten_result_size(n_features: int, max_order: int) -> int:
    return sum(comb(n_features, k) for k in range(1, max_order + 1))


if TYPE_CHECKING:
    from collections.abc import Callable

    from shapiq.tree.base import TreeModel

INDICES_C_IMPLEMENTATION_CAPABLE = [
    "SV",
    "SII",
    "BII",
    "BV",
    "SV",
    "CHII",
    "CV",
    "FBII",
    "FSII",
    "STII",
    "CUSTOM",
]


def obtain_E_R_values(tree: TreeModel) -> tuple[list[np.ndarray], list[np.ndarray], list[float]]:
    """Obtain two arrays E and R indicating for each leaf and each feature whether the leaf is reachable by having the feature equal to 1 (E) or equal to 0 (R)."""
    E = []
    R = []
    leaf_vals = []

    # Iterative DFS — avoids Python recursion overhead and recursion-depth limits.
    # Stack entries: (node_id, e_set, r_set)
    stack = [(0, frozenset(), frozenset())]
    while stack:
        node_id, e_set, r_set = stack.pop()
        if tree.children_left[node_id] == tree.children_right[node_id]:  # leaf
            E.append(np.array(sorted(e_set), dtype=np.int64))
            R.append(np.array(sorted(r_set), dtype=np.int64))
            leaf_vals.append(tree.values[node_id].item())
            continue

        feature = int(tree.features[node_id])

        # Go left: feature must be 0 → add to R (unless already constrained to 1).
        if feature not in e_set:
            stack.append((tree.children_left[node_id], e_set, r_set | {feature}))

        # Go right: feature must be 1 → add to E (unless already constrained to 0).
        if feature not in r_set:
            stack.append((tree.children_right[node_id], e_set | {feature}, r_set))

    return E, R, leaf_vals


def obtain_E_R_values_point(
    tree: TreeModel, point_to_explain: np.ndarray, reference_point: np.ndarray
) -> tuple[list[np.ndarray], list[np.ndarray], list[float]]:
    """Obtain two arrays E and R indicating for each leaf and each feature whether the leaf was reached due to features in point_to_explain (E) or due to features in reference_point (R)."""
    E = []
    R = []
    leaf_vals = []

    # Iterative DFS — avoids Python recursion overhead and recursion-depth limits.
    # Stack entries: (node_id, e_set, r_set)
    stack = [(0, frozenset(), frozenset())]
    while stack:
        node_id, e_set, r_set = stack.pop()
        if tree.children_left[node_id] == tree.children_right[node_id]:  # leaf
            E.append(np.array(sorted(e_set), dtype=np.int64))
            R.append(np.array(sorted(r_set), dtype=np.int64))
            leaf_vals.append(tree.values[node_id].item())
            continue

        feature = int(tree.features[node_id])
        explain_goes_left = tree.decision_function(
            point_to_explain[feature],
            tree.thresholds[node_id],
            left_default=tree.children_left_default[node_id],
        )
        child_node_explain = (
            tree.children_left[node_id] if explain_goes_left else tree.children_right[node_id]
        )
        ref_goes_left = tree.decision_function(
            reference_point[feature],
            tree.thresholds[node_id],
            left_default=tree.children_left_default[node_id],
        )
        child_node_ref = (
            tree.children_left[node_id] if ref_goes_left else tree.children_right[node_id]
        )
        if child_node_explain != child_node_ref:
            if feature not in r_set:  # Feature is not fixed by the reference point
                stack.append((child_node_explain, e_set | {feature}, r_set))
            if feature not in e_set:  # Feature is not fixed by the explain point
                stack.append((child_node_ref, e_set, r_set | {feature}))
        else:
            stack.append((child_node_explain, e_set, r_set))
    return E, R, leaf_vals


[docs] class InterventionalTreeExplainer: """Any-order interventional Shapley-interaction explainer for tree models. Extends interventional TreeSHAP to compute exact Shapley interactions of arbitrary order over a single decision tree or a tree ensemble (as validated by :func:`shapiq.tree.validation.validate_tree_model`). Each coalition's contribution is decomposed against a reference background dataset using the ``E``/``R`` partition (features fixed by the explained point vs. by the reference), and the recursion is offloaded to one of two C++ kernels: * **Dense flatten path** — :func:`compute_interactions_flatten` for ``max_order <= 3`` when the dense buffer fits within :data:`_DENSE_FLATTEN_MAX_RESULT_SIZE`. * **Sparse path** — :func:`compute_interactions_batched_sparse` for higher orders or wide-feature trees. Indices supported via the C path are listed in :data:`INDICES_C_IMPLEMENTATION_CAPABLE`. Custom weight functions are accepted via ``weight_fn`` and routed through a precomputed lookup table. The baseline value is computed from the validated trees by summing per-tree predictions over the reference data. ``validate_tree_model`` already scales sklearn ensemble trees by ``1/n_estimators`` and extracts class-specific raw scores for XGBoost / LightGBM classifiers, so this single path lands on the same scale as the kernel for every supported model. Attributes: tree: Validated tree (or list of trees) from :func:`validate_tree_model`. reference_data: Background dataset (shape ``(n_ref, n_features)``) used to define interventional baselines, cast to ``float32``. baseline_value: Mean tree-prediction over ``reference_data`` (scalar); written as the order-0 entry of the returned interactions. max_order: Maximum interaction order computed. index: Interaction index (e.g. ``"SII"``); replaced with ``"CUSTOM"`` when ``weight_fn`` is supplied. n_players: Number of features. """ def __init__( self, model: object, data: np.ndarray, class_index: int | None = None, *, debug: bool = False, max_order: int = 2, index: str = "SII", index_func: Callable | None = None, bool_tree: bool = False, weight_fn: Callable[[int, int, int], float] | None = None, ) -> None: r"""Initialize the InterventionalTreeExplainer. Args: model: A fitted tree or tree ensemble compatible with :func:`shapiq.tree.validation.validate_tree_model` (sklearn, XGBoost, LightGBM, or a precomputed leaf-matrix list). data: Background dataset of shape ``(n_ref, n_features)`` defining the interventional baseline. class_index: Class index for classifiers. For binary ``predict_proba`` models, defaults to ``1`` if left as ``None``. Ignored for regressors. debug: If ``True``, the C++ kernel prints debug information. Defaults to ``False``. max_order: Maximum interaction order to compute. Defaults to ``2``. index: Interaction index; one of :data:`INDICES_C_IMPLEMENTATION_CAPABLE`. Replaced with ``"CUSTOM"`` when ``weight_fn`` is supplied. Defaults to ``"SII"``. index_func: Reserved for a Python-side custom index function. Defaults to ``None``. bool_tree: If ``True``, the tree is treated as boolean (coalitions in :math:`\\{0, 1\\}^n`) and preprocessed once with the BitSet DFS C++ helper instead of per-explanation. Used by :class:`~proxyshap.proxyshap.ProxySHAP`. Defaults to ``False``. weight_fn: Optional custom weight callable with signature ``weight_fn(coalition_size, interaction_size, n_players) -> float``. When supplied, overrides ``index`` and triggers building a precomputed lookup table. Defaults to ``None``. """ # If Classification model and class_index is None, set to 1 if class_index is None and hasattr(model, "predict_proba"): class_index = 1 self.tree = validate_tree_model(model, class_label=class_index) self.reference_data: np.ndarray = data.astype(np.float32) self.debug = debug self.max_order = max_order self.index = index self.n_players = data.shape[1] self.index_func = index_func self.bool_tree = bool_tree self.look_up_table: np.ndarray | None = None self._custom_weight_table: np.ndarray | None = None if weight_fn is not None: self.weight_fn = weight_fn self.index = "CUSTOM" self.look_up_table = self._build_custom_weight_table() # Compute the interventional baseline directly from the validated trees. per_sample_predictions = np.array( [sum(t.predict_one(ref) for t in self.tree) for ref in self.reference_data], dtype=np.float64, ) self.baseline_value = float(per_sample_predictions.mean()) # The sparse C path needs the per-tree flattened arrays. Populate them # whenever we'll route there: max_order > 3 (always sparse) or when the # dense flatten path's result buffer would exceed our memory budget. n_features_hint = int(self.reference_data.shape[1]) self._use_sparse_path = ( self.max_order > 3 or _dense_flatten_result_size(n_features_hint, self.max_order) > _DENSE_FLATTEN_MAX_RESULT_SIZE ) if self._use_sparse_path: self._preprocess_tree_sparse_path() if self.bool_tree: self._preprocess_boolean_tree() def _preprocess_tree_sparse_path(self) -> None: """Flatten per-tree arrays into the layout expected by the sparse C kernel.""" self.values_list = [tree.values.astype(np.float32).flatten() for tree in self.tree] self.threshold_list = [tree.thresholds.astype(np.float32).flatten() for tree in self.tree] self.features_list = [tree.features.astype(np.int64).flatten() for tree in self.tree] self.children_left_list = [ tree.children_left.astype(np.int64).flatten() for tree in self.tree ] self.children_right_list = [ tree.children_right.astype(np.int64).flatten() for tree in self.tree ] self.children_left_default_list = [ tree.children_left_default.astype(bool).flatten() for tree in self.tree ] def _preprocess_boolean_tree(self) -> None: """Gather E and R statistics for boolean tree mode using C++ BitSet DFS.""" from .cext import preprocess_boolean_trees # ty: ignore[unresolved-import] self.n_features = self.reference_data.shape[1] # Prepare tree arrays for C++ values_list = [tree.values.astype(np.float32).flatten() for tree in self.tree] features_list = [tree.features.astype(np.int64).flatten() for tree in self.tree] children_left_list = [tree.children_left.astype(np.int64).flatten() for tree in self.tree] children_right_list = [tree.children_right.astype(np.int64).flatten() for tree in self.tree] ( self.E_R_flatten, self.leaf_vals_flatten, self.e_size_flatten, self.r_size_flatten, self.feature_in_E, self.leaf_id, ) = preprocess_boolean_trees( values_list, features_list, children_left_list, children_right_list, self.n_features, ) self.e_length = len(self.E_R_flatten) self.n_leafs = int(self.leaf_id[-1]) + 1 if len(self.leaf_id) > 0 else 0 def _preprocess_tree(self, explain_point: np.ndarray) -> None: """Gather E and R statistics for the given explain point. Args: explain_point: The instance to explain as a 1-dimensional array. """ self.n_features = self.reference_data.shape[1] E_list: list[np.ndarray] = [] R_list: list[np.ndarray] = [] leaf_vals_list: list[float] = [] for r in self.reference_data: for tree in self.tree: E, R, leaf_vals = obtain_E_R_values_point(tree, explain_point, r) E_list.extend(E) R_list.extend(R) leaf_vals_list.extend(leaf_vals) self.E_list = E_list self.R_list = R_list self.leaf_vals = np.array(leaf_vals_list, dtype=np.float32) n_leafs = len(E_list) # Per-leaf sizes — computed once and reused for all flattened arrays. e_sizes = np.array([len(e) for e in E_list], dtype=np.int64) r_sizes = np.array([len(r) for r in R_list], dtype=np.int64) er_sizes = e_sizes + r_sizes self.n_features_e = e_sizes self.n_features_r = r_sizes # Build flatten numpy arrays if n_leafs > 0: self.E_R_flatten = np.concatenate( [np.concatenate([e, r]) for e, r in zip(E_list, R_list, strict=False)] ).astype(np.int64) self.feature_in_E = np.concatenate( [ np.concatenate( [np.ones(int(e), dtype=np.int64), np.zeros(int(r), dtype=np.int64)] ) for e, r in zip(e_sizes, r_sizes, strict=False) ] ) else: self.E_R_flatten = np.array([], dtype=np.int64) self.feature_in_E = np.array([], dtype=np.int64) self.leaf_id = np.repeat(np.arange(n_leafs, dtype=np.int64), er_sizes) self.leaf_vals_flatten = np.repeat(self.leaf_vals, er_sizes) self.e_size_flatten = np.repeat(e_sizes, er_sizes) self.r_size_flatten = np.repeat(r_sizes, er_sizes) self.e_length = len(self.E_R_flatten) self.n_leafs = n_leafs def _build_custom_weight_table(self) -> np.ndarray: """Precompute the flat weight lookup table for the custom weight function. Returns: A 1-D float64 numpy array of size ``(n+1) * (n+1) * (k+1)^3`` where ``n = self.n_features`` and ``k = self.max_order``. """ n = self.n_features k = self.max_order N = n + 1 K = k + 1 table = np.zeros(N * N * K * K * K, dtype=np.float64) for e in range(N): for r in range( N - e ): # r can only go up to n - e since we can't have more than n features in total for s in range( 1, min(k + 1, e + r) ): # s can only go up to e + r since we can't have an interaction of size s if we don't have at least s features in total in E and R for s_cap_r in range( min(r, s) + 1 ): # s_cap_r can only go up to min(r, s) since we can't have more than r features in R and we can't have more than s features in the interaction if s_cap_r > r: continue idx = e * (N * K * K * K) + r * (K * K) + s_cap_r * K + s table[idx] = self._general_weight(e, r, s_cap_r, s, n) return table def _discrete_weight_to_moebius( self, weight_func: Callable[[int, int, int], float], coalition_size: int, interaction_size: int, ) -> float: """Convert a discrete-derivative weight to its Möbius counterpart. Args: weight_func: Callable ``weight_func(coalition_size, interaction_size, n_players) -> float`` returning the discrete-derivative weight. coalition_size: Size of the coalition. interaction_size: Size of the interaction. Returns: The corresponding Möbius weight. """ return weight_func(coalition_size - interaction_size, interaction_size, coalition_size) def _general_weight( self, e: int, r: int, s_cap_r: int, s: int, n: int, ) -> float: r"""Computes a the general weight $\lambda$ for the interventional tree algorithm. Args: e: Number of features in E. r: Number of features in R. s_cap_r: Number of features in R that are part of the interaction. s: Size of the interaction. n: Total number of features. Returns: The weight $\lambda$ for the given parameters. """ b = n - r sign = (-1) ** s_cap_r return sign * sum( [ (-1) ** k * binom(n - b - s_cap_r, k) * self._discrete_weight_to_moebius( weight_func=self.weight_fn, coalition_size=k + s_cap_r + e, interaction_size=s ) for k in range(n - b - s_cap_r + 1) ] )
[docs] def explain_function( self, x: np.ndarray, **_: dict, ) -> InteractionValues: """Compute interaction values for a single instance. Routes to the dense flatten C kernel for low-order, narrow-feature cases and to the sparse batched C kernel otherwise (see the class docstring). The empty interaction ``()`` is always populated with ``self.baseline_value`` before constructing the result. Args: x: The instance to explain, as a 1-D array of length ``self.n_players``. Returns: :class:`~shapiq.interaction_values.InteractionValues` carrying the requested interaction ``index``, ``max_order=self.max_order``, and ``min_order=1`` as the declared lower bound on computed orders. Note that the underlying interactions dict still contains the empty-set entry ``()`` set to ``self.baseline_value``, even though ``min_order`` reports ``1``. """ from .cext import ( compute_interactions_batched_sparse, # ty: ignore[unresolved-import] compute_interactions_flatten, # ty: ignore[unresolved-import] ) if not self.bool_tree and not self._use_sparse_path: self._preprocess_tree(x) computation_index = get_computation_index(self.index) interactions = {} # For higher order interactions we need to use the sparse implementation as the flatten one is only optimized for main effects, pairwise, and triple interactions. # For orders up to 3, we can use the flatten implementation which is faster. # We also redirect to sparse for orders <= 3 when n_features is large # enough that the dense flatten buffer would blow memory (see # _DENSE_FLATTEN_MAX_RESULT_SIZE). _use_sparse_path is set in __init__. if self._use_sparse_path: interactions = compute_interactions_batched_sparse( self.values_list, self.threshold_list, self.features_list, self.children_left_list, self.children_right_list, self.children_left_default_list, self.reference_data.astype(np.float32), x.astype(np.float32).flatten(), self.tree[0].decision_type, computation_index, self.max_order, self.debug, # whether to print debug information self.look_up_table, # optional custom weight table (None → built-in index) ) else: interactions = compute_interactions_flatten( self.leaf_vals_flatten, self.E_R_flatten, self.e_size_flatten, self.r_size_flatten, self.feature_in_E, self.leaf_id, computation_index, len(self.leaf_vals_flatten), self.n_features, self.e_length, self.max_order, self.debug, # whether to print debug information float( self.reference_data.shape[0] ), # number of reference samples for scaling the results self.look_up_table, # optional custom weight table (None → built-in index) ) interactions[()] = self.baseline_value return InteractionValues( interactions, max_order=self.max_order, min_order=1, index=computation_index, n_players=self.n_players, baseline_value=self.baseline_value, target_index=self.index, )