Source code for shapiq.explainer.nn.threshold_nn
"""Implements the Explainer for threshold nearest neighbor models."""
from __future__ import annotations
from typing import TYPE_CHECKING, cast, override
import numpy as np
from scipy.special import comb
if TYPE_CHECKING:
import numpy.typing as npt
import sklearn.neighbors
from sklearn.neighbors import RadiusNeighborsClassifier
from shapiq.explainer.custom_types import ValidNNExplainerIndices
from shapiq.interaction_values import InteractionValues
from ._util import (
assert_valid_index_and_order,
warn_ignored_parameters,
)
from .base import NNExplainerBase
[docs]
class ThresholdNNExplainer(NNExplainerBase):
"""Explainer for threshold nearest neighbor models.
Implements the algorithm for efficiently computing exact Shapley values for threshold nearest neighbor models proposed by :footcite:t:`Wang.2023`.
The algorithm has a runtime complexity of :math:`O(N)` (when explaining a single data point), where :math:`N` is the number of training samples.
References:
.. footbibliography::
"""
model: RadiusNeighborsClassifier
@override
def __init__(
self,
model: sklearn.neighbors.RadiusNeighborsClassifier,
class_index: int | None = None,
data: np.ndarray | None = None,
index: ValidNNExplainerIndices = "SV",
max_order: int = 1,
) -> None:
assert_valid_index_and_order(index, max_order)
warn_ignored_parameters(locals(), ["data"], self.__class__.__name__)
if not isinstance(model.radius, int | float):
msg = f"Expected RadiusNeighborsClassifier.radius to be int or float but got {type(model.radius)}"
raise TypeError(msg)
super().__init__(model, class_index=class_index)
[docs]
@override
def explain_function(
self, x: npt.NDArray[np.floating]
) -> InteractionValues: # ty: ignore[invalid-method-override]
# Following Theorem 13 and equation (7) in Wang et al. (2023) DOI: 2308.15709v2
# Counting queries defined in C.2.2 ibid.
n_train = self.X_train.shape[0]
n_classes = self.y_train_classes.shape[0]
neighbor_indices = self.model.radius_neighbors(x.reshape(1, -1), return_distance=False)
neighbor_indices = neighbor_indices[0]
in_neighborhood = np.zeros((n_train,), dtype=bool)
in_neighborhood[neighbor_indices] = True
y_train_is_class_index = self.y_train_indices == self.class_index
# For entire dataset D
c_D = n_train
c_x_tau_D = 1 + len(neighbor_indices)
c_plus_z_tau_D = cast("int", np.sum(in_neighborhood & y_train_is_class_index))
# For each training point z_i
c = c_D - 1
c_x_tau = c_x_tau_D - in_neighborhood.astype(int)
c_plus_z_tau = c_plus_z_tau_D - cast(
"npt.NDArray[np.integer]", (in_neighborhood & y_train_is_class_index).astype(int)
)
a1 = np.zeros((n_train,), dtype=np.float64)
mask = in_neighborhood & (c_x_tau >= 2)
a1[mask] = y_train_is_class_index[mask] / c_x_tau[mask] - c_plus_z_tau[mask] / (
c_x_tau[mask] * (c_x_tau[mask] - 1)
)
a2 = -1
divisor = comb(c + 1, c_x_tau_D - 1)
if np.isinf(divisor):
a2 += 1 / np.arange(1, c + 2)
else:
for k in range(c + 1):
binom_term = comb(c - k, (c_x_tau_D - 1)) / divisor
a2 += (1 - binom_term) / (k + 1)
first_summand = a1 * a2
second_summand = np.zeros((n_train,), dtype=np.float64)
second_summand[in_neighborhood] = (
y_train_is_class_index[in_neighborhood] - 1 / n_classes
) / c_x_tau[in_neighborhood]
sv = first_summand + second_summand
iv = InteractionValues.from_first_order_array(sv, index="SV")
iv[()] = 1 / n_classes
return iv