.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/nearest_neighbor/plot_nn_data_valuation.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_nearest_neighbor_plot_nn_data_valuation.py: Data Valuation with Nearest Neighbor Explainers =============================================== This notebook shows how explainers of nearest-neighbor (NN) models can be used for Data Valuation, the task of evaluating the usefulness of individual training data points in classification problems. When explaining NN models, a game is defined by first choosing an explanation point :math:`x_\text{explain}` and class :math:`y_\text{explain}`; the training data points :math:`\mathcal{D} := \mathcal{X} \times \mathcal{Y}` are the game's players, and the definition of the utility :math:`\nu(S)` of a coalition :math:`S \subseteq \mathcal{D}` is based on the probability of the model predicting class :math:`y_\text{explain}` on :math:`x_\text{explain}` if it's training data were limited to :math:`S`. .. GENERATED FROM PYTHON SOURCE LINES 10-16 There is support for explaining the the ``KNeighborsClassifier`` model (with ``'uniform'`` or ``'distance'`` weights) and ``RadiusNeighborsClassifier`` model from the ``scikit-learn`` library. The algorithms are based on the publications from `Jia et al. (2019) `__, `Wang et al. (2024) `__ and `Wang et al. (2023) `__, respectively. Let's start by generating a synthetic classification datset and fitting a simple `KNeighborsClassifier` to it. .. GENERATED FROM PYTHON SOURCE LINES 16-47 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import make_classification from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier from util_plot import plot_datasets X_train, y_train = make_classification( n_samples=30, n_features=2, n_redundant=0, n_clusters_per_class=1, n_informative=2, n_classes=2, random_state=45, ) fig, ax = plt.subplots(figsize=(6, 6)) plot_datasets(ax, X_train, y_train) model = KNeighborsClassifier(n_neighbors=3) model.fit(X_train, y_train) x_explain = np.array([[-0.75, -0.4]]) y_explain_pred = model.predict(x_explain)[0] print(f"Prediction: class {y_explain_pred}") y_explain_proba = model.predict_proba(x_explain)[0] print(f"Prediction probabilities: {y_explain_proba}") .. image-sg:: /auto_examples/nearest_neighbor/images/sphx_glr_plot_nn_data_valuation_001.png :alt: plot nn data valuation :srcset: /auto_examples/nearest_neighbor/images/sphx_glr_plot_nn_data_valuation_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Prediction: class 0 Prediction probabilities: [0.66666667 0.33333333] .. GENERATED FROM PYTHON SOURCE LINES 48-52 Using the ``KNNExplainer`` for Unweighted :math:`k`-Nearest Neighbor Models --------------------------------------------------------------------------- To explain the prediction, we create an explainer for the model by passing it to the constructor of ``Explainer``, which will automatically dispatch to the adequate subclass ``KNNExplainer``. .. GENERATED FROM PYTHON SOURCE LINES 52-59 .. code-block:: Python from shapiq import Explainer explainer = Explainer(model, class_index=y_explain_pred, index="SV", max_order=1) print(type(explainer)) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 60-63 Note that we set ``class_index=y_explain_pred``, since for now, we want to quantify the contribution of the training data to the class that was actually predicted. (We could also set a different class index if we wished to see how much the data points contribute to shifting the prediction towards another class.) Now we can get an explanation for the prediction we saw above: .. GENERATED FROM PYTHON SOURCE LINES 63-68 .. code-block:: Python iv = explainer.explain(x_explain) print(iv) .. rst-class:: sphx-glr-script-out .. code-block:: none InteractionValues( index=SV, max_order=1, min_order=0, estimated=True, estimation_budget=None, n_players=30, baseline_value=0, Top 10 interactions: (9,): 0.20212345126138226 (23,): 0.20212345126138226 (2,): 0.11879011792804894 (14,): 0.08545678459471562 (15,): 0.06759964173757277 (28,): 0.06759964173757277 (16,): 0.058508732646663675 (19,): -0.08120988207195104 (1,): -0.13120988207195106 (4,): -0.13120988207195106 ) .. GENERATED FROM PYTHON SOURCE LINES 69-71 Explaining Weighted :math:`k`-Nearest Neighbor and Threshold Nearest Neighbor Models ------------------------------------------------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 73-74 There are separate explainers for weighted :math:`k`-NN and threshold NN models, which are selected automatically when an `Explainer` is instantiated with a corresponding model: .. GENERATED FROM PYTHON SOURCE LINES 74-86 .. code-block:: Python wknn_model = KNeighborsClassifier(n_neighbors=3, weights="distance") wknn_model.fit(X_train, y_train) wknn_explainer = Explainer(wknn_model, class_index=0, index="SV", max_order=1) print(type(wknn_explainer)) tnn_model = RadiusNeighborsClassifier() tnn_model.fit(X_train, y_train) tnn_explainer = Explainer(tnn_model, class_index=0, index="SV", max_order=1) print(type(tnn_explainer)) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 87-88 They can be used just the same way: .. GENERATED FROM PYTHON SOURCE LINES 88-93 .. code-block:: Python print(wknn_explainer.explain(x_explain)) print(tnn_explainer.explain(x_explain)) .. rst-class:: sphx-glr-script-out .. code-block:: none InteractionValues( index=SV, max_order=1, min_order=0, estimated=True, estimation_budget=None, n_players=30, baseline_value=0, Top 10 interactions: (23,): 0.5414010940326729 (9,): 0.4580677606993394 (2,): 0.1858067858067858 (14,): 0.09743098032571718 (15,): 0.08314526604000289 (28,): 0.08314526604000289 (21,): -0.06915214415214416 (19,): -0.09058071558071559 (1,): -0.2653449377133588 (4,): -0.3752655726339937 ) InteractionValues( index=SV, max_order=1, min_order=0, estimated=True, estimation_budget=None, n_players=30, baseline_value=0, Top 10 interactions: (): 0.5 (2,): 0.13726715203987933 (14,): 0.13726715203987933 (15,): 0.13726715203987933 (23,): 0.13726715203987933 (1,): -0.1556296733569461 (4,): -0.1556296733569461 (17,): -0.1556296733569461 (19,): -0.1556296733569461 (21,): -0.1556296733569461 ) .. GENERATED FROM PYTHON SOURCE LINES 94-98 Large numbers of training samples ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Since the algorithms are pretty efficient, we can run them on large sets of training data. .. GENERATED FROM PYTHON SOURCE LINES 98-129 .. code-block:: Python from time import time def print_explain_times(model, n, n_test) -> None: X_train, y_train = make_classification( n_samples=n, n_features=5, n_redundant=0, n_clusters_per_class=1, n_informative=3, n_classes=2, random_state=45, ) X_test = X_train[:n_test] X_train = X_train[n_test:] y_train = y_train[n_test:] model.fit(X_train, y_train) explainer = Explainer(model, class_index=0, index="SV", max_order=1) times = np.zeros((n_test,)) for i, x_test in enumerate(X_test): t_start = time() explainer.explain(x_test) t_end = time() times[i] = t_end - t_start mean = np.mean(times) * 1000 std = np.std(times) * 1000 print(f"{explainer.__class__.__name__} on {n} samples: average {mean:.1f}±{std:.1f}ms") .. GENERATED FROM PYTHON SOURCE LINES 130-131 The cell below which uses the KNN explainer takes roughly 0.15 s to explain a single data point on a consumer-grade laptop with a 12th Gen Intel i5 processor. .. GENERATED FROM PYTHON SOURCE LINES 131-135 .. code-block:: Python print_explain_times(KNeighborsClassifier(n_neighbors=5, weights="uniform"), n=100_000, n_test=50) .. rst-class:: sphx-glr-script-out .. code-block:: none KNNExplainer on 100000 samples: average 151.3±1.2ms .. GENERATED FROM PYTHON SOURCE LINES 136-137 Since the algorithm of the WKNN explainer is less efficient, featuring a quadratic runtime complexity, the number of data points needs to be limited. .. GENERATED FROM PYTHON SOURCE LINES 137-141 .. code-block:: Python print_explain_times(KNeighborsClassifier(n_neighbors=5, weights="distance"), n=200, n_test=10) .. rst-class:: sphx-glr-script-out .. code-block:: none WeightedKNNExplainer on 200 samples: average 1843.1±6.2ms .. GENERATED FROM PYTHON SOURCE LINES 142-143 The TNN algorithm, on the other hand, is faster: .. GENERATED FROM PYTHON SOURCE LINES 143-147 .. code-block:: Python print_explain_times(RadiusNeighborsClassifier(radius=5), n=100_000, n_test=50) .. rst-class:: sphx-glr-script-out .. code-block:: none ThresholdNNExplainer on 100000 samples: average 66.8±1.0ms .. GENERATED FROM PYTHON SOURCE LINES 148-154 ## Identifying corrupted training samples ----------------------------------------- We can estimate the usefulness of each point of a training datset by calculating Shapley values for a set of test data points and averaging the results. This will allow us to identify potentially mislabeled data points. First, let's create a classification datset and split it into train and test sets. We will corrupt the training data by changing the class of a few randomly selected data points. .. GENERATED FROM PYTHON SOURCE LINES 154-189 .. code-block:: Python from sklearn.model_selection import train_test_split X, y = make_classification( n_samples=100, n_features=2, n_redundant=0, n_clusters_per_class=1, n_informative=2, n_classes=2, flip_y=0, random_state=49, class_sep=1.5, ) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42) y_train_corrupted = y_train.copy() n_corrupt = 7 rng = np.random.default_rng(seed=43) corrupted = rng.choice(np.arange(X_train.shape[0]), size=n_corrupt, replace=False) # Since our only class indices are 0 and 1, this is a quick way to flip the class y_train_corrupted[corrupted] = 1 - y_train[corrupted] fig, ax = plt.subplots(figsize=(6, 6)) plot_datasets(ax, X_train, y_train_corrupted, X_test, y_test) # Mark corrupted datapoints ax.scatter( X_train[corrupted, 0], X_train[corrupted, 1], marker="o", edgecolors="#b1170c", facecolors="none", s=100, ) .. image-sg:: /auto_examples/nearest_neighbor/images/sphx_glr_plot_nn_data_valuation_002.png :alt: plot nn data valuation :srcset: /auto_examples/nearest_neighbor/images/sphx_glr_plot_nn_data_valuation_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 190-191 Now, we can use the `KNNExplainer` to compute the training points' Shapley values based on the entire test dataset by averaging the Shapley values computed using each test point. .. GENERATED FROM PYTHON SOURCE LINES 191-206 .. code-block:: Python # Train the model with the corrupted training data model = KNeighborsClassifier(n_neighbors=5) model.fit(X_train, y_train_corrupted) sv_test = np.zeros(X_train.shape[0], dtype=np.float64) for x_test_current, y_test_current in zip(X_test, y_test, strict=True): explainer = Explainer(model, class_index=y_test_current, index="SV", max_order=1) iv = explainer.explain(x_test_current) sv_test += iv.to_first_order_array() sv_test /= X_test.shape[0] .. GENERATED FROM PYTHON SOURCE LINES 207-208 We can reasonably assume that the corrupted training data points will on average make the model's prediction worse, resulting in negative Shapley values. So let's filter out just those indices where the Shapley value is below zero and compare with our original array of corrupted indices: .. GENERATED FROM PYTHON SOURCE LINES 208-213 .. code-block:: Python print(f"Corrupted: {np.sort(corrupted)}") # Sort for easier comparison print(f"Negative Shapley values: {np.where(sv_test < 0)[0]}") .. rst-class:: sphx-glr-script-out .. code-block:: none Corrupted: [ 1 3 20 28 34 42 45] Negative Shapley values: [ 1 3 28 34 42 45] .. GENERATED FROM PYTHON SOURCE LINES 214-215 We have identified the set corrupted samples almost exactly. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 29.646 seconds) .. _sphx_glr_download_auto_examples_nearest_neighbor_plot_nn_data_valuation.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_nn_data_valuation.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_nn_data_valuation.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_nn_data_valuation.zip `