Note
Go to the end to download the full example code.
SHAP-IQ with scikit-learn¶
This example shows how to compute second-order Shapley Interaction Index (SII) values for a scikit-learn Random Forest on the California housing dataset.
from __future__ import annotations
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import shapiq
Load Data and Train Model¶
X, y = shapiq.load_california_housing()
X_train, X_test, y_train, y_test = train_test_split(
X.values,
y.values,
test_size=0.25,
random_state=42,
)
n_features = X_train.shape[1]
model = RandomForestRegressor(
n_estimators=100,
max_depth=n_features,
max_features=2 / 3,
max_samples=2 / 3,
random_state=42,
)
model.fit(X_train, y_train)
print(f"Train R2: {model.score(X_train, y_train):.4f}")
print(f"Test R2: {model.score(X_test, y_test):.4f}")
Train R2: 0.7965
Test R2: 0.7431
Compute Second-Order SII¶
TabularExplainer with index="SII" and max_order=2
computes pairwise Shapley interaction values.
InteractionValues(
index=SII, max_order=2, min_order=0, estimated=False, estimation_budget=256,
n_players=8, baseline_value=2.0701874006108745,
Top 10 interactions:
(6,): 0.1478250587446534
(1, 5): 0.1037904175132812
(5, 6): -0.03359635308036175
(6, 7): -0.04428551372622448
(0, 1): -0.04664913064576401
(0, 6): -0.052169398346945554
(1,): -0.08062385308954102
(0, 5): -0.08271510822132412
(5,): -0.1486837805357012
(7,): -0.25600704507907607
)
Second-Order Interaction Matrix¶
print(iv.get_n_order(2).dict_values)
{(0, 1): -0.04664913064576401, (0, 2): 0.014949694826485354, (0, 3): -0.0257174181375876, (0, 4): -0.021236779964447273, (0, 5): -0.08271510822132412, (0, 6): -0.052169398346945554, (0, 7): 0.006477296789726558, (1, 2): -0.01360457209368639, (1, 3): -0.019193610480654984, (1, 4): -0.018151921067836082, (1, 5): 0.1037904175132812, (1, 6): -0.021629198688194986, (1, 7): -0.025722172359827346, (2, 3): -0.020034807366737448, (2, 4): -0.020121479855404616, (2, 5): -0.020934611006186002, (2, 6): -0.01757370914212751, (2, 7): -0.025719157223795603, (3, 4): -0.020781919700815447, (3, 5): -0.015707991158175984, (3, 6): -0.024584793081930522, (3, 7): -0.022438635256429716, (4, 5): -0.024188302454580896, (4, 6): -0.021910785366980597, (4, 7): -0.019706331975901582, (5, 6): -0.03359635308036175, (5, 7): -0.006788345809946858, (6, 7): -0.04428551372622448}
Visualization: Network Plot¶
shapiq.network_plot(interaction_values=iv, feature_names=list(X.columns))

(<Figure size 700x700 with 1 Axes>, <Axes: >)
Stacked Bar Plot (First Order)¶
shapiq.stacked_bar_plot(iv.get_n_order(1), feature_names=list(X.columns))

(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)
Stacked Bar Plot (All Orders)¶
shapiq.stacked_bar_plot(interaction_values=iv, feature_names=list(X.columns))

(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)
Force Plot¶
iv.plot_force(feature_names=list(X.columns))

Total running time of the script: (0 minutes 3.796 seconds)