Source code for shapiq.explainer.base

"""The base Explainer classes for the shapiq package."""

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Any

from tqdm.auto import tqdm

from .utils import (
    get_explainers,
    get_predict_function_and_model_type,
    print_class,
)
from .validation import validate_data_predict_function, validate_index_and_max_order

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any

    import numpy as np

    from shapiq.approximator.base import Approximator
    from shapiq.game import Game
    from shapiq.game_theory import ExactComputer
    from shapiq.imputer.base import Imputer
    from shapiq.interaction_values import InteractionValues
    from shapiq.typing import Model

    from .custom_types import ExplainerIndices


def generic_to_specific_explainer(
    generic_explainer: Explainer,
    explainer_cls: type[Explainer],
    model: Model | Game | Callable[[np.ndarray], np.ndarray],
    data: np.ndarray | None = None,
    class_index: int | None = None,
    index: ExplainerIndices = "SV",
    max_order: int = 1,
    **kwargs: Any,
) -> None:
    """Transform the base Explainer instance into a specific explainer subclass.

    This function modifies the class of the given object to the specified explainer class and
    initializes it with the provided parameters.

    Args:
        generic_explainer: The base Explainer instance to be transformed.
        explainer_cls: The specific explainer subclass to transform into.
        model: The model object to be explained.
        data: A background dataset to be used for imputation.
        class_index: The class index of the model to explain.
        index: The type of Shapley interaction index to use. Defaults to ``"SV"``.
        max_order: The maximum interaction order to be computed. Defaults to ``1``.
        **kwargs: Additional keyword-only arguments passed to the specific explainer class.
    """
    generic_explainer.__class__ = explainer_cls
    explainer_cls.__init__(
        generic_explainer,
        model=model,
        data=data,
        class_index=class_index,
        index=index,
        max_order=max_order,
        **kwargs,
    )


