"""Implementation of TabPFNExplainer class.
The TabPFNExplainer is a class for explaining the predictions of a TabPFN model.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from .tabular import TabularExplainer
from .utils import get_predict_function_and_model_type
if TYPE_CHECKING:
from typing import Literal
from shapiq.approximator.base import Approximator
from shapiq.typing import Model
from .custom_types import ExplainerIndices
[docs]
class TabPFNExplainer(TabularExplainer):
"""The TabPFN explainer as the main interface for the shapiq package.
The ``TabPFNExplainer`` class is the dedicated interface for the ``shapiq`` package and
TabPFN :footcite:t:`Hollmann.2025` models such as the ``TabPFNClassifier`` and
``TabPFNRegressor``. The explainer does not rely on classical imputation methods and is
optimized for TabPFN's in-context learning approach. The explanation paradigm for TabPFN is
described in :footcite:t:`Rundel.2024`. In essence the explainer is a wrapper around the
:class:~`shapiq.explainer.tabular.TabularExplainer` class and uses the same API.
References:
.. footbibliography::
"""
def __init__(
self,
model: Model,
data: np.ndarray,
labels: np.ndarray,
*,
index: ExplainerIndices = "SV",
max_order: int = 1,
x_test: np.ndarray | None = None,
empty_prediction: float | None = None,
class_index: int | None = None,
approximator: Approximator
| Literal["auto", "spex", "montecarlo", "svarm", "permutation", "regression"] = "auto",
verbose: bool = False,
) -> None:
"""Initialize the TabPFNExplainer.
Args:
model: Either a TabPFNClassifier or TabPFNRegressor model to be explained.
data: The background data to use for the explainer as a 2-dimensional array with shape
``(n_samples, n_features)``. This data is used to contextualize the model on.
labels: The labels for the background data as a 1-dimensional array with shape
``(n_samples,)``. This data is used to contextualize the model on.
approximator: An :class:`~shapiq.approximator.Approximator` object to use for the
explainer or a literal string from
``["auto", "spex", "montecarlo", "svarm", "permutation"]``. Defaults to ``"auto"``
which automatically selects: :class:`~shapiq.approximator.KernelSHAP` for ``"SV"``,
:class:`~shapiq.approximator.KernelSHAPIQ` for ``"SII"``/``"k-SII"``,
:class:`~shapiq.approximator.RegressionFSII` for ``"FSII"``,
:class:`~shapiq.approximator.RegressionFBII` for ``"FBII"``, and
:class:`~shapiq.approximator.SVARMIQ` for ``"STII"``.
index: The index to explain the model with. Defaults to ``"SV"`` which computes the
Shapley value. Options: ``"SV"`` (Shapley value), ``"k-SII"``
(k-Shapley Interaction Index), ``"FSII"`` (Faithful Shapley Interaction Index),
``"FBII"`` (Faithful Banzhaf Interaction Index, becomes ``BV`` for order 1),
``"STII"`` (Shapley Taylor Interaction Index), ``"SII"`` (Shapley Interaction
Index).
max_order: The maximum order of interactions to be computed. Set to ``1`` for no
interactions (i.e, for Shapley values ``"SV"`` or Banzhaf values ``"BV"``).
x_test: An optional test data set to compute the model's empty prediction (average
prediction) on. If no test data and ``empty_prediction`` is set to ``None`` the last
20% of the background data is used as test data and the remaining 80% as training
data for contextualization. Defaults to ``None``.
empty_prediction: Optional value for the model's average prediction on an empty data
point (all features missing). If provided, overrides parameters in ``x_test``. and
skips the computation of the empty prediction. Defaults to ``None``.
class_index: The class index of the model to explain. Defaults to ``None``, which will
set the class index to ``1`` per default for classification models and is ignored
for regression models.
verbose: Whether to show a progress bar during the computation. Defaults to ``False``.
Note that verbosity can slow down the computation for large datasets.
"""
from shapiq.imputer.tabpfn_imputer import TabPFNImputer
_predict_function, _ = get_predict_function_and_model_type(model, class_index=class_index)
model._shapiq_predict_function = _predict_function # noqa: SLF001
# check that data and labels have the same number of samples
if data.shape[0] != labels.shape[0]:
msg = (
f"The number of samples in `data` and `labels` must be equal (got data.shape= "
f"{data.shape} and labels.shape={labels.shape})."
)
raise ValueError(msg)
n_samples = data.shape[0]
x_train = data
y_train = labels
if x_test is None and empty_prediction is None:
sections = [int(0.8 * n_samples)]
x_train, x_test = np.split(data, sections)
y_train, _ = np.split(labels, sections)
if x_test is None:
x_test = x_train # is not used in the TabPFNImputer if empty_prediction is set
imputer = TabPFNImputer(
model=model,
x_train=x_train,
y_train=y_train,
x_test=x_test,
empty_prediction=empty_prediction,
verbose=verbose,
)
super().__init__(
model,
data=x_test,
imputer=imputer,
class_index=class_index,
approximator=approximator,
index=index,
max_order=max_order,
)
@property
def is_available(self) -> bool:
"""Check if the TabPFN package is available."""
import importlib
try:
importlib.import_module("tabpfn")
except ImportError:
return False
return True