TreeSHAP-IQ for LightGBM¶

This example demonstrates TreeExplainer on a LightGBM model trained on the bike-sharing dataset. TreeSHAP-IQ computes exact Shapley interaction values in linear time for tree ensembles.

from __future__ import annotations

import lightgbm
from sklearn.model_selection import train_test_split

import shapiq

Load Data and Train Model¶

X, y = shapiq.load_bike_sharing()
X_train, X_test, y_train, y_test = train_test_split(
    X.values,
    y.values,
    test_size=0.25,
    random_state=42,
)
n_features = X_train.shape[1]

model = lightgbm.LGBMRegressor(
    n_estimators=100,
    max_depth=n_features,
    random_state=42,
    verbose=-1,
)
model.fit(X_train, y_train)
print(f"Train R2: {model.score(X_train, y_train):.4f}")
print(f"Test  R2: {model.score(X_test, y_test):.4f}")
/home/docs/checkouts/readthedocs.org/user_builds/shapiq/checkouts/latest/.venv/lib/python3.12/site-packages/sklearn/utils/validation.py:2691: UserWarning: X does not have valid feature names, but LGBMRegressor was fitted with feature names
  warnings.warn(
Train R2: 0.9599
/home/docs/checkouts/readthedocs.org/user_builds/shapiq/checkouts/latest/.venv/lib/python3.12/site-packages/sklearn/utils/validation.py:2691: UserWarning: X does not have valid feature names, but LGBMRegressor was fitted with feature names
  warnings.warn(
Test  R2: 0.9478

Compute Shapley Interactions¶

We compute k-SII scores up to order 3 for a single instance.

explainer = shapiq.TreeExplainer(model=model, index="k-SII", min_order=1, max_order=3)
x = X_test[1234]
interaction_values = explainer.explain(x)
print(interaction_values)
InteractionValues(
    index=k-SII, max_order=3, min_order=1, estimated=False, estimation_budget=None,
    n_players=12, baseline_value=190.379622526228,
    Top 10 interactions:
        (np.int64(0),): 35.085159510882214
        (np.int64(1), np.int64(5)): 14.984490827500759
        (np.int64(0), np.int64(1)): 14.033445365073064
        (np.int64(1), np.int64(6)): 11.124251580989434
        (np.int64(0), np.int64(8)): -13.612044956259176
        (np.int64(2),): -15.387584854862604
        (np.int64(6),): -21.97379736656223
        (np.int64(0), np.int64(9)): -32.8740365775636
        (np.int64(5),): -42.991609226571114
        (np.int64(1),): -56.72710843542341
)

First-order Values (Shapley Values)¶

print(interaction_values.get_n_order(1).dict_values)
{(np.int64(0),): 35.085159510882214, (np.int64(1),): -56.72710843542341, (np.int64(5),): -42.991609226571114, (np.int64(9),): -4.207060851031486, (np.int64(10),): -10.861061944301156, (np.int64(11),): -2.472089266311177, (np.int64(2),): -15.387584854862604, (np.int64(3),): 5.371857243915892, (np.int64(6),): -21.97379736656223, (np.int64(8),): -5.934637391289779, (np.int64(7),): 0.25444939802939326, (np.int64(4),): 0.5157679457473514}

Visualization: Network Plot¶

shapiq.network_plot(interaction_values=interaction_values, feature_names=list(X.columns))
plot treeshapiq lightgbm
(<Figure size 700x700 with 1 Axes>, <Axes: >)

Stacked Bar Plot (First Order)¶

shapiq.stacked_bar_plot(
    interaction_values=interaction_values.get_n_order(1),
    feature_names=list(X.columns),
)
plot treeshapiq lightgbm
(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)

Stacked Bar Plot (First + Second Order)¶

shapiq.stacked_bar_plot(
    interaction_values=interaction_values.get_n_order(2, min_order=1),
    feature_names=list(X.columns),
)
plot treeshapiq lightgbm
(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='features', ylabel='SI values'>)

Force Plot¶

interaction_values.plot_force(feature_names=list(X.columns), contribution_threshold=0.03)
plot treeshapiq lightgbm

Global Feature Importance¶

Compute interaction values for 50 test instances and show global bar plot.

list_of_ivs = explainer.explain_X(X_test[:50])
shapiq.plot.bar_plot(list_of_ivs, feature_names=list(X.columns), max_display=20)
plot treeshapiq lightgbm
<Axes: xlabel='Attribution'>

Total running time of the script: (0 minutes 47.719 seconds)