Computing Shapley Values¶

This example introduces cooperative game theory and shows how to compute Shapley values with shapiq – both exactly and via approximation – and how to apply them for explainable AI (XAI).

from __future__ import annotations

import numpy as np

import shapiq

The Cooking Game¶

Three cooks (Alice, Bob, Charlie) prepare a meal together. We model their joint productivity as a cooperative game and compute exact Shapley values.

class CookingGame(shapiq.Game):
    """Cooking game with three cooks."""

    def __init__(self) -> None:
        self.characteristic_function = {
            (): 0,
            (0,): 4,
            (1,): 3,
            (2,): 2,
            (0, 1): 9,
            (0, 2): 8,
            (1, 2): 7,
            (0, 1, 2): 15,
        }
        super().__init__(
            n_players=3,
            player_names=["Alice", "Bob", "Charlie"],
            normalization_value=self.characteristic_function[()],
        )

    def value_function(self, coalitions: np.ndarray) -> np.ndarray:
        return np.array([self.characteristic_function[tuple(np.where(c)[0])] for c in coalitions])


cooking_game = CookingGame()

Exact Shapley Values¶

The ExactComputer evaluates all \(2^n\) coalitions.

exact_computer = shapiq.ExactComputer(n_players=cooking_game.n_players, game=cooking_game)
sv_exact = exact_computer(index="SV")
print(sv_exact)

sv_exact.plot_stacked_bar(
    xlabel="Cooks",
    ylabel="Shapley Values",
    feature_names=["Alice", "Bob", "Charlie"],
)
plot sv calculation
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=False, estimation_budget=None,
    n_players=3, baseline_value=0.0,
    Top 10 interactions:
        (0,): 6.0
        (1,): 5.0
        (2,): 3.9999999999999996
        (): 0.0
)

Approximating Shapley Values¶

For larger games, exact computation is infeasible. Here we define a 10-player restaurant game and approximate Shapley values with KernelSHAP.

rng = np.random.default_rng(42)
quality_dict = {cooks: rng.random() * len(cooks) for cooks in shapiq.powerset(range(10))}


def restaurant_value_function(coalitions: np.ndarray) -> np.ndarray:
    return np.array([quality_dict[tuple(np.where(c)[0])] for c in coalitions])


approx = shapiq.KernelSHAP(n=10, random_state=42)
sv_approx = approx(game=restaurant_value_function, budget=100)
print(sv_approx)

sv_approx.plot_stacked_bar(
    xlabel="Cooks",
    ylabel="Shapley Values",
    feature_names=[f"Cook {i}" for i in range(10)],
)
plot sv calculation
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=True, estimation_budget=100,
    n_players=10, baseline_value=0.0,
    Top 10 interactions:
        (4,): 1.2877903529704664
        (7,): 0.935247478916706
        (6,): 0.7299621089517951
        (3,): 0.7097258286171887
        (2,): 0.6497116716925838
        (8,): 0.6355807967812966
        (1,): 0.34334542495116643
        (5,): 0.13950819236481088
        (9,): 0.02633086370686733
        (0,): -0.14306459717347872
)

XAI with Shapley Values¶

We train a Random Forest on the California housing dataset and explain a single prediction using TabularExplainer.

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

data, targets = shapiq.datasets.load_california_housing()
feature_names = list(data.columns)
n_features = len(feature_names)

x_train, x_test, y_train, y_test = train_test_split(
    data.values,
    targets.values,
    test_size=0.2,
    random_state=42,
)
rf = RandomForestRegressor(n_estimators=30, random_state=42)
rf.fit(x_train, y_train)
print(f"Test R2: {rf.score(x_test, y_test):.4f}")
Test R2: 0.8001

Explain a Single Prediction¶

x_explain = x_test[2]
y_pred = rf.predict([x_explain])[0]
print(f"Predicted: {y_pred:.3f}, Average: {np.mean(rf.predict(x_test)):.3f}")

explainer = shapiq.TabularExplainer(
    model=rf,
    data=x_test,
    imputer="marginal",
    index="SV",
    max_order=1,
    sample_size=100,
    random_state=42,
)
sv = explainer.explain(x_explain, budget=2**n_features)
print(sv)

sv.plot_force(feature_names=feature_names)
plot sv calculation
Predicted: 4.956, Average: 2.066
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=False, estimation_budget=256,
    n_players=8, baseline_value=2.0660426668281655,
    Top 10 interactions:
        (): 2.0660426668281655
        (5,): 0.9779035092551271
        (7,): 0.7369852732988685
        (1,): 0.4837160440140105
        (6,): 0.3593446030694959
        (0,): 0.1496069218651042
        (3,): 0.10402608194125455
        (2,): 0.04175938314779219
        (4,): 0.036188507565703666
)

TreeExplainer for Exact Tree-based SV¶

For tree models, TreeExplainer computes exact Shapley values in linear time.

tree_explainer = shapiq.TreeExplainer(model=rf, index="SV", max_order=1)
sv_tree = tree_explainer.explain(x_explain)
print(sv_tree)

sv_tree.plot_force(feature_names=feature_names)
plot sv calculation
InteractionValues(
    index=SV, max_order=1, min_order=0, estimated=True, estimation_budget=None,
    n_players=8, baseline_value=2.072542972807655,
    Top 10 interactions:
        (): 2.072542972807655
        (5,): 1.0855436627047892
        (7,): 0.674322137806524
        (1,): 0.5348652885441032
        (6,): 0.24030738010944624
        (0,): 0.2383988567381223
        (3,): 0.06806164509226233
        (2,): 0.02662700821814218
        (4,): 0.014904826015329302
)

Total running time of the script: (0 minutes 4.228 seconds)