"""Tabular Explainer class for the shapiq package."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal
from warnings import warn
from shapiq.explainer.base import Explainer
from shapiq.game_theory.indices import is_empty_value_the_baseline
from .configuration import setup_approximator
from .custom_types import ExplainerIndices
if TYPE_CHECKING:
import numpy as np
from shapiq.approximator.base import Approximator
from shapiq.imputer.base import Imputer
from shapiq.interaction_values import InteractionValues
from shapiq.typing import Model
TabularExplainerApproximators = Literal["spex", "montecarlo", "svarm", "permutation", "regression"]
TabularExplainerImputers = Literal["marginal", "baseline", "conditional"]
TabularExplainerIndices = ExplainerIndices
[docs]
class TabularExplainer(Explainer):
"""The tabular explainer as the main interface for the shapiq package.
The ``TabularExplainer`` class is the main interface for the ``shapiq`` package and tabular
data. It can be used to explain the predictions of any model by estimating the Shapley
interaction values.
Attributes:
index: Type of Shapley interaction index to use.
data: A background data to use for the explainer.
Properties:
baseline_value: A baseline value of the explainer.
"""
def __init__(
self,
model: Model,
data: np.ndarray,
*,
class_index: int | None = None,
imputer: Imputer | TabularExplainerImputers = "marginal",
approximator: (
Literal["auto"] | TabularExplainerApproximators | Approximator[TabularExplainerIndices]
) = "auto",
index: TabularExplainerIndices = "SV",
max_order: int = 1,
random_state: int | None = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
"""Initializes the TabularExplainer.
Args:
model: The model to be explained as a callable function expecting data points as input
and returning 1-dimensional predictions.
data: A background dataset to be used for imputation.
class_index: The class index of the model to explain. Defaults to ``None``, which will
set the class index to ``1`` per default for classification models and is ignored
for regression models.
imputer: Either an :class:`~shapiq.games.imputer.Imputer` as implemented in the
:mod:`~shapiq.games.imputer` module, or a literal string from
``["marginal", "baseline", "conditional"]``. Defaults to ``"marginal"``, which
initializes the default
:class:`~shapiq.games.imputer.marginal_imputer.MarginalImputer` with its default
parameters or as provided in ``kwargs``.
approximator: An :class:`~shapiq.approximator.Approximator` object to use for the
explainer or a literal string from
``["auto", "spex", "montecarlo", "svarm", "permutation"]``. Defaults to ``"auto"``
which automatically selects: :class:`~shapiq.approximator.KernelSHAP` for ``"SV"``,
:class:`~shapiq.approximator.KernelSHAPIQ` for ``"SII"``/``"k-SII"``,
:class:`~shapiq.approximator.RegressionFSII` for ``"FSII"``,
:class:`~shapiq.approximator.RegressionFBII` for ``"FBII"``, and
:class:`~shapiq.approximator.SVARMIQ` for ``"STII"``.
index: The index to explain the model with. Defaults to ``"SV"`` which computes the
Shapley value. Options: ``"SV"`` (Shapley value), ``"k-SII"``
(k-Shapley Interaction Index), ``"FSII"`` (Faithful Shapley Interaction Index),
``"FBII"`` (Faithful Banzhaf Interaction Index, becomes ``BV`` for order 1),
``"STII"`` (Shapley Taylor Interaction Index), ``"SII"`` (Shapley Interaction
Index).
max_order: The maximum order of interactions to be computed. Set to ``1`` for no
interactions (i.e, for Shapley values ``"SV"`` or Banzhaf values ``"BV"``).
random_state: The random state to initialize Imputer and Approximator with. Defaults to
``None``.
verbose: Whether to show a progress bar during the computation. Defaults to ``False``.
**kwargs: Additional keyword-only arguments passed to the imputers implemented in
:mod:`~shapiq.games.imputer`.
"""
from shapiq.imputer import (
BaselineImputer,
GenerativeConditionalImputer,
MarginalImputer,
TabPFNImputer,
)
super().__init__(model, data, class_index, index=index, max_order=max_order)
# get class for self
class_name = self.__class__.__name__
if self._model_type == "tabpfn" and class_name == "TabularExplainer":
warn(
"You are using a TabPFN model with the ``shapiq.TabularExplainer`` directly. This "
"is not recommended as it uses missing value imputation and not contextualization. "
"Consider using the ``shapiq.TabPFNExplainer`` instead. For more information see "
"the documentation and the example notebooks.",
stacklevel=2,
)
if imputer == "marginal":
self._imputer = MarginalImputer(
self.predict,
self._data,
random_state=random_state,
**kwargs,
)
elif imputer == "conditional":
self._imputer = GenerativeConditionalImputer(
self.predict,
self._data,
random_state=random_state,
**kwargs,
)
elif imputer == "baseline":
self._imputer = BaselineImputer(
self.predict,
self._data,
random_state=random_state,
**kwargs,
)
elif isinstance(
imputer,
MarginalImputer | GenerativeConditionalImputer | BaselineImputer | TabPFNImputer,
):
self._imputer = imputer
else:
msg = (
f"Invalid imputer {imputer}. "
f'Must be one of ["marginal", "baseline", "conditional"], or a valid Imputer '
f"object."
)
raise ValueError(msg)
self._n_features: int = self._data.shape[1]
self.imputer.verbose = verbose # set the verbose flag for the imputer
self._approximator = setup_approximator(
approximator,
self.index,
self._max_order,
self._n_features,
random_state,
)
[docs]
def explain_function( # type: ignore[override]
self,
x: np.ndarray,
budget: int,
*,
random_state: int | None = None,
) -> InteractionValues:
"""Explains the model's predictions.
Args:
x: The data point to explain as a 2-dimensional array with shape
(1, n_features).
budget: The budget to use for the approximation. It indicates how many coalitions are
sampled, thus high values indicate more accurate approximations, but induce higher
computational costs.
random_state: The random state to re-initialize Imputer and Approximator with.
Defaults to ``None``, which will not set a random state.
Returns:
An object of class :class:`~shapiq.interaction_values.InteractionValues` containing
the computed interaction values.
"""
self.set_random_state(random_state)
# initialize the imputer with the explanation point
self.imputer.fit(x)
# explain
interaction_values = self.approximator(budget=budget, game=self.imputer)
interaction_values.baseline_value = self.baseline_value
# Adjust the Baseline Value if the empty value is the baseline
if is_empty_value_the_baseline(interaction_values.index):
interaction_values[()] = interaction_values.baseline_value
return interaction_values
@property
def baseline_value(self) -> float:
"""Returns the baseline value of the explainer."""
return self.imputer.empty_prediction