Source code for shapiq.approximator.montecarlo.svarmiq

"""SVARM-IQ approximation."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args

from .base import MonteCarlo, ValidMonteCarloIndices

if TYPE_CHECKING:
    from shapiq.typing import FloatVector

TIndices = TypeVar("TIndices", bound=ValidMonteCarloIndices)
"""A type variable for the valid indices of the MonteCarlo approximator."""


[docs] class SVARMIQ(MonteCarlo[ValidMonteCarloIndices]): """The SVARM-IQ approximator for Shapley interactions. SVARM-IQ utilizes MonteCarlo approximation with two stratification strategies. SVARM-IQ is a generalization of the SVARM algorithm :cite:p:`Kolpaczki.2024a` and can approximate any-order Shapley interactions efficiently. For details about the algorithm see the original paper by :cite:t:`Kolpaczki.2024b`. """ def __init__( self, n: int, max_order: int = 2, index: ValidMonteCarloIndices = "k-SII", *, top_order: bool = False, pairing_trick: bool = False, sampling_weights: FloatVector | None = None, random_state: int | None = None, ) -> None: """Initialize the SVARMIQ approximator. Args: n: The number of players. max_order: The interaction order of the approximation. Defaults to ``2``. index: The interaction index to be used. Choose from ``['k-SII', 'SII']``. Defaults to ``'k-SII'``. top_order: If ``True``, the top-order interactions are estimated. Defaults to ``False``. pairing_trick: If ``True``, the pairing trick is applied to the sampling procedure. Defaults to ``False``. sampling_weights: An optional array of weights for the sampling procedure. The weights must be of shape ``(n + 1,)`` and are used to determine the probability of sampling a coalition of a certain size. Defaults to ``None``. random_state: The random state of the estimator. Defaults to ``None``. """ super().__init__( n, max_order=max_order, index=index, top_order=top_order, stratify_coalition_size=True, stratify_intersection=True, random_state=random_state, sampling_weights=sampling_weights, pairing_trick=pairing_trick, )
ValidIndicesSVARM = Literal["SV", "BV"]
[docs] class SVARM(MonteCarlo[ValidIndicesSVARM]): """The SVARM approximator for estimating the Shapley value (SV). SVARM is a MonteCarlo approximation algorithm that estimates the Shapley value. For details about the algorithm see the original paper by Kolpaczki et al. (2024) :footcite:t:`Kolpaczki.2024a`. References: .. footbibliography:: """ valid_indices: tuple[ValidIndicesSVARM, ...] = tuple(get_args(ValidIndicesSVARM)) """The valid indices for the SVARM approximator.""" def __init__( self, n: int, index: ValidIndicesSVARM = "SV", *, random_state: int | None = None, pairing_trick: bool = False, sampling_weights: FloatVector | None = None, **kwargs: Any, # noqa: ARG002 ) -> None: """Initialize the SVARM approximator. Args: n: The number of players. index: The interaction index to be used. Choose from ``['SV', 'BV']``. Defaults to ``'SV'``. random_state: The random state of the estimator. Defaults to ``None``. pairing_trick: If ``True``, the pairing trick is applied to the sampling procedure. Defaults to ``False``. sampling_weights: An optional array of weights for the sampling procedure. The weights must be of shape ``(n + 1,)`` and are used to determine the probability of sampling a coalition of a certain size. Defaults to ``None``. **kwargs: Additional keyword arguments (not used only for compatibility). """ super().__init__( n, max_order=1, index=index, top_order=False, stratify_coalition_size=True, stratify_intersection=True, random_state=random_state, sampling_weights=sampling_weights, pairing_trick=pairing_trick, )