"""Module for plotting the explanation graph of interaction values."""
from __future__ import annotations
import math
from typing import TYPE_CHECKING
from warnings import warn
import matplotlib.patches as mpatches
import matplotlib.path as mpath
import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from ._config import get_color
from .utils import add_image_in_center
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.legend import Legend
from PIL.Image import Image
from shapiq.interaction_values import InteractionValues
NORMAL_NODE_SIZE = 0.125 # 0.125
BASE_ALPHA_VALUE = 1.0 # the transparency level for the highest interaction
BASE_SIZE = 0.05 # the size of the highest interaction edge (with scale factor 1)
ADJUST_NODE_ALPHA = True
LABEL_OFFSET = 0.07
__all__ = ["get_legend", "si_graph_plot"]
[docs]
def si_graph_plot(
interaction_values: InteractionValues,
*,
show: bool = False,
n_interactions: int | None = None,
draw_threshold: float = 0.0,
interaction_direction: str | None = None,
min_max_order: tuple[int, int] = (1, -1),
size_factor: float = 1.0,
node_size_scaling: float = 1.0,
min_max_interactions: tuple[float, float] | None = None,
feature_names: list | dict | None = None,
graph: list[tuple] | nx.Graph | None = None,
plot_original_nodes: bool = False,
plot_explanation: bool = True,
pos: dict | None = None,
circular_layout: bool = True,
random_seed: int = 42,
adjust_node_pos: bool = False,
spring_k: float | None = None,
compactness: float = 1e10,
center_image: Image | np.ndarray | None = None,
center_image_size: float = 0.4,
feature_image_patches: dict[int, Image] | list[Image] | None = None,
feature_image_patches_size: float = 0.2,
) -> tuple[Figure, Axes] | None:
"""Plots the interaction values as an explanation graph.
An explanation graph is an undirected graph where the nodes represent players and the edges
represent interactions between the players. The size of the nodes and edges represent the
strength of the interaction values. The color of the edges represents the sign of the
interaction values (red for positive and blue for negative). The SI-graph plot is presented in
:footcite:t:`Muschalik.2024b`.
Args:
interaction_values: The interaction values to plot.
show: Whether to show or return the plot. Defaults to ``True``.
n_interactions: The number of interactions to plot. If ``None``, all interactions are
plotted according to the draw_threshold.
draw_threshold: The threshold to draw an edge (i.e. only draw explanations with an
interaction value higher than this threshold).
interaction_direction: The sign of the interaction values to plot. If ``None``, all
interactions are plotted. Possible values are ``"positive"`` and
``"negative"``. Defaults to ``None``.
min_max_order: Only show interactions of min <= size <= max. First order interactions are
always shown. To use maximum order of interaction values, set max to -1. Defaults to
``(1, -1)``.
size_factor: The factor to scale the explanations by (a higher value will make the
interactions and main effects larger). Defaults to ``1.0``.
node_size_scaling: The scaling factor for the node sizes. This can be used to make the nodes
larger or smaller depending on how the graph looks. Defaults to ``1.0`` (no scaling).
Values between ``0.0`` and ``1.0`` will make the nodes smaller, higher values will make
the nodes larger.
min_max_interactions: The minimum and maximum interaction values to use for scaling the
interactions as a tuple ``(min, max)``. If ``None``, the minimum and maximum interaction
values are used. Defaults to ``None``.
feature_names: The feature names used for plotting. List/dict mapping index of the player as
index/key to name. If no feature names are provided, the feature indices are used
instead. Defaults to ``None``.
graph: The underlying graph structure as a list of edge tuples or a networkx graph. If a
networkx graph is provided, the nodes are used as the players and the edges are used as
the connections between the players. Defaults to ``None``, which creates a graph with
all nodes from the interaction values without any edges between them.
plot_original_nodes: If set to ``True``, nodes are shown as white circles with the label
inside, large first-order-effects appear as halos around the node. Set to ``False``,
only the explanation nodes are shown, their labels next to them. Defaults to ``False``.
plot_explanation: Whether to plot the explanation or only the original graph. Defaults to
``True``.
pos: The positions of the nodes in the graph. If ``None``, the spring layout is used to
position the nodes. Defaults to ``None``.
circular_layout: plot the players in a circle according to their order.
random_seed: The random seed to use for layout of the graph (if not circular).
adjust_node_pos: Whether to adjust the node positions such that the nodes are at least
``NORMAL_NODE_SIZE`` apart. Defaults to ``False``.
spring_k: The spring constant for the spring layout. If `None`, the spring constant is
calculated based on the number of nodes in the graph. Defaults to ``None``.
compactness: A scaling factor for the underlying spring layout. A higher compactness value
will move the interactions closer to the graph nodes. If your graph looks weird, try
adjusting this value, e.g. ``[0.1, 1.0, 10.0, 100.0, 1000.0]``. Defaults to ``1e10``.
center_image: An optional image to be displayed in the center of the graph. If provided,
the image displayed with size ``center_image_size``. If the number of features is
a perfect square, we assume a vision transformer style grid was used and overlay the
image with a grid of feature image patches. Defaults to ``None``.
center_image_size: The size of the center image. Defaults to ``0.4``. Adjust this value
to make the image larger or smaller in the center of the graph.
feature_image_patches: A dictionary/list containing the image patches to be displayed
instead of the feature labels in the network. The keys/indices of the list are the
feature indices and the values are the feature images. If explicit feature names are
provided, they are displayed on top of the image. Defaults to ``None``.
feature_image_patches_size: The size of the feature image patches. Defaults to ``0.2``.
Returns:
The figure and axis of the plot if ``show`` is ``False``. Otherwise, ``None``.
References:
.. footbibliography::
"""
if interaction_values is None:
msg = "Interaction_values must be provided."
raise ValueError(msg)
normal_node_size = NORMAL_NODE_SIZE * node_size_scaling
base_size = BASE_SIZE * node_size_scaling
label_mapping = None
if isinstance(feature_names, list):
label_mapping = {i: feature_names[i] for i in range(len(feature_names))}
else:
label_mapping = feature_names
player_ids = {
interaction[0]
for interaction in interaction_values.interaction_lookup
if len(interaction) == 1
}
# fill the original graph with the edges and nodes
if isinstance(graph, nx.Graph):
original_graph = graph
graph_nodes = list(original_graph.nodes)
# check if graph has labels
if "label" not in original_graph.nodes[graph_nodes[0]]:
for node in graph_nodes:
node_label = label_mapping.get(node, node) if label_mapping is not None else node
original_graph.nodes[node]["label"] = node_label
elif isinstance(graph, list):
circular_layout = False
original_graph, graph_nodes = nx.Graph(), []
for edge in graph:
original_graph.add_edge(*edge)
nodel_labels = [edge[0], edge[1]]
if label_mapping is not None:
nodel_labels = [label_mapping.get(node, node) for node in nodel_labels]
original_graph.add_node(edge[0], label=nodel_labels[0])
original_graph.add_node(edge[1], label=nodel_labels[1])
graph_nodes.extend([edge[0], edge[1]])
else: # graph is considered None
original_graph = nx.Graph()
graph_nodes = list(player_ids)
for node in graph_nodes:
node_label = label_mapping.get(node, node) if label_mapping is not None else node
original_graph.add_node(node, label=node_label)
for player_id in player_ids:
if player_id not in original_graph.nodes:
msg = (
f"The given graph does not contain player {player_id}, which can lead to misattributions in the plot.\n"
f"The given graph: {graph} and the players of the given interaction values: {player_ids}"
)
warn(msg, stacklevel=2)
break
if n_interactions is not None:
# get the top n interactions
interaction_values = interaction_values.get_top_k_interactions(
n_interactions
) # TODO(advueu963): Was get_top_k(n_interactions) which should be wrong. # noqa: TD003
min_order, max_order = min_max_order
min_order = max(1, min_order)
if max_order == -1:
max_order = interaction_values.max_order
# get the interactions to plot (sufficiently large, right order)
interactions_to_plot = {}
min_interaction, max_interaction = 1e10, 0.0
for interaction, interaction_pos in interaction_values.interaction_lookup.items():
if len(interaction) < min_order or len(interaction) > max_order:
continue
interaction_value = interaction_values.values[interaction_pos]
min_interaction = min(abs(interaction_value), min_interaction)
max_interaction = max(abs(interaction_value), max_interaction)
if abs(interaction_value) > draw_threshold:
if interaction_direction == "positive" and interaction_value < 0:
continue
if interaction_direction == "negative" and interaction_value > 0:
continue
interactions_to_plot[interaction] = interaction_value
if min_max_interactions is not None:
min_interaction, max_interaction = min_max_interactions
# create explanation graph
explanation_graph, explanation_nodes, explanation_edges = nx.Graph(), [], []
for interaction, interaction_value in interactions_to_plot.items():
interaction_size = len(interaction)
interaction_strength = abs(interaction_value)
attributes = {
"color": get_color(interaction_value),
"alpha": _normalize_value(interaction_value, max_interaction, BASE_ALPHA_VALUE),
"interaction": interaction,
"weight": interaction_strength * compactness,
"size": _normalize_value(interaction_value, max_interaction, base_size * size_factor),
}
# add main effect explanations as nodes
if interaction_size == 1:
player = interaction[0]
explanation_graph.add_node(player, **attributes)
explanation_nodes.append(player)
# add 2-way interaction explanations as edges
if interaction_size >= 2:
explanation_edges.append(interaction)
player_last = interaction[-1]
if interaction_size > 2:
dummy_node = tuple(interaction)
explanation_graph.add_node(dummy_node, **attributes)
player_last = dummy_node
# add the edges between the players
for player in interaction[:-1]:
explanation_graph.add_edge(player, player_last, **attributes)
# position first the original graph structure
if isinstance(graph, nx.Graph | list):
circular_layout = False
adjusted_pos: dict = {}
if pos is None:
# TODO(advueu963): pos is statically just a Mapping. Forcing it to be dict is way stronger but necessary as far I see it # noqa: TD003
if circular_layout:
adjusted_pos = nx.circular_layout(original_graph) # pyright: ignore[reportAssignmentType]
else:
adjusted_pos = nx.spring_layout(original_graph, seed=random_seed, k=spring_k) # pyright: ignore[reportAssignmentType]
adjusted_pos = nx.kamada_kawai_layout(original_graph, scale=1, pos=pos) # pyright: ignore[reportAssignmentType]
else:
# pos is given, but we need to scale the positions potentially
min_pos = np.min(list(pos.values()), axis=0)
max_pos = np.max(list(pos.values()), axis=0)
adjusted_pos = {node: (pos[node] - min_pos) / (max_pos - min_pos) for node in pos}
# adjust pos such that the nodes are at least NORMAL_NODE_SIZE apart
if adjust_node_pos:
adjusted_pos = _adjust_position(adjusted_pos, original_graph)
# create the plot
fig, ax = plt.subplots(figsize=(7, 7))
if plot_explanation:
# position now again the hyper-edges onto the normal nodes weight param is weight
pos_explain = nx.spring_layout(
explanation_graph,
weight="weight",
seed=random_seed,
pos=adjusted_pos,
fixed=graph_nodes,
)
adjusted_pos.update(pos_explain)
_draw_fancy_hyper_edges(ax, adjusted_pos, explanation_graph, hyper_edges=explanation_edges)
_draw_explanation_nodes(
ax,
adjusted_pos,
explanation_graph,
nodes=explanation_nodes,
normal_node_size=normal_node_size,
)
# add the original graph structure on top
if plot_original_nodes or not plot_explanation:
_draw_graph_nodes(ax, adjusted_pos, original_graph, normal_node_size=normal_node_size)
_draw_graph_edges(ax, adjusted_pos, original_graph, normal_node_size=normal_node_size)
# add images
if feature_image_patches is not None:
_draw_feature_images(
ax,
adjusted_pos,
original_graph,
feature_image_patches,
feature_image_patches_size,
)
if feature_image_patches is None or plot_original_nodes:
_draw_graph_labels(
ax,
adjusted_pos,
original_graph,
normal_node_size=normal_node_size,
plot_white_nodes=plot_original_nodes,
)
# add the center image
if center_image is not None:
n_features = interaction_values.n_players
if feature_image_patches is not None:
n_features = len(feature_image_patches)
# if the number is not a square we should not draw a grid, otherwise we assume a grid
if math.isqrt(n_features) ** 2 != n_features:
n_features = None
add_image_in_center(
image=center_image,
axis=ax,
size=center_image_size,
n_features=n_features,
)
# tidy up the plot
ax.set_aspect("equal", adjustable="datalim") # make y- and x-axis scales equal
ax.axis("off") # remove axis
if not show:
return fig, ax
plt.show()
return None
# TODO(advueu963): This function is not used at all. If not given an axis it will also crash. What is the meaning of this function # noqa: TD003
def get_legend(axis: Axes) -> tuple[Legend, Legend]:
"""Gets the legend for the SI graph plot.
Returns a tuple of legends, a legend for first order (nodes) and one for higher order (edges)
interactions. If an axis is provided, it adds the legend to the axis.
Args:
axis (plt.Axes): The axis to add the legend to.
Returns:
a tuple of two legend objects: the first is the legend for the first order interactions, the second for higher
order interactions.
"""
interaction_values = [1.0, 0.4, -0.4, -1]
labels = ["high pos.", "low pos.", "low neg.", "high neg."]
plot_circles = []
plot_edges = []
for value in interaction_values:
color = get_color(value)
node_size = abs(value) / 2 + 1 / 2
edge_size = abs(value) / 2
alpha = _normalize_value(value, 1, BASE_ALPHA_VALUE)
circle = axis.plot(
[], [], c=color, marker="o", markersize=node_size * 8, linestyle="None", alpha=alpha
)
plot_circles.append(circle[0])
line = axis.plot([], [], c=color, linewidth=edge_size * 6, alpha=alpha)
plot_edges.append(line[0])
font_size = plt.rcParams["legend.fontsize"]
legend1 = plt.gca().legend(
plot_circles,
labels,
frameon=True,
framealpha=0.5,
facecolor="white",
title=r"$\bf{First\ Order}$",
fontsize=font_size,
labelspacing=0.5,
handletextpad=0.5,
borderpad=0.5,
handlelength=1.5,
title_fontsize=font_size,
loc="upper left",
)
legend2 = plt.legend(
plot_edges,
labels,
frameon=True,
framealpha=0.5,
facecolor="white",
title=r"$\bf{Higher\ Order}$",
fontsize=font_size,
labelspacing=0.5,
handletextpad=0.5,
borderpad=0.5,
handlelength=1.5,
title_fontsize=font_size,
loc="upper right",
)
if axis:
axis.add_artist(legend1)
axis.add_artist(legend2)
return legend1, legend2
def _normalize_value(
value: float | np.ndarray,
max_value: float,
base_value: float,
) -> float | np.ndarray:
"""Scale a value between 0 and 1 based on the maximum value and a base value.
Args:
value: The value to normalize/scale.
max_value: The maximum value to normalize/scale the value by.
base_value: The base value to scale the value by. For example, the alpha value for the
highest interaction (as defined in ``BASE_ALPHA_VALUE``) or the size of the highest
interaction edge (as defined in ``BASE_SIZE``).
Returns:
The normalized/scaled value.
"""
ratio = abs(value) / abs(max_value) # ratio is always positive in [0, 1]
return ratio * base_value
def _draw_fancy_hyper_edges(
axis: Axes,
pos: dict,
graph: nx.Graph,
hyper_edges: list[tuple],
) -> None:
"""Draws a collection of hyper-edges as a fancy hyper-edge on the graph.
Note:
This is also used to draw normal 2-way edges in a fancy way.
Args:
axis: The axis to draw the hyper-edges on.
pos: The positions of the nodes.
graph: The graph to draw the hyper-edges on.
hyper_edges: The hyper-edges to draw.
"""
for hyper_edge in hyper_edges:
# store all paths for the hyper-edge to combine them later
all_paths = []
# make also normal (2-way) edges plottable -> one node becomes the "center" node
is_hyper_edge = True
if len(hyper_edge) == 2:
u, v = hyper_edge
center_pos = pos[v]
node_size = graph[u][v]["size"]
color = graph[u][v]["color"]
alpha = graph[u][v]["alpha"]
is_hyper_edge = False
else: # a hyper-edge encodes its information in an artificial "center" node
center_pos = pos[hyper_edge]
# TODO(advueu963): Technically there is not guarantee it is not sure that the hyper_edge must exist # noqa: TD003
node_size = graph.nodes.get(hyper_edge)["size"] # pyright: ignore[reportOptionalSubscript]
color = graph.nodes.get(hyper_edge)["color"] # pyright: ignore[reportOptionalSubscript]
alpha = graph.nodes.get(hyper_edge)["alpha"] # pyright: ignore[reportOptionalSubscript]
alpha = min(1.0, max(0.0, alpha))
# draw the connection point of the hyper-edge
circle = mpath.Path.circle(center_pos, radius=node_size / 2)
all_paths.append(circle)
axis.scatter(center_pos[0], center_pos[1], s=0, c="none", lw=0) # add empty point for limit
# draw the fancy connections from the other nodes to the center node
for player in hyper_edge:
player_pos = pos[player]
circle_p = mpath.Path.circle(player_pos, radius=node_size / 2)
all_paths.append(circle_p)
axis.scatter(player_pos[0], player_pos[1], s=0, c="none", lw=0) # for axis limits
# get the direction of the connection
direction = (center_pos[0] - player_pos[0], center_pos[1] - player_pos[1])
direction = np.array(direction) / np.linalg.norm(direction)
# get 90 degree of the direction
direction_90 = np.array([-direction[1], direction[0]])
# get the distance between the player and the center node
distance = np.linalg.norm(center_pos - player_pos)
# get the position of the start and end of the connection
start_pos = player_pos - direction_90 * (node_size / 2)
middle_pos = player_pos + direction * distance / 2
end_pos_one = center_pos - direction_90 * (node_size / 2)
end_pos_two = center_pos + direction_90 * (node_size / 2)
start_pos_two = player_pos + direction_90 * (node_size / 2)
# create the connection
connection = mpath.Path(
[
start_pos,
middle_pos,
end_pos_one,
end_pos_two,
middle_pos,
start_pos_two,
start_pos,
],
[
mpath.Path.MOVETO,
mpath.Path.CURVE3,
mpath.Path.CURVE3,
mpath.Path.LINETO,
mpath.Path.CURVE3,
mpath.Path.CURVE3,
mpath.Path.LINETO,
],
)
# add the connection to the list of all paths
all_paths.append(connection)
# break after the first hyper-edge if there are only two players
if not is_hyper_edge:
break
# combine all paths into one patch
combined_path = mpath.Path.make_compound_path(*all_paths)
patch = mpatches.PathPatch(combined_path, facecolor=color, lw=0, alpha=alpha)
axis.add_patch(patch)
def _draw_graph_nodes(
ax: Axes,
pos: dict,
graph: nx.Graph,
nodes: list | None = None,
normal_node_size: float = NORMAL_NODE_SIZE,
) -> None:
"""Draws the nodes of the graph as circles with a fixed size.
Args:
ax: The axis to draw the nodes on.
pos: The positions of the nodes.
graph: The graph to draw the nodes on.
nodes: The nodes to draw. If ``None``, all nodes are drawn. Defaults to ``None``.
normal_node_size: The size of the nodes. Defaults to ``NORMAL_NODE_SIZE``.
"""
for node in graph.nodes:
if nodes is not None and node not in nodes:
continue
position = pos[node]
circle = mpath.Path.circle(position, radius=normal_node_size / 2)
patch = mpatches.PathPatch(circle, facecolor="white", lw=1, alpha=1, edgecolor="black")
ax.add_patch(patch)
# add empty scatter for the axis to adjust the limits later
ax.scatter(position[0], position[1], s=0, c="none", lw=0)
def _draw_explanation_nodes(
ax: Axes,
pos: dict,
graph: nx.Graph,
nodes: list | None = None,
normal_node_size: float = NORMAL_NODE_SIZE,
) -> None:
"""Adds the node level explanations to the graph as circles with varying sizes.
Args:
ax: The axis to draw the nodes on.
pos: The positions of the nodes.
graph: The graph to draw the nodes on.
nodes: The nodes to draw. If ``None``, all nodes are drawn. Defaults to ``None``.
normal_node_size: The size of the nodes. Defaults to ``NORMAL_NODE_SIZE``.
"""
for node in graph.nodes:
if isinstance(node, tuple):
continue
if nodes is not None and node not in nodes:
continue
position = pos[node]
# TODO(advueu963): Statically it seems to be not clear that we get an object which is subscriptable # noqa: TD003
color = graph.nodes.get(node)["color"] # pyright: ignore[reportOptionalSubscript]
explanation_size = graph.nodes.get(node)["size"] # pyright: ignore[reportOptionalSubscript]
alpha = 1.0
if ADJUST_NODE_ALPHA:
alpha = graph.nodes.get(node)["alpha"] # pyright: ignore[reportOptionalSubscript]
alpha = min(1.0, max(0.0, alpha))
radius = normal_node_size / 2 + explanation_size / 2
circle = mpath.Path.circle(position, radius=radius)
patch = mpatches.PathPatch(circle, facecolor="white", lw=1, edgecolor="white", alpha=1.0)
ax.add_patch(patch)
patch = mpatches.PathPatch(circle, facecolor=color, lw=1, edgecolor="white", alpha=alpha)
ax.add_patch(patch)
ax.scatter(position[0], position[1], s=0, c="none", lw=0) # add empty point for limits
def _draw_graph_edges(
ax: Axes,
pos: dict,
graph: nx.Graph,
edges: list[tuple] | None = None,
normal_node_size: float = NORMAL_NODE_SIZE,
) -> None:
"""Draws black lines between the nodes.
Args:
ax: The axis to draw the edges on.
pos: The positions of the nodes.
graph: The graph to draw the edges on.
edges: The edges to draw. If ``None`` (default), all edges are drawn.
normal_node_size: The size of the nodes. Defaults to ``NORMAL_NODE_SIZE``.
"""
for u, v in graph.edges:
if edges is not None and (u, v) not in edges and (v, u) not in edges:
continue
u_pos = pos[u]
v_pos = pos[v]
direction = v_pos - u_pos
direction = direction / np.linalg.norm(direction)
start_point = u_pos + direction * normal_node_size / 2
end_point = v_pos - direction * normal_node_size / 2
connection = mpath.Path(
[start_point, end_point],
[mpath.Path.MOVETO, mpath.Path.LINETO],
)
patch = mpatches.PathPatch(connection, facecolor="none", lw=1, edgecolor="black")
ax.add_patch(patch)
def _draw_graph_labels(
ax: Axes,
pos: dict,
graph: nx.Graph,
*,
nodes: list | None = None,
normal_node_size: float = 1.0,
plot_white_nodes: bool = False,
) -> None:
"""Adds labels to the nodes of the graph.
Args:
ax: The axis to draw the labels on.
pos: The positions of the nodes.
graph: The graph to draw the labels on.
nodes: The nodes to draw the labels on. If ``None`` (default), all nodes are drawn.
normal_node_size: The size of the nodes. Defaults to ``1.0``.
plot_white_nodes: If set to ``True``, the nodes are drawn as white circles with the label
inside. If set to ``False``, the labels are drawn next to the nodes. Defaults to
``False``.
"""
for node in graph.nodes:
if nodes is not None and node not in nodes:
continue
label = graph.nodes.get(node)["label"] # pyright: ignore[reportOptionalSubscript]
position = pos[node]
if plot_white_nodes:
offset = (0, 0)
else:
# offset so the text is next to the node
offset_norm = np.sqrt(position[0] ** 2 + position[1] ** 2)
offset = (
(LABEL_OFFSET + normal_node_size) * position[0] / offset_norm,
(LABEL_OFFSET + normal_node_size) * position[1] / offset_norm,
)
ax.text(
position[0] + offset[0],
position[1] + offset[1],
label,
fontsize=plt.rcParams["font.size"] + 1,
ha="center",
va="center",
color="black",
)
def _draw_feature_images(
ax: Axes,
pos: dict,
graph: nx.Graph,
feature_image_patches: dict[int, Image] | list[Image],
patch_size: float,
) -> None:
"""Draws the feature images.
Args:
ax: The axis to draw the edges on.
pos: The positions of the nodes.
graph: The graph to draw the edges on.
feature_image_patches: a dict that stores the images for the players
patch_size: The size of the feature images.
"""
x_min, x_max = ax.get_xlim()
img_scale = x_max - x_min
extend = img_scale * patch_size / 2
for node in graph.nodes:
if node < len(feature_image_patches):
image = feature_image_patches[node]
x, y = pos[node]
offset_norm = np.sqrt(x**2 + y**2)
# 1.55 -> bit more than sqrt(2) to position the middle of the image
offset = (
1.55 * patch_size * x / offset_norm,
1.55 * patch_size * y / offset_norm,
)
# x and y are the middle of the image
x, y = x + offset[0], y + offset[1]
ax.imshow(image, extent=(x - extend, x + extend, y - extend, y + extend))
# set the plot to show the whole graph
x_min -= img_scale * patch_size
x_max += img_scale * patch_size
ax.set_xlim(x_min, x_max)
ax.set_ylim(x_min, x_max)
def _adjust_position(
pos: dict,
graph: nx.Graph,
normal_node_size: float = NORMAL_NODE_SIZE,
) -> dict:
"""Moves the nodes in the graph further apart if they are too close together."""
# get the minimum distance between two nodes
min_distance = 1e10
for u, v in graph.edges:
distance = np.linalg.norm(pos[u] - pos[v]).item()
min_distance = min(min_distance, distance)
# adjust the positions if the nodes are too close together
min_edge_distance = normal_node_size + normal_node_size / 2
if min_distance < min_edge_distance:
for node, position in pos.items():
pos[node] = position * min_edge_distance / min_distance
return pos