[docs] class Explainer: """The main Explainer class for a simpler user interface. shapiq.Explainer is a simplified interface for the ``shapiq`` package. It detects between :class:`~shapiq.explainer.tabular.TabularExplainer`, :class:`~shapiq.tree.TreeExplainer`, and :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. For a detailed description of the different explainers, see the respective classes. """ model: Model | Game | Callable[[np.ndarray], np.ndarray] """The model to be explained, either as a Model instance or a callable function.""" _index: ExplainerIndices _max_order: int def __init__( self, model: Model | Game | Callable[[np.ndarray], np.ndarray], data: np.ndarray | None = None, class_index: int | None = None, index: ExplainerIndices = "SV", max_order: int = 1, **kwargs: Any, ) -> None: """Initialize the Explainer class. Args: model: The model object to be explained. data: A background dataset to be used for imputation in :class:`~shapiq.explainer.tabular.TabularExplainer` or :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. This is a 2-dimensional NumPy array with shape ``(n_samples, n_features)``. Can be empty for the :class:`~shapiq.tree.TreeExplainer`, which does not require background data. class_index: The class index of the model to explain. Defaults to ``None``, which will set the class index to ``1`` per default for classification models and is ignored for regression models. Note, it is important to specify the class index for your classification model. index: The type of Shapley interaction index to use. Defaults to ``"SV"``, which computes the Shapley value. To compute interactions, pass an interaction index explicitly and set ``max_order`` accordingly. Options are: - ``"SV"``: Shapley value - ``"k-SII"``: k-Shapley Interaction Index - ``"FSII"``: Faithful Shapley Interaction Index - ``"FBII"``: Faithful Banzhaf Interaction Index (becomes ``BV`` for order 1) - ``"STII"``: Shapley Taylor Interaction Index - ``"SII"``: Shapley Interaction Index max_order: The maximum order of interactions to be computed. Set to ``1`` for no interactions (i.e, for Shapley values ``"SV"`` or Banzhaf values ``"BV"``). **kwargs: Additional keyword-only arguments passed to the specific explainer classes. """ # If Explainer is instantiated directly, dynamically dispatch to the appropriate subclass if self.__class__ is Explainer: model_class = print_class(model) _, model_type = get_predict_function_and_model_type(model, model_class, class_index) explainer_classes = get_explainers() if model_type in explainer_classes: explainer_cls = explainer_classes[model_type] generic_to_specific_explainer( self, explainer_cls, model=model, data=data, class_index=class_index, index=index, max_order=max_order, **kwargs, ) return msg = f"Model '{model_class}' with type '{model_type}' is not supported by shapiq.Explainer." raise TypeError(msg) # proceed with the base Explainer initialization self._model_class = print_class(model) self._shapiq_predict_function, self._model_type = get_predict_function_and_model_type( model, self._model_class, class_index ) # validate the model and data self.model = model if data is not None: validate_data_predict_function(data, predict_function=self.predict, raise_error=False) self._data: np.ndarray = data # validate index and max_order and set them as attributes self._index, self._max_order = validate_index_and_max_order(index, max_order) # initialize private attributes self._imputer: Imputer | None = None self._approximator: Approximator | None = None self._exact_computer: ExactComputer | None = None @property def imputer(self) -> Imputer: """The imputer used by the explainer (or None in the base class).""" if self._imputer is None: msg = "The explainer does not have an imputer. Use a specific explainer class." raise NotImplementedError(msg) return self._imputer @property def exact_computer(self) -> ExactComputer: """The exact computer used by the explainer (or None in the base class).""" if self._exact_computer is None: msg = "The explainer does not have an exact computer. Use a specific explainer class." raise NotImplementedError(msg) return self._exact_computer @property def approximator(self) -> Approximator: """The approximator used by the explainer (or None in the base class).""" if self._approximator is None: msg = "The explainer does not have an approximator. Use a specific explainer class." raise NotImplementedError(msg) return self._approximator @property def index(self) -> ExplainerIndices: """The type of Shapley interaction index the explainer is using.""" return self._index @property def max_order(self) -> int: """The maximum interaction order the explainer is using.""" return self._max_order
[docs] def explain(self, x: np.ndarray | None = None, **kwargs: Any) -> InteractionValues: """Explain a single prediction in terms of interaction values. Args: x: A numpy array of a data point to be explained. **kwargs: Additional keyword-only arguments passed to the specific explainer's ``explain_function`` method. Returns: The interaction values of the prediction. """ return self.explain_function(x=x, **kwargs)
[docs] def set_random_state(self, random_state: int | None = None) -> None: """Set the random state for the explainer and its components. Note: Setting the random state in the explainer will also overwrite the random state in the approximator and imputer, if they are set. Args: random_state: The random state to set. If ``None``, no random state is set. """ if random_state is None: return if self.approximator is not None: self.approximator.set_random_state(random_state=random_state) if self.imputer is not None: self.imputer.set_random_state(random_state=random_state)
[docs] @abstractmethod def explain_function( self, x: np.ndarray | None, *args: Any, **kwargs: Any ) -> InteractionValues: """Explain a single prediction in terms of interaction values. Args: x: A numpy array of a data point to be explained. *args: Additional positional arguments passed to the explainer. **kwargs: Additional keyword-only arguments passed to the explainer. Returns: The interaction values of the prediction. """ msg = "The method `explain` must be implemented in a subclass." raise NotImplementedError(msg)
[docs] def explain_X( self, X: np.ndarray, *, n_jobs: int | None = None, random_state: int | None = None, verbose: bool = False, **kwargs: Any, ) -> list[InteractionValues]: """Explain multiple predictions at once. This method is a wrapper around the ``explain`` method. It allows to explain multiple predictions at once. It is a convenience method that uses the ``joblib`` library to parallelize the computation of the interaction values. Args: X: A 2-dimensional matrix of inputs to be explained with shape (n_samples, n_features). n_jobs: Number of jobs for ``joblib.Parallel``. Defaults to ``None``, which will use no parallelization. If set to ``-1``, all available cores will be used. random_state: The random state to re-initialize Imputer and Approximator with. Defaults to ``None``. verbose: Whether to print a progress bar. Defaults to ``False``. **kwargs: Additional keyword-only arguments passed to the explainer's ``explain_function`` method. Returns: A list of interaction values for each prediction in the input matrix ``X``. """ if len(X.shape) != 2: msg = "The `X` must be a 2-dimensional matrix." raise TypeError(msg) self.set_random_state(random_state=random_state) if n_jobs: # parallelization with joblib import joblib parallel = joblib.Parallel(n_jobs=n_jobs) ivs: list[InteractionValues] = list( parallel(joblib.delayed(self.explain)(X[i, :], **kwargs) for i in range(X.shape[0])) ) else: ivs: list[InteractionValues] = [] pbar = tqdm(total=X.shape[0], desc="Explaining") if verbose else None for i in range(X.shape[0]): ivs.append(self.explain(X[i, :], **kwargs)) if pbar is not None: pbar.update(1) return ivs
[docs] def predict(self, x: np.ndarray) -> np.ndarray: """Provides a unified prediction interface for the explainer. Args: x: An instance/point/sample/observation to be explained. Returns: The model's prediction for the given data point as a vector. """ if isinstance(self._shapiq_predict_function, RuntimeError): raise self._shapiq_predict_function return self._shapiq_predict_function(self.model, x)