"""Implementation of the TreeExplainer class.
The :class:`~shapiq.tree.explainer.TreeSHAPIQ` uses the
:class:`~shapiq.tree.treeshapiq.TreeSHAPIQ` algorithm for computing any-order Interactions
for tree ensembles.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal
from shapiq.explainer.base import Explainer
from shapiq.tree.interventional.explainer import InterventionalTreeExplainer
from .linear import LinearTreeSHAP
from .treeshapiq import TreeSHAPIQ, TreeSHAPIQIndices
from .validation import validate_tree_model
if TYPE_CHECKING:
import numpy as np
from shapiq.interaction_values import InteractionValues
from shapiq.typing import Model
from .base import TreeModel
TREE_MODES = Literal["pathdependent", "interventional"]
[docs]
class TreeExplainer(Explainer):
"""The TreeExplainer class for tree-based models.
The explainer for tree-based models using the
:class:`~shapiq.tree.treeshapiq.TreeSHAPIQ` algorithm. For details, refer to
`Muschalik et al. (2024)` [Mus24]_.
TreeSHAP-IQ is an algorithm for computing Shapley Interaction values for tree-based models.
It is based on the Linear TreeSHAP algorithm by `Yu et al. (2022)` [Yu22]_, but extended to
compute Shapley Interaction values up to a given order. TreeSHAP-IQ needs to visit each node
only once and makes use of polynomial arithmetic to compute the Shapley Interaction values
efficiently.
The TreeExplainer can be used with a variety of tree-based models, including
``scikit-learn``, ``XGBoost``, ``LightGBM``, and ``CatBoost``. The explainer can handle both
regression and classification models.
References:
.. [Yu22] Peng Yu, Chao Xu, Albert Bifet, Jesse Read. (2022). Linear Tree Shap. In: Proceedings of 36th Conference on Neural Information Processing Systems. https://openreview.net/forum?id=OzbkiUo24g
.. [Mus24] Maximilian Muschalik, Fabian Fumagalli, Barbara Hammer, & Eyke Hüllermeier (2024). Beyond TreeSHAP: Efficient Computation of Any-Order Shapley Interactions for Tree Ensembles. In: Proceedings of the AAAI Conference on Artificial Intelligence, 38(13), 14388-14396. https://doi.org/10.1609/aaai.v38i13.29352
"""
def __init__(
self,
model: dict | TreeModel | list[TreeModel] | Model,
*,
mode: TREE_MODES = "pathdependent",
reference_dataset: np.ndarray | None = None,
max_order: int = 1,
min_order: int = 0,
index: TreeSHAPIQIndices = "SV",
class_index: int | None = None,
**kwargs: Any, # noqa: ARG002
) -> None:
"""Initializes the TreeExplainer.
Args:
model: A tree-based model to explain.
mode: The mode of the explainer, either ``"pathdependent"`` or ``"interventional"``.
In ``"pathdependent"`` mode, the explainer computes path-dependent interaction values using the TreeSHAPIQ algorithm or the Linear TreeSHAP algorithm if the index is ``"SV"``.
In ``"interventional"`` mode, the explainer computes interventional interaction values using the Interventional TreeExplainer algorithm.
Defaults to ``"pathdependent"``.
max_order: The maximum order of interactions to be computed. Set to ``1`` for no
interactions (i.e, for Shapley values ``"SV"`` or Banzhaf values ``"BV"``). Any
value higher than ``1`` computes interaction values up to that order. Defaults to
``1``.
min_order: The minimum interaction order to keep in the returned
:class:`~shapiq.interaction_values.InteractionValues`. Must satisfy
``0 <= min_order <= max_order``. When ``min_order == 0`` the empty interaction
``()`` is included with the baseline value. When ``min_order >= 1`` all
interactions of order below ``min_order`` are filtered out of the result; the
underlying algorithm still computes them internally when required by aggregated
indices such as ``"k-SII"``. Defaults to ``0``.
index: The type of interaction to be computed. It can be one of
``["k-SII", "SII", "STII", "FSII", "BII", "SV"]``. All indices apart from ``"BII"``
will reduce to the ``"SV"`` (Shapley value) for order 1. Defaults to ``"SV"``.
class_index: The class index of the model to explain. Defaults to ``None``, which will
set the class index to ``1`` per default for classification models and is ignored
for regression models.
reference_dataset: A dataset to be used for reference in the explanation when using `mode=interventional`. Defaults to ``None``.
**kwargs: Additional keyword arguments are ignored.
"""
super().__init__(model, index=index, max_order=max_order)
if min_order < 0 or min_order > self._max_order:
msg = (
f"min_order={min_order} must satisfy 0 <= min_order <= max_order "
f"(max_order={self._max_order})."
)
raise ValueError(msg)
# validate and parse model
self._trees: list[TreeModel] = validate_tree_model(model, class_label=class_index)
self._n_trees = len(self._trees)
self._min_order: int = min_order
self._class_label: int | None = class_index
self.mode = mode
self._reference_dataset: np.ndarray | None = reference_dataset
# In ``"pathdependent"`` mode, build exactly one per-tree explainer list — either
# ``LinearTreeSHAP`` (cheap, order-1 only) or ``TreeSHAPIQ`` (any order). The dispatch
# decision is fixed at construction time so callers can mutate the chosen list (e.g.
# ``_tree.thresholds`` rounding in tests) before calling :meth:`explain`. In
# ``"interventional"`` mode no per-tree list is created — the
# :class:`~shapiq.tree.interventional.explainer.InterventionalTreeExplainer` handles the
# full ensemble in one shot, so a per-tree list would be meaningless.
self._treeshapiq_explainers: list[TreeSHAPIQ] = []
self._lineartreeshap_explainers: list[LinearTreeSHAP] = []
self._interventional_explainer: InterventionalTreeExplainer | None = None
if self.mode == "pathdependent":
if self._can_use_lineartreeshap():
self._lineartreeshap_explainers = [
LinearTreeSHAP(model=tree) for tree in self._trees
]
else:
# ``index`` (the local parameter) is already narrowed to ``TreeSHAPIQIndices``;
# ``self.index`` is the broader ``ExplainerIndices`` and would not type-check.
self._treeshapiq_explainers = [
TreeSHAPIQ(model=tree, max_order=self._max_order, index=index)
for tree in self._trees
]
elif self.mode == "interventional":
if self._reference_dataset is None:
msg = (
"InterventionalTreeExplainer requires a reference_dataset; pass one to "
"TreeExplainer(..., mode='interventional', reference_dataset=...)."
)
raise ValueError(msg)
self._interventional_explainer = InterventionalTreeExplainer(
model=self._trees,
data=self._reference_dataset,
class_index=self._class_label,
max_order=self._max_order,
index=self.index,
)
# Baseline is the sum of the per-tree empty predictions and is identical regardless of
# which algorithm runs explain — derive it from the trees directly so the attribute is
# always populated, including in ``"interventional"`` mode where no per-tree list exists.
self.baseline_value: float = float(sum(tree.empty_prediction for tree in self._trees))
def _can_use_lineartreeshap(self) -> bool:
"""Whether the LinearTreeSHAP fast path can replace TreeSHAP-IQ for this configuration.
LinearTreeSHAP is restricted to first-order Shapley values and needs at least two
distinct features per tree (its Chebyshev base ``chebpts2`` requires ``npts >= 2``).
Trivial trees (constant or single-feature) and higher-order interactions fall back to
TreeSHAP-IQ, which carries dedicated trivial-tree fast paths.
"""
return (
self._max_order == 1
and self.index in ("SV", "SII")
and all(tree.n_features_in_tree >= 2 for tree in self._trees)
)
def _explain_function_lineartreeshap(
self,
x: np.ndarray,
**kwargs: Any, # noqa: ARG002
) -> InteractionValues:
"""Compute first-order Shapley values for ``x`` by aggregating the per-tree LinearTreeSHAP results.
Mirrors the per-tree aggregation done by ``_explain_function_treeshapiq``: each
``LinearTreeSHAP`` in ``self._lineartreeshap_explainers`` runs against ``x``, the
resulting :class:`~shapiq.interaction_values.InteractionValues` are summed (which also
sums ``baseline_value`` and the ``()`` entry), and ``min_order`` is finally enforced via
:meth:`InteractionValues.get_n_order` when the user asked for a stricter minimum.
Args:
x: The instance to explain as a 1-dimensional array.
**kwargs: Additional keyword arguments are ignored.
Returns:
The aggregated Shapley values for the instance.
"""
if len(x.shape) != 1:
msg = "explain expects a single instance, not a batch."
raise TypeError(msg)
interaction_values: list[InteractionValues] = [
lts.explain_function(x) for lts in self._lineartreeshap_explainers
]
final_explanation = interaction_values[0]
for iv in interaction_values[1:]:
final_explanation += iv
if self._min_order > final_explanation.min_order:
final_explanation = final_explanation.get_n_order(
min_order=self._min_order,
max_order=self._max_order,
)
return final_explanation
def _explain_function_interventionaltreeshapiq(
self,
x: np.ndarray,
**kwargs: Any, # noqa: ARG002
) -> InteractionValues:
"""Compute interaction values for ``x`` via the eagerly-built :class:`InterventionalTreeExplainer`.
Args:
x: The instance to explain as a 1-dimensional array.
**kwargs: Additional keyword arguments are ignored.
Returns:
The interaction values for the instance.
"""
if self._interventional_explainer is None:
msg = "Interventional explainer is not initialized; mode must be 'interventional'."
raise RuntimeError(msg)
return self._interventional_explainer.explain_function(x)
def _explain_function_treeshapiq(
self,
x: np.ndarray,
**kwargs: Any, # noqa: ARG002
) -> InteractionValues:
"""Computes the Shapley Interaction values for a single instance.
Args:
x: The instance to explain as a 1-dimensional array.
**kwargs: Additional keyword arguments are ignored.
Returns:
The interaction values for the instance.
"""
if len(x.shape) != 1:
msg = "explain expects a single instance, not a batch."
raise TypeError(msg)
# run treeshapiq for all trees
interaction_values: list[InteractionValues] = []
for explainer in self._treeshapiq_explainers:
tree_explanation = explainer.explain(x)
interaction_values.append(tree_explanation)
# combine the explanations for all trees
final_explanation = interaction_values[0]
if len(interaction_values) > 1:
for i in range(1, len(interaction_values)):
final_explanation += interaction_values[i]
if self._min_order == 0 and final_explanation.min_order == 1:
final_explanation.min_order = 0
# Add the baseline value to the empty prediction
# might break for some edge cases
final_explanation.interactions[()] = float(final_explanation.baseline_value)
if self._min_order > final_explanation.min_order:
final_explanation = final_explanation.get_n_order(
min_order=self._min_order,
max_order=self._max_order,
)
return final_explanation
[docs]
def explain_function( # type: ignore[override]
self,
x: np.ndarray,
*args: Any, # noqa: ARG002
**kwargs: Any,
) -> InteractionValues:
"""Computes the interaction index for a single instance.
The method used for computing the explanation depends on the specified mode and the
parameters of the explainer.
Args:
x: The instance to explain as a 1-dimensional array.
*args: Additional positional arguments are ignored.
**kwargs: Additional keyword arguments forwarded to the per-mode explain function.
Returns:
The computed interaction index for the instance.
"""
if self.mode == "pathdependent":
# Dispatch on whichever per-tree list __init__ chose to populate.
if self._lineartreeshap_explainers:
return self._explain_function_lineartreeshap(x, **kwargs)
return self._explain_function_treeshapiq(x, **kwargs)
return self._explain_function_interventionaltreeshapiq(x, **kwargs)