"""Linear TreeShap Explainer Implementation."""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import scipy.special as sp
from shapiq.interaction_values import InteractionValues
from shapiq.tree.conversion.edges import create_edge_tree
from shapiq.tree.validation import validate_tree_model
from shapiq.utils.sets import generate_interaction_lookup, powerset
if TYPE_CHECKING:
from collections.abc import Callable
from shapiq.typing import IntVector, Model
def get_norm_weight(M: int) -> np.ndarray:
"""Get normalization weights for Linear Tree Shap."""
return np.array([sp.binom(M, i) for i in range(M + 1)])
def get_N_prime(max_size: int = 10) -> np.ndarray:
"""Get N' matrix for Linear Tree Shap."""
N = np.zeros((max_size + 2, max_size + 2))
for i in range(max_size + 2):
N[i, : i + 1] = get_norm_weight(i)
N_prime = np.zeros((max_size + 2, max_size + 2))
for i in range(max_size + 2):
N_prime[i, : i + 1] = N[: i + 1, : i + 1].dot(1 / N[i, : i + 1])
return N_prime
def get_N_v2(D: np.ndarray) -> np.ndarray:
"""Get N_v2 matrix for Linear Tree Shap."""
depth = D.shape[0]
Ns = np.zeros((depth + 1, depth))
for i in range(1, depth + 1):
Ns[i, :i] = np.linalg.inv(np.vander(D[:i]).T).dot(1.0 / get_norm_weight(i - 1))
return Ns
[docs]
class LinearTreeSHAP:
"""Linear TreeSHAP explainer for first-order Shapley values on tree-based models.
Implements the Linear TreeSHAP algorithm of `Yu et al. (2022)
<https://openreview.net/forum?id=OzbkiUo24g>`_ for exact ``order=1`` Shapley value
computation on a single decision tree. The heavy lifting is delegated to a C++ kernel
(``linear_tree_shap_iterative``), which is faster than the any-order
:class:`~shapiq.tree.treeshapiq.TreeSHAPIQ` algorithm when only Shapley values are
needed.
Attributes:
clf: The original tree-based model passed by the user.
edge_tree: Edge-based representation of the tree (:class:`~shapiq.tree.base.EdgeTree`)
used by the C++ kernel for efficient traversal.
N: The :math:`N'` matrix used by Linear TreeSHAP (see :func:`get_N_prime`).
Base: The Chebyshev (or user-supplied) interpolation base of length ``max_depth``.
Offset: The Vandermonde-style power cache of ``Base + 1``.
N_v2: The interpolation N matrix evaluated at ``Base`` (see :func:`get_N_v2`).
"""
def __init__(
self,
model: Model,
*,
base_func: Callable[[int], np.ndarray] = np.polynomial.chebyshev.chebpts2,
) -> None:
"""Initialize the :class:`LinearTreeSHAP` explainer.
Args:
model: A fitted single-tree model accepted by
:func:`~shapiq.tree.validation.validate_tree_model`.
base_func: Callable ``int -> np.ndarray`` returning the interpolation base for the
given depth. Defaults to :func:`numpy.polynomial.chebyshev.chebpts2`.
"""
self.clf = model
self._tree = validate_tree_model(model, class_label=None)[0]
self._relevant_features: np.ndarray = np.array(list(self._tree.feature_ids), dtype=int)
self._tree.reduce_feature_complexity()
self._n_nodes: int = self._tree.n_nodes
self._n_features_in_tree: int = self._tree.n_features_in_tree
self._max_feature_id: int = self._tree.max_feature_id
self._feature_ids: set = self._tree.feature_ids
self._max_order = 1
# precompute interaction lookup tables
self._interactions_lookup_relevant: dict[tuple, int] = generate_interaction_lookup(
self._relevant_features,
0,
1,
)
self._interactions_lookup: dict[int, dict[tuple, int]] = {} # lookup for interactions
self._interaction_update_positions: dict[int, dict[int, IntVector]] = {} # lookup
self._init_interaction_lookup_tables()
self.edge_tree = create_edge_tree(
children_left=self._tree.children_left,
children_right=self._tree.children_right,
features=self._tree.features,
node_sample_weight=self._tree.node_sample_weight,
values=self._tree.values,
max_interaction=1,
n_features=self._max_feature_id + 1,
n_nodes=self._n_nodes,
subset_updates_pos_store=self._interaction_update_positions,
)
self.N = get_N_prime(self.edge_tree.max_depth)
self.Base = base_func(self.edge_tree.max_depth)
self.Offset = np.vander(self.Base + 1).T[::-1]
self.N_v2 = get_N_v2(self.Base)
def _init_interaction_lookup_tables(self) -> None:
"""Initializes the lookup tables for the interaction subsets."""
for order in range(1, self._max_order + 1):
order_interactions_lookup = generate_interaction_lookup(
self._n_features_in_tree,
order,
order,
)
self._interactions_lookup[order] = order_interactions_lookup
_, interaction_update_positions = self._precompute_subsets_with_feature(
interaction_order=order,
n_features=self._n_features_in_tree,
order_interactions_lookup=order_interactions_lookup,
)
self._interaction_update_positions[order] = interaction_update_positions
@staticmethod
def _precompute_subsets_with_feature(
n_features: int,
interaction_order: int,
order_interactions_lookup: dict[tuple, int],
) -> tuple[dict[int, list[tuple]], dict[int, IntVector]]:
"""Precomputes the subsets of interactions that include a given feature.
Args:
n_features: The number of features in the model.
interaction_order: The interaction order to be computed.
order_interactions_lookup: The lookup table of interaction subsets to their positions
in the interaction values array for a given interaction order (e.g. all 2-way
interactions for order ``2``).
Returns:
interaction_updates: A dictionary (lookup table) containing the interaction subsets
for each feature given an interaction order.
interaction_update_positions: A dictionary (lookup table) containing the positions of
the interaction subsets to update for each feature given an interaction order.
"""
# stores interactions that include feature i (needs to be updated when feature i appears)
interaction_updates: dict[int, list[tuple]] = {}
# stores position of interactions that include feature i
interaction_update_positions: dict[int, np.ndarray] = {}
# prepare the interaction updates and positions
for feature_i in range(n_features):
positions = np.zeros(
int(sp.binom(n_features - 1, interaction_order - 1)),
dtype=int,
)
interaction_update_positions[feature_i] = positions.copy()
interaction_updates[feature_i] = []
# fill the interaction updates and positions
position_counter = np.zeros(n_features, dtype=int) # used to keep track of the position
for interaction in powerset(
range(n_features),
min_size=interaction_order,
max_size=interaction_order,
):
for i in interaction:
interaction_updates[i].append(interaction)
position = position_counter[i]
interaction_update_positions[i][position] = order_interactions_lookup[interaction]
position_counter[i] += 1
return interaction_updates, interaction_update_positions
[docs]
def shap_values_cpp_iterative(self, X: np.ndarray) -> np.ndarray:
"""Shapley Value computation using an Iterative C++ Implementation of LinearTreeShap.
Args:
X (np.ndarray): Datapoints
Returns:
np.ndarray: The computed shapley values
"""
from .cext import (
linear_tree_shap_iterative, # ty: ignore[unresolved-import]
)
V = np.zeros_like(X, dtype=np.float64)
V = np.ascontiguousarray(V)
orig_feature_indices = np.array(
[
self._tree.feature_map_internal_original[i] if i != -2 else i
for i in self._tree.features
],
dtype=np.int32,
)
weights = 1 / self.edge_tree.p_e_values
# The kernel routes ``x`` honouring the model's split convention (passed as the
# ``decision_type`` string, same as :class:`InterventionalTreeExplainer`):
# XGBoost-style trees use strict ``x < threshold``, every other supported family
# ``x <= threshold``. This must match ``TreeModel.predict_one`` exactly, otherwise
# instances lying on a split threshold are routed to the wrong leaf and the
# Shapley efficiency property breaks.
linear_tree_shap_iterative(
np.ascontiguousarray(weights, dtype=np.float64),
np.ascontiguousarray(self.edge_tree.empty_predictions, dtype=np.float64),
np.ascontiguousarray(self._tree.thresholds, dtype=np.float64),
np.ascontiguousarray(self.edge_tree.ancestors, dtype=np.int32),
np.ascontiguousarray(self.edge_tree.edge_heights, dtype=np.int32),
np.ascontiguousarray(orig_feature_indices, dtype=np.int32),
np.ascontiguousarray(self._tree.children_left, dtype=np.int32),
np.ascontiguousarray(self._tree.children_right, dtype=np.int32),
self.edge_tree.max_depth,
self._tree.n_nodes,
self.Base,
np.ascontiguousarray(self.Offset, dtype=np.float64),
np.ascontiguousarray(self.N_v2, dtype=np.float64),
np.ascontiguousarray(X, dtype=np.float64),
V,
self._tree.decision_type,
)
return V
[docs]
def explain_function(self, x: np.ndarray) -> InteractionValues:
"""Computes the Shapley values for a single instance.
Args:
x: The instance to explain as a 1-dimensional array.
Returns:
The interaction values for the instance.
"""
shap_values = self.shap_values_cpp_iterative(x.reshape(1, -1)).flatten()
shap_interactions: dict[tuple[int, ...], float] = {
(feature,): float(shap_values[feature])
for feature in range(x.shape[0]) # one entry per input feature, zero outside the tree
}
# n_players matches the user-facing feature space (``x.shape[0]``); the tree's reduced
# feature count would underreport players and break downstream ``get_n_order_values``.
return InteractionValues(
values=shap_interactions,
baseline_value=(
self._tree.empty_prediction
if self._tree.empty_prediction is not None
else float(np.sum(self.edge_tree.empty_predictions))
),
min_order=0,
max_order=1,
index="SV",
n_players=int(x.shape[0]),
)