Source code for shapiq.approximator.proxy.regressionmsr
"""RegressionMSR is a proxy-based approximator that uses a regression model to approximate the value function and applies the MSR adjustment method."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
from .proxyshap import ProxySHAP
if TYPE_CHECKING:
import numpy as np
from shapiq.approximator.proxy._models import ProxyLiteral, ProxyModel, ProxyModelWithHPO
ValidRegressionMSRIndices = Literal["SV", "BV"]
[docs]
class RegressionMSR(ProxySHAP):
"""RegressionMSR is a proxy-based approximator that uses a regression model to approximate the value function and applies the MSR adjustment method.
The regression model is trained on a subset of the coalitions, and its predictions are adjusted using the MSR method to better match the true value function.
The method was proposed by Witter et al. (2025) :cite:t:`Witter.2025` and is designed to provide more accurate approximations of the Shapley values, especially in cases where the value function is complex and non-linea
Example:
>>> from shapiq_games.synthetic import DummyGame
>>> from shapiq.approximator import RegressionMSR
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = RegressionMSR(n=5, index="SV")
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=SV, max_order=1, estimated=True, estimation_budget=100
)
"""
def __init__(
self,
n: int,
index: ValidRegressionMSRIndices,
*,
proxy_model: ProxyModel | ProxyModelWithHPO | ProxyLiteral = "xgboost",
sampling_weights: np.ndarray | None = None,
pairing_trick: bool = True,
random_state: int | None = None,
) -> None:
"""Initialize the RegressionMSR approximator.
Args:
n: The number of players in the game.
index: The index to be approximated. Must be a valid index for the chosen adjustment method.
proxy_model: The model used as the proxy. Either an estimator/HPO wrapper or a
string tag (``"xgboost"`` (default), ``"lightgbm"``, ``"tree"``, ``"linear"``);
see :class:`~shapiq.approximator.proxy.proxyshap.ProxySHAP` for details.
sampling_weights: The sampling weights for the coalitions. If None, uniform weights will be used.
pairing_trick: Whether to use the pairing trick for sampling coalitions. Default is True.
random_state: The random state for reproducibility. Default is None.
"""
super().__init__(
n=n,
max_order=1,
index=index,
proxy_model=proxy_model,
adjustment="msr",
sampling_weights=sampling_weights,
pairing_trick=pairing_trick,
random_state=random_state,
)