.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/vision/plot_vision_transformer.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_vision_plot_vision_transformer.py: Explaining a Vision Transformer ================================= This example shows how to explain an image classification by a Vision Transformer (ViT) using ``shapiq``. The image is divided into patches and each patch becomes a player in a cooperative game. Shapley values then quantify how much each patch contributes to the predicted class. We use the :class:`~shapiq_games.benchmark.ImageClassifierLocalXAI` game from the ``shapiq_games`` package, which wraps a pretrained ViT model and handles patch masking internally. .. GENERATED FROM PYTHON SOURCE LINES 14-32 .. code-block:: Python from __future__ import annotations import os # Prevent OpenMP/MKL thread conflicts with PyTorch backend os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") import tempfile from pathlib import Path import numpy as np from PIL import Image import shapiq from shapiq_games.benchmark import ImageClassifierLocalXAI .. GENERATED FROM PYTHON SOURCE LINES 33-41 Set Up the Image Game ---------------------- We use a ViT model with a 3x3 grid (9 patches). Each patch is a player in the cooperative game. The game value for a coalition is the model's predicted probability for the top class when only those patches are visible. We create a synthetic image here for portability. In practice you would pass the path to a real photograph. .. GENERATED FROM PYTHON SOURCE LINES 41-55 .. code-block:: Python rng = np.random.default_rng(42) image = Image.fromarray(rng.integers(0, 255, (384, 384, 3), dtype=np.uint8)) image_path = str(Path(tempfile.gettempdir()) / "shapiq_vit_example.png") image.save(image_path) game = ImageClassifierLocalXAI( model_name="vit_9_patches", x_explain_path=image_path, normalize=True, ) print(f"Number of patches (players): {game.n_players}") print(f"Grand coalition value: {game.grand_coalition_value:.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Loading weights: 0%| | 0/200 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_vision_transformer.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_vision_transformer.zip `