Note
Go to the end to download the full example code.
TreeSHAP-IQ for LightGBM¶
This example demonstrates TreeExplainer on a LightGBM model
trained on the bike-sharing dataset. TreeSHAP-IQ computes exact Shapley
interaction values in linear time for tree ensembles.
from __future__ import annotations
import lightgbm
from sklearn.model_selection import train_test_split
import shapiq
Load Data and Train Model¶
X, y = shapiq.load_bike_sharing()
X_train, X_test, y_train, y_test = train_test_split(
X.values,
y.values,
test_size=0.25,
random_state=42,
)
n_features = X_train.shape[1]
model = lightgbm.LGBMRegressor(
n_estimators=100,
max_depth=n_features,
random_state=42,
verbose=-1,
)
model.fit(X_train, y_train)
print(f"Train R2: {model.score(X_train, y_train):.4f}")
print(f"Test R2: {model.score(X_test, y_test):.4f}")
/home/docs/checkouts/readthedocs.org/user_builds/shapiq/checkouts/latest/.venv/lib/python3.12/site-packages/sklearn/utils/validation.py:2691: UserWarning: X does not have valid feature names, but LGBMRegressor was fitted with feature names
warnings.warn(
Train R2: 0.9599
/home/docs/checkouts/readthedocs.org/user_builds/shapiq/checkouts/latest/.venv/lib/python3.12/site-packages/sklearn/utils/validation.py:2691: UserWarning: X does not have valid feature names, but LGBMRegressor was fitted with feature names
warnings.warn(
Test R2: 0.9478
Compute Shapley Interactions¶
We compute k-SII scores up to order 3 for a single instance.
InteractionValues(
index=k-SII, max_order=3, min_order=1, estimated=False, estimation_budget=None,
n_players=12, baseline_value=190.379622526228,
Top 10 interactions:
(np.int64(0),): 35.085159510882214
(np.int64(1), np.int64(5)): 14.984490827500759
(np.int64(0), np.int64(1)): 14.033445365073064
(np.int64(1), np.int64(6)): 11.124251580989434
(np.int64(0), np.int64(8)): -13.612044956259176
(np.int64(2),): -15.387584854862604
(np.int64(6),): -21.97379736656223
(np.int64(0), np.int64(9)): -32.8740365775636
(np.int64(5),): -42.991609226571114
(np.int64(1),): -56.72710843542341
)
First-order Values (Shapley Values)¶
print(interaction_values.get_n_order(1).dict_values)
{(np.int64(0),): 35.085159510882214, (np.int64(1),): -56.72710843542341, (np.int64(5),): -42.991609226571114, (np.int64(9),): -4.207060851031486, (np.int64(10),): -10.861061944301156, (np.int64(11),): -2.472089266311177, (np.int64(2),): -15.387584854862604, (np.int64(3),): 5.371857243915892, (np.int64(6),): -21.97379736656223, (np.int64(8),): -5.934637391289779, (np.int64(7),): 0.25444939802939326, (np.int64(4),): 0.5157679457473514}
Visualization: Network Plot¶
shapiq.network_plot(interaction_values=interaction_values, feature_names=list(X.columns))

(<Figure size 700x700 with 1 Axes>, <Axes: >)
Stacked Bar Plot (First Order)¶
shapiq.stacked_bar_plot(
interaction_values=interaction_values.get_n_order(1),
feature_names=list(X.columns),
)

(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)
Stacked Bar Plot (First + Second Order)¶
shapiq.stacked_bar_plot(
interaction_values=interaction_values.get_n_order(2, min_order=1),
feature_names=list(X.columns),
)

(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)
Force Plot¶
interaction_values.plot_force(feature_names=list(X.columns), contribution_threshold=0.03)

Global Feature Importance¶
Compute interaction values for 50 test instances and show global bar plot.
list_of_ivs = explainer.explain_X(X_test[:50])
shapiq.plot.bar_plot(list_of_ivs, feature_names=list(X.columns), max_display=20)

<Axes: xlabel='Attribution'>
Total running time of the script: (0 minutes 47.719 seconds)