SHAP-IQ with scikit-learn¶

This example shows how to compute second-order Shapley Interaction Index (SII) values for a scikit-learn Random Forest on the California housing dataset.

from __future__ import annotations

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

import shapiq

Load Data and Train Model¶

X, y = shapiq.load_california_housing()
X_train, X_test, y_train, y_test = train_test_split(
    X.values,
    y.values,
    test_size=0.25,
    random_state=42,
)
n_features = X_train.shape[1]

model = RandomForestRegressor(
    n_estimators=100,
    max_depth=n_features,
    max_features=2 / 3,
    max_samples=2 / 3,
    random_state=42,
)
model.fit(X_train, y_train)
print(f"Train R2: {model.score(X_train, y_train):.4f}")
print(f"Test  R2: {model.score(X_test, y_test):.4f}")
Train R2: 0.7965
Test  R2: 0.7431

Compute Second-Order SII¶

TabularExplainer with index="SII" and max_order=2 computes pairwise Shapley interaction values.

explainer = shapiq.TabularExplainer(model=model, data=X_train, index="SII", max_order=2)
x = X_test[24]
iv = explainer.explain(x, budget=2**n_features, random_state=0)
print(iv)
InteractionValues(
    index=SII, max_order=2, min_order=0, estimated=False, estimation_budget=256,
    n_players=8, baseline_value=2.0701874006108745,
    Top 10 interactions:
        (6,): 0.14782505874465351
        (1, 5): 0.10379041856327899
        (5, 6): -0.033596353469553024
        (6, 7): -0.0442855101404568
        (0, 1): -0.0466491301504133
        (0, 6): -0.05216940041380504
        (1,): -0.08062385308954095
        (0, 5): -0.08271511045840176
        (5,): -0.14868378053570117
        (7,): -0.25600704507907596
)

Second-Order Interaction Matrix¶

print(iv.get_n_order(2).dict_values)
{(0, 1): -0.0466491301504133, (0, 2): 0.014949695474409294, (0, 3): -0.025717417761207676, (0, 4): -0.021236780611487253, (0, 5): -0.08271511045840176, (0, 6): -0.05216940041380504, (0, 7): 0.006477300221049124, (1, 2): -0.013604570474865177, (1, 3): -0.019193608276303717, (1, 4): -0.01815192423491655, (1, 5): 0.10379041856327899, (1, 6): -0.021629200501209332, (1, 7): -0.025722172748253485, (2, 3): -0.02003480871301562, (2, 4): -0.020121479362014635, (2, 5): -0.020934610114452155, (2, 6): -0.017573708513835187, (2, 7): -0.025719160157678812, (3, 4): -0.02078192059503527, (3, 5): -0.015707988739720742, (3, 6): -0.024584795557661743, (3, 7): -0.022438635539386957, (4, 5): -0.02418830077352101, (4, 6): -0.021910782836244265, (4, 7): -0.019706331972747453, (5, 6): -0.033596353469553024, (5, 7): -0.006788349224924663, (6, 7): -0.0442855101404568}

Visualization: Network Plot¶

shapiq.network_plot(interaction_values=iv, feature_names=list(X.columns))
plot shapiq scikit learn
(<Figure size 700x700 with 1 Axes>, <Axes: >)

Stacked Bar Plot (First Order)¶

shapiq.stacked_bar_plot(iv.get_n_order(1), feature_names=list(X.columns))
plot shapiq scikit learn
(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)

Stacked Bar Plot (All Orders)¶

shapiq.stacked_bar_plot(interaction_values=iv, feature_names=list(X.columns))
plot shapiq scikit learn
(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)

Force Plot¶

iv.plot_force(feature_names=list(X.columns))
plot shapiq scikit learn

Total running time of the script: (0 minutes 3.465 seconds)