Source code for shapiq.explainer.tabular

"""Tabular Explainer class for the shapiq package."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal
from warnings import warn

from shapiq.explainer.base import Explainer
from shapiq.game_theory.indices import is_empty_value_the_baseline

from .configuration import setup_approximator
from .custom_types import ExplainerIndices

if TYPE_CHECKING:
    import numpy as np

    from shapiq.approximator.base import Approximator
    from shapiq.imputer.base import Imputer
    from shapiq.interaction_values import InteractionValues
    from shapiq.typing import Model


TabularExplainerApproximators = Literal["spex", "montecarlo", "svarm", "permutation", "regression"]
TabularExplainerImputers = Literal["marginal", "baseline", "conditional"]
TabularExplainerIndices = ExplainerIndices


[docs] class TabularExplainer(Explainer): """The tabular explainer as the main interface for the shapiq package. The ``TabularExplainer`` class is the main interface for the ``shapiq`` package and tabular data. It can be used to explain the predictions of any model by estimating the Shapley interaction values. Attributes: index: Type of Shapley interaction index to use. data: A background data to use for the explainer. Properties: baseline_value: A baseline value of the explainer. """ def __init__( self, model: Model, data: np.ndarray, *, class_index: int | None = None, imputer: Imputer | TabularExplainerImputers = "marginal", approximator: ( Literal["auto"] | TabularExplainerApproximators | Approximator[TabularExplainerIndices] ) = "auto", index: TabularExplainerIndices = "SV", max_order: int = 1, random_state: int | None = None, verbose: bool = False, **kwargs: Any, ) -> None: """Initializes the TabularExplainer. Args: model: The model to be explained as a callable function expecting data points as input and returning 1-dimensional predictions. data: A background dataset to be used for imputation. class_index: The class index of the model to explain. Defaults to ``None``, which will set the class index to ``1`` per default for classification models and is ignored for regression models. imputer: Either an :class:`~shapiq.games.imputer.Imputer` as implemented in the :mod:`~shapiq.games.imputer` module, or a literal string from ``["marginal", "baseline", "conditional"]``. Defaults to ``"marginal"``, which initializes the default :class:`~shapiq.games.imputer.marginal_imputer.MarginalImputer` with its default parameters or as provided in ``kwargs``. approximator: An :class:`~shapiq.approximator.Approximator` object to use for the explainer or a literal string from ``["auto", "spex", "montecarlo", "svarm", "permutation"]``. Defaults to ``"auto"`` which automatically selects: :class:`~shapiq.approximator.KernelSHAP` for ``"SV"``, :class:`~shapiq.approximator.KernelSHAPIQ` for ``"SII"``/``"k-SII"``, :class:`~shapiq.approximator.RegressionFSII` for ``"FSII"``, :class:`~shapiq.approximator.RegressionFBII` for ``"FBII"``, and :class:`~shapiq.approximator.SVARMIQ` for ``"STII"``. index: The index to explain the model with. Defaults to ``"SV"`` which computes the Shapley value. Options: ``"SV"`` (Shapley value), ``"k-SII"`` (k-Shapley Interaction Index), ``"FSII"`` (Faithful Shapley Interaction Index), ``"FBII"`` (Faithful Banzhaf Interaction Index, becomes ``BV`` for order 1), ``"STII"`` (Shapley Taylor Interaction Index), ``"SII"`` (Shapley Interaction Index). max_order: The maximum order of interactions to be computed. Set to ``1`` for no interactions (i.e, for Shapley values ``"SV"`` or Banzhaf values ``"BV"``). random_state: The random state to initialize Imputer and Approximator with. Defaults to ``None``. verbose: Whether to show a progress bar during the computation. Defaults to ``False``. **kwargs: Additional keyword-only arguments passed to the imputers implemented in :mod:`~shapiq.games.imputer`. """ from shapiq.imputer import ( BaselineImputer, GenerativeConditionalImputer, MarginalImputer, TabPFNImputer, ) super().__init__(model, data, class_index, index=index, max_order=max_order) # get class for self class_name = self.__class__.__name__ if self._model_type == "tabpfn" and class_name == "TabularExplainer": warn( "You are using a TabPFN model with the ``shapiq.TabularExplainer`` directly. This " "is not recommended as it uses missing value imputation and not contextualization. " "Consider using the ``shapiq.TabPFNExplainer`` instead. For more information see " "the documentation and the example notebooks.", stacklevel=2, ) if imputer == "marginal": self._imputer = MarginalImputer( self.predict, self._data, random_state=random_state, **kwargs, ) elif imputer == "conditional": self._imputer = GenerativeConditionalImputer( self.predict, self._data, random_state=random_state, **kwargs, ) elif imputer == "baseline": self._imputer = BaselineImputer( self.predict, self._data, random_state=random_state, **kwargs, ) elif isinstance( imputer, MarginalImputer | GenerativeConditionalImputer | BaselineImputer | TabPFNImputer, ): self._imputer = imputer else: msg = ( f"Invalid imputer {imputer}. " f'Must be one of ["marginal", "baseline", "conditional"], or a valid Imputer ' f"object." ) raise ValueError(msg) self._n_features: int = self._data.shape[1] self.imputer.verbose = verbose # set the verbose flag for the imputer self._approximator = setup_approximator( approximator, self.index, self._max_order, self._n_features, random_state, )
[docs] def explain_function( # type: ignore[override] self, x: np.ndarray, budget: int, *, random_state: int | None = None, ) -> InteractionValues: """Explains the model's predictions. Args: x: The data point to explain as a 2-dimensional array with shape (1, n_features). budget: The budget to use for the approximation. It indicates how many coalitions are sampled, thus high values indicate more accurate approximations, but induce higher computational costs. random_state: The random state to re-initialize Imputer and Approximator with. Defaults to ``None``, which will not set a random state. Returns: An object of class :class:`~shapiq.interaction_values.InteractionValues` containing the computed interaction values. """ self.set_random_state(random_state) # initialize the imputer with the explanation point self.imputer.fit(x) # explain interaction_values = self.approximator(budget=budget, game=self.imputer) interaction_values.baseline_value = self.baseline_value # Adjust the Baseline Value if the empty value is the baseline if is_empty_value_the_baseline(interaction_values.index): interaction_values[()] = interaction_values.baseline_value return interaction_values
@property def baseline_value(self) -> float: """Returns the baseline value of the explainer.""" return self.imputer.empty_prediction