Computing Shapley Values¶

This example introduces cooperative game theory and shows how to compute Shapley values with shapiq – both exactly and via approximation – and how to apply them for explainable AI (XAI).

from __future__ import annotations

import numpy as np

import shapiq

The Cooking Game¶

Three cooks (Alice, Bob, Charlie) prepare a meal together. We model their joint productivity as a cooperative game and compute exact Shapley values.

class CookingGame(shapiq.Game):
    """Cooking game with three cooks."""

    def __init__(self) -> None:
        self.characteristic_function = {
            (): 0,
            (0,): 4,
            (1,): 3,
            (2,): 2,
            (0, 1): 9,
            (0, 2): 8,
            (1, 2): 7,
            (0, 1, 2): 15,
        }
        super().__init__(
            n_players=3,
            player_names=["Alice", "Bob", "Charlie"],
            normalization_value=self.characteristic_function[()],
        )

    def value_function(self, coalitions: np.ndarray) -> np.ndarray:
        return np.array([self.characteristic_function[tuple(np.where(c)[0])] for c in coalitions])


cooking_game = CookingGame()

Exact Shapley Values¶

The ExactComputer evaluates all \(2^n\) coalitions.

exact_computer = shapiq.ExactComputer(n_players=cooking_game.n_players, game=cooking_game)
sv_exact = exact_computer(index="SV")
print(sv_exact)

sv_exact.plot_stacked_bar(
    xlabel="Cooks",
    ylabel="Shapley Values",
    feature_names=["Alice", "Bob", "Charlie"],
)
plot sv calculation
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=False, estimation_budget=None,
    n_players=3, baseline_value=0.0,
    Top 10 interactions:
        (0,): 6.0
        (1,): 5.0
        (2,): 3.9999999999999996
        (): 0.0
)

Approximating Shapley Values¶

For larger games, exact computation is infeasible. Here we define a 10-player restaurant game and approximate Shapley values with KernelSHAP.

rng = np.random.default_rng(42)
quality_dict = {cooks: rng.random() * len(cooks) for cooks in shapiq.powerset(range(10))}


def restaurant_value_function(coalitions: np.ndarray) -> np.ndarray:
    return np.array([quality_dict[tuple(np.where(c)[0])] for c in coalitions])


approx = shapiq.KernelSHAP(n=10, random_state=42)
sv_approx = approx(game=restaurant_value_function, budget=100)
print(sv_approx)

sv_approx.plot_stacked_bar(
    xlabel="Cooks",
    ylabel="Shapley Values",
    feature_names=[f"Cook {i}" for i in range(10)],
)
plot sv calculation
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=True, estimation_budget=100,
    n_players=10, baseline_value=0.0,
    Top 10 interactions:
        (4,): 1.2877903732343707
        (7,): 0.9352475154189892
        (6,): 0.7299620630825716
        (3,): 0.7097257920905784
        (2,): 0.6497116613130968
        (8,): 0.6355807790513686
        (1,): 0.3433454317119682
        (5,): 0.13950816520202586
        (9,): 0.026330920377390772
        (0,): -0.143064579702957
)

XAI with Shapley Values¶

We train a Random Forest on the California housing dataset and explain a single prediction using TabularExplainer.

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

data, targets = shapiq.datasets.load_california_housing()
feature_names = list(data.columns)
n_features = len(feature_names)

x_train, x_test, y_train, y_test = train_test_split(
    data.values,
    targets.values,
    test_size=0.2,
    random_state=42,
)
rf = RandomForestRegressor(n_estimators=30, random_state=42)
rf.fit(x_train, y_train)
print(f"Test R2: {rf.score(x_test, y_test):.4f}")
Test R2: 0.8001

Explain a Single Prediction¶

x_explain = x_test[2]
y_pred = rf.predict([x_explain])[0]
print(f"Predicted: {y_pred:.3f}, Average: {np.mean(rf.predict(x_test)):.3f}")

explainer = shapiq.TabularExplainer(
    model=rf,
    data=x_test,
    imputer="marginal",
    index="SV",
    max_order=1,
    sample_size=100,
    random_state=42,
)
sv = explainer.explain(x_explain, budget=2**n_features)
print(sv)

sv.plot_force(feature_names=feature_names)
plot sv calculation
Predicted: 4.956, Average: 2.066
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=False, estimation_budget=256,
    n_players=8, baseline_value=2.0660426668281655,
    Top 10 interactions:
        (): 2.0660426668281655
        (5,): 0.9779035073737224
        (7,): 0.7369852714174642
        (1,): 0.48371604761721454
        (6,): 0.3593446011880915
        (0,): 0.149606913635752
        (3,): 0.10402608554445877
        (2,): 0.04175938675099633
        (4,): 0.03618851062965703
)

TreeExplainer for Exact Tree-based SV¶

For tree models, TreeExplainer computes exact Shapley values in linear time.

tree_explainer = shapiq.TreeExplainer(model=rf, index="SV", max_order=1)
sv_tree = tree_explainer.explain(x_explain)
print(sv_tree)

sv_tree.plot_force(feature_names=feature_names)
plot sv calculation
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=True, estimation_budget=None,
    n_players=8, baseline_value=2.072542972807655,
    Top 10 interactions:
        (): 2.072542972807655
        (5,): 1.0855436489007038
        (7,): 0.6743220798486768
        (1,): 0.5348653217356988
        (6,): 0.24030731415309226
        (0,): 0.23839871195972875
        (3,): 0.06806166913442126
        (2,): 0.026626962106926326
        (4,): 0.014904831205571135
)

Total running time of the script: (0 minutes 3.887 seconds)