Parallel Computation with joblib¶

This example shows how to speed up shapiq explanations using parallel computation via joblib. We use a simple synthetic game to keep runtime short.

from __future__ import annotations

import numpy as np

import shapiq

Define a Synthetic Game¶

A lightweight callable that we can explain quickly.

rng = np.random.default_rng(42)
n_features = 8
weights = rng.standard_normal(n_features)


def synthetic_model(x: np.ndarray) -> np.ndarray:
    """Simple linear model with interaction term."""
    return x @ weights + 0.5 * x[:, 0] * x[:, 1]


# Create synthetic data
X_background = rng.standard_normal((100, n_features))
X_test = rng.standard_normal((6, n_features))

Explain a Single Instance¶

explainer = shapiq.Explainer(
    model=synthetic_model,
    data=X_background,
    index="k-SII",
    max_order=2,
    random_state=0,
)
print(f"Explainer type: {type(explainer).__name__}")

iv = explainer.explain(X_test[0], budget=256)
print(iv)

shapiq.network_plot(interaction_values=iv, feature_names=[f"x{i}" for i in range(n_features)])
plot parallel computation
Explainer type: TabularExplainer
InteractionValues(
    index=k-SII, max_order=2, min_order=0, estimated=False, estimation_budget=256,
    n_players=8, baseline_value=0.3123873157557045,
    Top 10 interactions:
        (5,): 0.5176617643254364
        (): 0.3123873157557045
        (0, 1): 0.03694416683658724
        (6,): -0.06933886084594183
        (0,): -0.11273767884210385
        (7,): -0.1365021776001595
        (1,): -0.2622014244845414
        (2,): -0.26702089174324456
        (4,): -0.4915673843261027
        (3,): -0.5965700344493814
)

(<Figure size 700x700 with 1 Axes>, <Axes: >)

Parallel Explanation of Multiple Instances¶

Use n_jobs in explain_X() to parallelize.

ivs = explainer.explain_X(X_test, budget=256, n_jobs=2)
print(f"Computed {len(ivs)} explanations")
Computed 6 explanations

Global Feature Importance¶

shapiq.plot.bar_plot(ivs, feature_names=[f"x{i}" for i in range(n_features)])
plot parallel computation
<Axes: xlabel='Attribution'>

Total running time of the script: (0 minutes 2.522 seconds)