Source code for shapiq.approximator.proxy.regressionmsr
"""RegressionMSR is a proxy-based approximator that uses a regression model to approximate the value function and applies the MSR adjustment method."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
from .proxyshap import ProxySHAP
if TYPE_CHECKING:
import numpy as np
from shapiq.typing import Model
ValidRegressionMSRIndices = Literal["SV", "BV"]
[docs]
class RegressionMSR(ProxySHAP):
"""RegressionMSR is a proxy-based approximator that uses a regression model to approximate the value function and applies the MSR adjustment method.
The regression model is trained on a subset of the coalitions, and its predictions are adjusted using the MSR method to better match the true value function.
The method was proposed by Witter et al. (2025) :cite:t:`Witter.2025` and is designed to provide more accurate approximations of the Shapley values, especially in cases where the value function is complex and non-linea
Example:
>>> from shapiq_games.synthetic import DummyGame
>>> from shapiq.approximator import RegressionMSR
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = RegressionMSR(n=5, index="SV")
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=SV, max_order=1, estimated=True, estimation_budget=100
)
"""
def __init__(
self,
n: int,
index: ValidRegressionMSRIndices,
*,
proxy_model: Model | None = None,
sampling_weights: np.ndarray | None = None,
pairing_trick: bool = True,
random_state: int | None = None,
) -> None:
"""Initialize the RegressionMSR approximator.
Args:
n: The number of players in the game.
index: The index to be approximated. Must be a valid index for the chosen adjustment method.
proxy_model: The regression model to be used as the proxy. If None, a default regression model will be used.
sampling_weights: The sampling weights for the coalitions. If None, uniform weights will be used.
pairing_trick: Whether to use the pairing trick for sampling coalitions. Default is True.
random_state: The random state for reproducibility. Default is None.
"""
super().__init__(
n=n,
max_order=1,
index=index,
proxy_model=proxy_model,
adjustment="msr",
sampling_weights=sampling_weights,
pairing_trick=pairing_trick,
random_state=random_state,
)