Beeswarm Plot¶

This example demonstrates beeswarm_plot(), which provides a global perspective on feature interactions by plotting interaction values across multiple instances, colored by feature value.

from __future__ import annotations

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor

import shapiq

Train a Model¶

x_data, y_data = shapiq.datasets.load_california_housing(to_numpy=False)
feature_names = list(x_data.columns)
x_data, y_data = x_data.values, y_data.values
x_train, x_test, y_train, y_test = train_test_split(
    x_data,
    y_data,
    test_size=0.2,
    random_state=42,
)
model = XGBRegressor(random_state=42, max_depth=4, n_estimators=50)
model.fit(x_train, y_train)
XGBRegressor(base_score=None, booster=None, callbacks=None,
             colsample_bylevel=None, colsample_bynode=None,
             colsample_bytree=None, device=None, early_stopping_rounds=None,
             enable_categorical=False, eval_metric=None, feature_types=None,
             feature_weights=None, gamma=None, grow_policy=None,
             importance_type=None, interaction_constraints=None,
             learning_rate=None, max_bin=None, max_cat_threshold=None,
             max_cat_to_onehot=None, max_delta_step=None, max_depth=4,
             max_leaves=None, min_child_weight=None, missing=nan,
             monotone_constraints=None, multi_strategy=None, n_estimators=50,
             n_jobs=None, num_parallel_tree=None, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Compute Explanations for Multiple Instances¶

We explain 20 test instances to keep the example fast.

x_explain = x_test[:20]
explainer = shapiq.TabularExplainer(
    model,
    data=x_test,
    index="FSII",
    max_order=3,
    random_state=42,
)
explanations = explainer.explain_X(x_explain, budget=200)

Basic Beeswarm Plot¶

shapiq.beeswarm_plot(explanations, x_explain)
plot beeswarm

With Feature Names¶

shapiq.beeswarm_plot(explanations, x_explain, feature_names=feature_names)
plot beeswarm

Full Feature Names (no abbreviation)¶

shapiq.beeswarm_plot(
    explanations,
    x_explain,
    feature_names=feature_names,
    abbreviate=False,
)
plot beeswarm

Limit Displayed Interactions¶

shapiq.beeswarm_plot(
    explanations,
    x_explain,
    feature_names=feature_names,
    abbreviate=False,
    max_display=5,
)
plot beeswarm

Adjust Row Height¶

shapiq.beeswarm_plot(
    explanations,
    x_explain,
    feature_names=feature_names,
    abbreviate=False,
    row_height=1.0,
)
plot beeswarm

Custom Axis¶

fig, ax = plt.subplots(figsize=(6, 6))
shapiq.beeswarm_plot(
    explanations,
    x_explain,
    feature_names=feature_names,
    abbreviate=False,
    ax=ax,
)
plot beeswarm

Total running time of the script: (0 minutes 2.955 seconds)