Source code for shapiq.explainer.nn.knn
"""KNN Classifier Explainer."""
from __future__ import annotations
from typing import TYPE_CHECKING, override
import numpy as np
from shapiq.explainer.nn.base import NNExplainerBase
from shapiq.interaction_values import InteractionValues
from ._util import (
assert_valid_index_and_order,
warn_ignored_parameters,
)
if TYPE_CHECKING:
import numpy.typing as npt
from sklearn.neighbors import KNeighborsClassifier
from shapiq.explainer.custom_types import ValidNNExplainerIndices
[docs]
class KNNExplainer(NNExplainerBase):
r"""Explainer for unweighted KNN models.
Implements the algorithm proposed by :footcite:t:`Jia.2019` to efficiently calculate Shapley values for unweighted KNN models.
The algorithm itself has a linear time complexity, but requires sorting training points by distance to the test
point, resulting in a time complexity of :math:`O(N \log N)` for explaining a single data point.
References:
.. footbibliography::
"""
model: KNeighborsClassifier
@override
def __init__(
self,
model: KNeighborsClassifier,
class_index: int | None = None,
data: np.ndarray | None = None,
index: ValidNNExplainerIndices = "SV",
max_order: int = 1,
) -> None:
assert_valid_index_and_order(index, max_order)
warn_ignored_parameters(locals(), ["data"], self.__class__.__name__)
if model.weights != "uniform":
msg = f"KNeighborsClassifier must use weights='uniform', but has weights='{model.weights}'"
raise ValueError(msg)
if not isinstance(model.n_neighbors, int):
msg = f"Expected KNeighborsClassifier.n_neighbors to be int but got {type(model.n_neighbors)}"
raise TypeError(msg)
super().__init__(model, class_index=class_index)
self.k = model.n_neighbors
[docs]
@override
def explain_function(
self, x: npt.NDArray[np.floating]
) -> InteractionValues: # ty: ignore[invalid-method-override]
n = len(self.X_train)
sv = np.zeros(n)
sortperm = self.model.kneighbors(x.reshape(1, -1), n_neighbors=n, return_distance=False)
sortperm = sortperm[0]
y_train_indices_sorted = self.y_train_indices[sortperm]
# Compute indicator function of whether a training point's class agrees with the class to explain
y_train_is_class_index = (y_train_indices_sorted == self.class_index).astype(int)
sv[-1] = y_train_is_class_index[-1] / n
for i in range(n - 2, -1, -1):
sv[i] = sv[i + 1] + (
(y_train_is_class_index[i] - y_train_is_class_index[i + 1]) / self.k
) * (min(self.k, (i + 1)) / (i + 1))
inv_sortperm = np.zeros_like(sortperm)
inv_sortperm[sortperm] = np.arange(sortperm.shape[0])
return InteractionValues.from_first_order_array(sv[inv_sortperm], index="SV")