From 66e05974f8337a2cea8033d59eb644ed7dd42611 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Mon, 6 Oct 2025 17:37:37 +0200 Subject: [PATCH 1/4] Enable type-checking for `visualization.py` This commit adds type annotations to `visualization.py`, along with its corresponding tests, enabling full type-checking with mypy while preserving existing functionality. **Related issue:** This PR continues the work started in #302, #308, and #312. --- examples/visualization.py | 2 +- graphix/generator.py | 2 +- graphix/gflow.py | 2 +- graphix/pattern.py | 4 +- graphix/visualization.py | 414 ++++++++++++++++-------------------- pyproject.toml | 6 - tests/test_visualization.py | 4 +- 7 files changed, 196 insertions(+), 238 deletions(-) diff --git a/examples/visualization.py b/examples/visualization.py index 8bc981247..8721f5492 100644 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -73,7 +73,7 @@ edges = [(1, 4), (1, 6), (2, 4), (2, 5), (2, 6), (3, 5), (3, 6)] inputs = {1, 2, 3} outputs = {4, 5, 6} -graph = nx.Graph() +graph: nx.Graph[int] = nx.Graph() graph.add_nodes_from(nodes) graph.add_edges_from(edges) meas_planes = {1: Plane.XY, 2: Plane.XY, 3: Plane.XY} diff --git a/graphix/generator.py b/graphix/generator.py index 815488c2f..39cf4b508 100644 --- a/graphix/generator.py +++ b/graphix/generator.py @@ -79,7 +79,7 @@ def generate_from_graph( # search for flow first f, l_k = graphix.gflow.find_flow(graph, inputs_set, outputs_set, meas_planes=meas_planes) - if f is not None: + if f is not None and l_k is not None: # flow found pattern = _flow2pattern(graph, angles, inputs, f, l_k) pattern.reorder_output_nodes(outputs) diff --git a/graphix/gflow.py b/graphix/gflow.py index 5faae477c..5f19bb939 100644 --- a/graphix/gflow.py +++ b/graphix/gflow.py @@ -102,7 +102,7 @@ def find_flow( iset: set[int], oset: set[int], meas_planes: dict[int, Plane] | None = None, -) -> tuple[dict[int, set[int]], dict[int, int]]: +) -> tuple[dict[int, set[int]], dict[int, int]] | tuple[None, None]: """Causal flow finding algorithm. For open graph g with input, output, and measurement planes, this returns causal flow. diff --git a/graphix/pattern.py b/graphix/pattern.py index 2c96d77b2..cd63cfdab 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -998,7 +998,7 @@ def get_measurement_order_from_flow(self) -> list[int] | None: vout = set(self.output_nodes) meas_planes = self.get_meas_plane() f, l_k = find_flow(graph, vin, vout, meas_planes=meas_planes) - if f is None: + if f is None or l_k is None: return None depth, layer = get_layers(l_k) meas_order: list[int] = [] @@ -1389,7 +1389,7 @@ def draw_graph( show_local_clifford: bool = False, show_measurement_planes: bool = False, show_loop: bool = True, - node_distance: tuple[int, int] = (1, 1), + node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, save: bool = False, filename: str | None = None, diff --git a/graphix/visualization.py b/graphix/visualization.py index 3b835473a..0a00969b3 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -4,7 +4,7 @@ import math from copy import deepcopy -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING import networkx as nx import numpy as np @@ -12,13 +12,24 @@ from graphix import gflow from graphix.fundamentals import Plane +from graphix.measurements import PauliMeasurement if TYPE_CHECKING: + from collections.abc import Collection, Hashable, Iterable, Mapping, Sequence + from typing import TypeAlias, TypeVar + + import numpy.typing as npt + # MEMO: Potential circular import from graphix.clifford import Clifford from graphix.parameter import ExpressionOrFloat from graphix.pattern import Pattern + _Edge: TypeAlias = tuple[int, int] + _Point: TypeAlias = tuple[float, float] + + _HashableT = TypeVar("_HashableT", bound=Hashable) # reusable node type variable + class GraphVisualizer: """A class for visualizing MBQC graphs with flow or gflow structure. @@ -42,12 +53,12 @@ class GraphVisualizer: def __init__( self, - g: nx.Graph, - v_in: list[int], - v_out: list[int], - meas_plane: dict[int, Plane] | None = None, - meas_angles: dict[int, ExpressionOrFloat] | None = None, - local_clifford: dict[int, Clifford] | None = None, + g: nx.Graph[int], + v_in: Collection[int], + v_out: Collection[int], + meas_plane: Mapping[int, Plane] | None = None, + meas_angles: Mapping[int, ExpressionOrFloat] | None = None, + local_clifford: Mapping[int, Clifford] | None = None, ): """ Construct a graph visualizer. @@ -74,7 +85,7 @@ def __init__( if meas_plane is None: self.meas_planes = dict.fromkeys(iter(g.nodes), Plane.XY) else: - self.meas_planes = meas_plane + self.meas_planes = dict(meas_plane) self.meas_angles = meas_angles self.local_clifford = local_clifford @@ -84,11 +95,11 @@ def visualize( show_local_clifford: bool = False, show_measurement_planes: bool = False, show_loop: bool = True, - node_distance: tuple[int, int] = (1, 1), + node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, save: bool = False, filename: str | None = None, - ): + ) -> None: """ Visualize the graph with flow or gflow structure. @@ -117,7 +128,7 @@ def visualize( Filename of the saved plot. """ f, l_k = gflow.find_flow(self.graph, set(self.v_in), set(self.v_out), meas_planes=self.meas_planes) # try flow - if f: + if f is not None and l_k is not None: print("Flow detected in the graph.") self.visualize_w_flow( f, @@ -132,7 +143,7 @@ def visualize( ) else: g, l_k = gflow.find_gflow(self.graph, set(self.v_in), set(self.v_out), self.meas_planes) # try gflow - if g: + if g is not None and l_k is not None: print("Gflow detected in the graph. (flow not detected)") self.visualize_w_gflow( g, @@ -165,11 +176,11 @@ def visualize_from_pattern( show_local_clifford: bool = False, show_measurement_planes: bool = False, show_loop: bool = True, - node_distance: tuple[int, int] = (1, 1), + node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, save: bool = False, filename: str | None = None, - ): + ) -> None: """ Visualize the graph with flow or gflow structure found from the given pattern. @@ -231,12 +242,12 @@ def visualize_from_pattern( else: print("The pattern is not consistent with flow or gflow structure.") depth, layers = pattern.get_layers() - layers = {element: key for key, value_set in layers.items() for element in value_set} + unfolded_layers = {element: key for key, value_set in layers.items() for element in value_set} for output in pattern.output_nodes: - layers[output] = depth + 1 + unfolded_layers[output] = depth + 1 xflow, zflow = gflow.get_corrections_from_pattern(pattern) self.visualize_all_correction( - layers, + unfolded_layers, xflow, zflow, show_pauli_measurement, @@ -250,16 +261,16 @@ def visualize_from_pattern( def visualize_w_flow( self, - f: dict[int, set[int]], - l_k: dict[int, int], + f: Mapping[int, set[int]], + l_k: Mapping[int, int], show_pauli_measurement: bool = True, show_local_clifford: bool = False, show_measurement_planes: bool = False, - node_distance: tuple[int, int] = (1, 1), - figsize: tuple[int, int] | None = None, + node_distance: tuple[float, float] = (1, 1), + figsize: _Point | None = None, save: bool = False, filename: str | None = None, - ): + ) -> None: """ Visualizes the graph with flow structure. @@ -300,24 +311,17 @@ def visualize_w_flow( if len(edge_path[edge]) == 2: nx.draw_networkx_edges(self.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) else: - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(edge_path[edge], t) + curve = self._bezier_curve_linspace(edge_path[edge]) plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) - for arrow in arrow_path: - if len(arrow_path[arrow]) == 2: + for arrow, path in arrow_path.items(): + if len(path) == 2: nx.draw_networkx_edges( self.graph, pos, edgelist=[arrow], edge_color="black", arrowstyle="->", arrows=True ) else: - path = arrow_path[arrow] - last = np.array(path[-1]) - second_last = np.array(path[-2]) - path[-1] = list( - last - (last - second_last) / np.linalg.norm(last - second_last) * 0.2 - ) # Shorten the last edge not to hide arrow under the node - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(path, t) + GraphVisualizer._shorten_path(path) + curve = self._bezier_curve_linspace(path) plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) plt.annotate( @@ -329,21 +333,13 @@ def visualize_w_flow( self.__draw_nodes_role(pos, show_pauli_measurement) - if show_local_clifford and self.local_clifford is not None: - for node in self.graph.nodes(): - if node in self.local_clifford: - plt.text(*pos[node] + np.array([0.2, 0.2]), f"{self.local_clifford[node]}", fontsize=10, zorder=3) + if show_local_clifford: + self.__draw_local_clifford(pos) if show_measurement_planes: - for node in self.graph.nodes(): - if node in self.meas_planes: - plt.text(*pos[node] + np.array([0.22, -0.2]), f"{self.meas_planes[node]}", fontsize=9, zorder=3) + self.__draw_measurement_planes(pos) - # Draw the labels - fontsize = 12 - if max(self.graph.nodes()) >= 100: - fontsize = fontsize * 2 / len(str(max(self.graph.nodes()))) - nx.draw_networkx_labels(self.graph, pos, font_size=fontsize) + self._draw_labels(pos) x_min = min(pos[node][0] for node in self.graph.nodes()) # Get the minimum x coordinate x_max = max(pos[node][0] for node in self.graph.nodes()) # Get the maximum x coordinate @@ -368,13 +364,27 @@ def visualize_w_flow( plt.savefig(filename) plt.show() - def __draw_nodes_role(self, pos: dict[int, tuple[float, float]], show_pauli_measurement: bool = False) -> None: + @staticmethod + def _shorten_path(path: list[_Point]) -> None: + """Shorten the last edge not to hide arrow under the node.""" + last = np.array(path[-1]) + second_last = np.array(path[-2]) + last_edge: _Point = tuple(last - (last - second_last) / np.linalg.norm(last - second_last) * 0.2) + path[-1] = last_edge + + def _draw_labels(self, pos: Mapping[int, _Point]) -> None: + fontsize = 12 + if max(self.graph.nodes()) >= 100: + fontsize = int(fontsize * 2 / len(str(max(self.graph.nodes())))) + nx.draw_networkx_labels(self.graph, pos, font_size=fontsize) + + def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None: """ Draw the nodes with different colors based on their role (input, output, or other). Parameters ---------- - pos : dict[int, tuple[float, float]] + pos : Mapping[int, tuple[float, float]] dictionary of node positions. show_pauli_measurement : bool If True, the nodes with Pauli measurement angles are colored light blue. @@ -388,10 +398,8 @@ def __draw_nodes_role(self, pos: dict[int, tuple[float, float]], show_pauli_meas inner_color = "lightgray" elif ( show_pauli_measurement - and isinstance(self.meas_angles, SupportsFloat) - and ( - 2 * self.meas_angles[node] == int(2 * self.meas_angles[node]) - ) # measurement angle is integer or half-integer + and self.meas_angles is not None + and PauliMeasurement.try_from(Plane.XY, self.meas_angles[node]) is not None ): inner_color = "lightblue" plt.scatter( @@ -400,17 +408,17 @@ def __draw_nodes_role(self, pos: dict[int, tuple[float, float]], show_pauli_meas def visualize_w_gflow( self, - g: dict[int, set[int]], - l_k: dict[int, int], + g: Mapping[int, set[int]], + l_k: Mapping[int, int], show_pauli_measurement: bool = True, show_local_clifford: bool = False, show_measurement_planes: bool = False, show_loop: bool = True, - node_distance: tuple[int, int] = (1, 1), - figsize: tuple[int, int] | None = None, + node_distance: tuple[float, float] = (1, 1), + figsize: _Point | None = None, save: bool = False, filename: str | None = None, - ): + ) -> None: """ Visualizes the graph with flow structure. @@ -455,15 +463,13 @@ def visualize_w_gflow( if len(edge_path[edge]) == 2: nx.draw_networkx_edges(self.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) else: - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(edge_path[edge], t) + curve = self._bezier_curve_linspace(edge_path[edge]) plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) - for arrow in arrow_path: + for arrow in arrow_path.values(): if arrow[0] == arrow[1]: # self loop if show_loop: - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(arrow_path[arrow], t) + curve = self._bezier_curve_linspace(arrow) plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) plt.annotate( "", @@ -471,19 +477,13 @@ def visualize_w_gflow( xytext=curve[-2], arrowprops={"arrowstyle": "->", "color": "k", "lw": 1}, ) - elif len(arrow_path[arrow]) == 2: # straight line + elif len(arrow) == 2: # straight line nx.draw_networkx_edges( self.graph, pos, edgelist=[arrow], edge_color="black", arrowstyle="->", arrows=True ) else: - path = arrow_path[arrow] - last = np.array(path[-1]) - second_last = np.array(path[-2]) - path[-1] = list( - last - (last - second_last) / np.linalg.norm(last - second_last) * 0.2 - ) # Shorten the last edge not to hide arrow under the node - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(path, t) + GraphVisualizer._shorten_path(arrow) + curve = self._bezier_curve_linspace(arrow) plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) plt.annotate( @@ -495,21 +495,13 @@ def visualize_w_gflow( self.__draw_nodes_role(pos, show_pauli_measurement) - if show_local_clifford and self.local_clifford is not None: - for node in self.graph.nodes(): - if node in self.local_clifford: - plt.text(*pos[node] + np.array([0.2, 0.2]), f"{self.local_clifford[node]}", fontsize=10, zorder=3) + if show_local_clifford: + self.__draw_local_clifford(pos) if show_measurement_planes: - for node in self.graph.nodes(): - if node in self.meas_planes: - plt.text(*pos[node] + np.array([0.22, -0.2]), f"{self.meas_planes[node]}", fontsize=9, zorder=3) + self.__draw_measurement_planes(pos) - # Draw the labels - fontsize = 12 - if max(self.graph.nodes()) >= 100: - fontsize = fontsize * 2 / len(str(max(self.graph.nodes()))) - nx.draw_networkx_labels(self.graph, pos, font_size=fontsize) + self._draw_labels(pos) x_min = min(pos[node][0] for node in self.graph.nodes()) # Get the minimum x coordinate x_max = max(pos[node][0] for node in self.graph.nodes()) # Get the maximum x coordinate @@ -534,16 +526,29 @@ def visualize_w_gflow( plt.savefig(filename) plt.show() + def __draw_local_clifford(self, pos: Mapping[int, _Point]) -> None: + if self.local_clifford is not None: + for node in self.graph.nodes(): + if node in self.local_clifford: + x, y = pos[node] + np.array([0.2, 0.2]) + plt.text(x, y, f"{self.local_clifford[node]}", fontsize=10, zorder=3) + + def __draw_measurement_planes(self, pos: Mapping[int, _Point]) -> None: + for node in self.graph.nodes(): + if node in self.meas_planes: + x, y = pos[node] + np.array([0.22, -0.2]) + plt.text(x, y, f"{self.meas_planes[node]}", fontsize=9, zorder=3) + def visualize_wo_structure( self, show_pauli_measurement: bool = True, show_local_clifford: bool = False, show_measurement_planes: bool = False, - node_distance: tuple[int, int] = (1, 1), - figsize: tuple[int, int] | None = None, + node_distance: tuple[float, float] = (1, 1), + figsize: _Point | None = None, save: bool = False, filename: str | None = None, - ): + ) -> None: """ Visualizes the graph without flow or gflow. @@ -582,27 +587,18 @@ def visualize_wo_structure( if len(edge_path[edge]) == 2: nx.draw_networkx_edges(self.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) else: - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(edge_path[edge], t) + curve = self._bezier_curve_linspace(edge_path[edge]) plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) self.__draw_nodes_role(pos, show_pauli_measurement) - if show_local_clifford and self.local_clifford is not None: - for node in self.graph.nodes(): - if node in self.local_clifford: - plt.text(*pos[node] + np.array([0.2, 0.2]), f"{self.local_clifford[node]}", fontsize=10, zorder=3) + if show_local_clifford: + self.__draw_local_clifford(pos) if show_measurement_planes: - for node in self.graph.nodes(): - if node in self.meas_planes: - plt.text(*pos[node] + np.array([0.22, -0.2]), f"{self.meas_planes[node]}", fontsize=9, zorder=3) + self.__draw_measurement_planes(pos) - # Draw the labels - fontsize = 12 - if max(self.graph.nodes()) >= 100: - fontsize = fontsize * 2 / len(str(max(self.graph.nodes()))) - nx.draw_networkx_labels(self.graph, pos, font_size=fontsize) + self._draw_labels(pos) x_min = min(pos[node][0] for node in self.graph.nodes()) # Get the minimum x coordinate x_max = max(pos[node][0] for node in self.graph.nodes()) # Get the maximum x coordinate @@ -620,17 +616,17 @@ def visualize_wo_structure( def visualize_all_correction( self, - layers: dict[int, int], - xflow: dict[int, set[int]], - zflow: dict[int, set[int]], + layers: Mapping[int, int], + xflow: Mapping[int, set[int]], + zflow: Mapping[int, set[int]], show_pauli_measurement: bool = True, show_local_clifford: bool = False, show_measurement_planes: bool = False, - node_distance: tuple[int, int] = (1, 1), - figsize: tuple[int, int] | None = None, + node_distance: tuple[float, float] = (1, 1), + figsize: _Point | None = None, save: bool = False, filename: str | None = None, - ): + ) -> None: """ Visualizes the graph of pattern with all correction flows. @@ -669,7 +665,7 @@ def visualize_all_correction( figsize = (figsize[0] + 3.0, figsize[1]) plt.figure(figsize=figsize) - xzflow = {} + xzflow: dict[int, set[int]] = {} for key, value in deepcopy(xflow).items(): if key in xzflow: xzflow[key] |= value @@ -686,8 +682,7 @@ def visualize_all_correction( if len(edge_path[edge]) == 2: nx.draw_networkx_edges(self.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) else: - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(edge_path[edge], t) + curve = self._bezier_curve_linspace(edge_path[edge]) plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) for arrow in arrow_path: if arrow[1] not in xflow.get(arrow[0], set()): @@ -702,14 +697,8 @@ def visualize_all_correction( ) else: path = arrow_path[arrow] - last = np.array(path[-1]) - second_last = np.array(path[-2]) - path[-1] = list( - last - (last - second_last) / np.linalg.norm(last - second_last) * 0.2 - ) # Shorten the last edge not to hide arrow under the node - - t = np.linspace(0, 1, 100) - curve = self._bezier_curve(path, t) + GraphVisualizer._shorten_path(path) + curve = self._bezier_curve_linspace(path) plt.plot(curve[:, 0], curve[:, 1], c=color, linewidth=1) plt.annotate( @@ -721,21 +710,13 @@ def visualize_all_correction( self.__draw_nodes_role(pos, show_pauli_measurement) - if show_local_clifford and self.local_clifford is not None: - for node in self.graph.nodes(): - if node in self.local_clifford: - plt.text(*pos[node] + np.array([0.2, 0.2]), f"{self.local_clifford[node]}", fontsize=10, zorder=3) + if show_local_clifford: + self.__draw_local_clifford(pos) if show_measurement_planes: - for node in self.graph.nodes(): - if node in self.meas_planes: - plt.text(*pos[node] + np.array([0.22, -0.2]), f"{self.meas_planes[node]}", fontsize=9, zorder=3) + self.__draw_measurement_planes(pos) - # Draw the labels - fontsize = 12 - if max(self.graph.nodes()) >= 100: - fontsize = fontsize * 2 / len(str(max(self.graph.nodes()))) - nx.draw_networkx_labels(self.graph, pos, font_size=fontsize) + self._draw_labels(pos) # legend for arrow colors plt.plot([], [], "k--", alpha=0.7, label="graph edge") @@ -761,10 +742,10 @@ def visualize_all_correction( def get_figsize( self, - l_k: dict[int, int], - pos: dict[int, tuple[float, float]] | None = None, - node_distance: tuple[int, int] = (1, 1), - ) -> tuple[int, int]: + l_k: Mapping[int, int] | None, + pos: Mapping[int, _Point] | None = None, + node_distance: tuple[float, float] = (1, 1), + ) -> _Point: """ Return the figure size of the graph. @@ -783,13 +764,17 @@ def get_figsize( figure size of the graph. """ if l_k is None: + if pos is None: + raise ValueError("l_k and pos cannot be both None") width = len({pos[node][0] for node in self.graph.nodes()}) * 0.8 else: width = (max(l_k.values()) + 1) * 0.8 height = len({pos[node][1] for node in self.graph.nodes()}) if pos is not None else len(self.v_out) return (width * node_distance[0], height * node_distance[1]) - def get_edge_path(self, flow: dict[int, int | set[int]], pos: dict[int, tuple[float, float]]) -> dict[int, list]: + def get_edge_path( + self, flow: Mapping[int, int | set[int]], pos: Mapping[int, _Point] + ) -> tuple[dict[int, list[_Point]], dict[_Edge, list[_Point]]]: """ Return the path of edges and gflow arrows. @@ -808,49 +793,15 @@ def get_edge_path(self, flow: dict[int, int | set[int]], pos: dict[int, tuple[fl dictionary of arrow paths. """ max_iter = 5 - edge_path = {} - arrow_path = {} + edge_path = self.get_edge_path_wo_structure(pos) edge_set = set(self.graph.edges()) - flow_arrows = {(k, v) for k, values in flow.items() for v in values} - # set of mid-points of the edges - # mid_points = {(0.5 * (pos[k][0] + pos[v][0]), 0.5 * (pos[k][1] + pos[v][1])) for k, v in edge_set} - set(pos[node] for node in self.g.nodes()) - - for edge in edge_set: - iteration = 0 - nodes = self.graph.nodes() - bezier_path = [pos[edge[0]], pos[edge[1]]] - while True: - iteration += 1 - intersect = False - if iteration > max_iter: - break - ctrl_points = [] - for i in range(len(bezier_path) - 1): - start = bezier_path[i] - end = bezier_path[i + 1] - for node in nodes: - if node != edge[0] and node != edge[1] and self._edge_intersects_node(start, end, pos[node]): - intersect = True - ctrl_points.append( - [ - i, - self._control_point( - bezier_path[0], bezier_path[-1], pos[node], distance=0.6 / iteration - ), - ] - ) - nodes = set(nodes) - {node} - if not intersect: - break - for i, ctrl_point in enumerate(ctrl_points): - bezier_path.insert(ctrl_point[0] + i + 1, ctrl_point[1]) - bezier_path = self._check_path(bezier_path) - edge_path[edge] = bezier_path + arrow_path: dict[_Edge, list[_Point]] = {} + flow_arrows = {(k, v) for k, values in flow.items() for v in ((values,) if isinstance(values, int) else values)} for arrow in flow_arrows: if arrow[0] == arrow[1]: # Self loop - def _point_from_node(pos, dist, angle): + def _point_from_node(pos: Sequence[float], dist: float, angle: float) -> _Point: """Return a point at a given distance and angle from ``pos``. Parameters @@ -865,11 +816,11 @@ def _point_from_node(pos, dist, angle): Returns ------- - list[float] + _Point The new ``[x, y]`` coordinate. """ angle = np.deg2rad(angle) - return [pos[0] + dist * np.cos(angle), pos[1] + dist * np.sin(angle)] + return (pos[0] + dist * np.cos(angle), pos[1] + dist * np.sin(angle)) bezier_path = [ _point_from_node(pos[arrow[0]], 0.2, 170), @@ -882,7 +833,7 @@ def _point_from_node(pos, dist, angle): ] else: iteration = 0 - nodes = self.graph.nodes() + nodes = set(self.graph.nodes()) bezier_path = [pos[arrow[0]], pos[arrow[1]]] if arrow in edge_set or (arrow[1], arrow[0]) in edge_set: mid_point = ( @@ -909,21 +860,21 @@ def _point_from_node(pos, dist, angle): ): intersect = True ctrl_points.append( - [ + ( i, self._control_point(start, end, pos[node], distance=0.6 / iteration), - ] + ) ) if not intersect: break - for i, ctrl_point in enumerate(ctrl_points): - bezier_path.insert(ctrl_point[0] + i + 1, ctrl_point[1]) + for i, (index, ctrl_point) in enumerate(ctrl_points): + bezier_path.insert(index + i + 1, ctrl_point) bezier_path = self._check_path(bezier_path, pos[arrow[1]]) arrow_path[arrow] = bezier_path return edge_path, arrow_path - def get_edge_path_wo_structure(self, pos: dict[int, tuple[float, float]]) -> dict[int, list]: + def get_edge_path_wo_structure(self, pos: Mapping[int, _Point]) -> dict[int, list[_Point]]: """ Return the path of edges. @@ -938,33 +889,33 @@ def get_edge_path_wo_structure(self, pos: dict[int, tuple[float, float]]) -> dic dictionary of edge paths. """ max_iter = 5 - edge_path = {} + edge_path: dict[int, list[_Point]] = {} edge_set = set(self.graph.edges()) for edge in edge_set: iteration = 0 - nodes = self.graph.nodes() + nodes = set(self.graph.nodes()) bezier_path = [pos[edge[0]], pos[edge[1]]] while True: iteration += 1 intersect = False if iteration > max_iter: break - ctrl_points = [] + ctrl_points: list[tuple[int, _Point]] = [] for i in range(len(bezier_path) - 1): start = bezier_path[i] end = bezier_path[i + 1] - for node in nodes: + for node in list(nodes): if node != edge[0] and node != edge[1] and self._edge_intersects_node(start, end, pos[node]): intersect = True ctrl_points.append( - [ + ( i, self._control_point( bezier_path[0], bezier_path[-1], pos[node], distance=0.6 / iteration ), - ] + ) ) - nodes = set(nodes) - {node} + nodes -= {node} if not intersect: break for i, ctrl_point in enumerate(ctrl_points): @@ -973,7 +924,7 @@ def get_edge_path_wo_structure(self, pos: dict[int, tuple[float, float]]) -> dic edge_path[edge] = bezier_path return edge_path - def get_pos_from_flow(self, f: dict[int, int], l_k: dict[int, int]) -> dict[int, tuple[float, float]]: + def get_pos_from_flow(self, f: Mapping[int, set[int]], l_k: Mapping[int, int]) -> dict[int, _Point]: """ Return the position of nodes based on the flow. @@ -990,7 +941,7 @@ def get_pos_from_flow(self, f: dict[int, int], l_k: dict[int, int]) -> dict[int, dictionary of node positions. """ values_union = set().union(*f.values()) - start_nodes = self.graph.nodes() - values_union + start_nodes = set(self.graph.nodes()) - values_union pos = {node: [0, 0] for node in self.graph.nodes()} for i, k in enumerate(start_nodes): pos[k][1] = i @@ -1003,9 +954,9 @@ def get_pos_from_flow(self, f: dict[int, int], l_k: dict[int, int]) -> dict[int, # Change the x coordinates of the nodes based on their layer, sort in descending order for node, layer in l_k.items(): pos[node][0] = lmax - layer - return {k: tuple(v) for k, v in pos.items()} + return {k: (x, y) for k, (x, y) in pos.items()} - def get_pos_from_gflow(self, g: dict[int, set[int]], l_k: dict[int, int]) -> dict[int, tuple[float, float]]: + def get_pos_from_gflow(self, g: Mapping[int, set[int]], l_k: Mapping[int, int]) -> dict[int, _Point]: """ Return the position of nodes based on the gflow. @@ -1021,7 +972,7 @@ def get_pos_from_gflow(self, g: dict[int, set[int]], l_k: dict[int, int]) -> dic pos : dict dictionary of node positions. """ - g_edges = [] + g_edges: list[_Edge] = [] for node, node_list in g.items(): g_edges.extend((node, n) for n in node_list) @@ -1033,7 +984,7 @@ def get_pos_from_gflow(self, g: dict[int, set[int]], l_k: dict[int, int]) -> dic l_max = max(l_k.values()) l_reverse = {v: l_max - l for v, l in l_k.items()} - nx.set_node_attributes(g_prime, l_reverse, "subset") + _set_node_attributes(g_prime, l_reverse, "subset") pos = nx.multipartite_layout(g_prime) @@ -1047,7 +998,7 @@ def get_pos_from_gflow(self, g: dict[int, set[int]], l_k: dict[int, int]) -> dic return pos - def get_pos_wo_structure(self) -> dict[int, tuple[float, float]]: + def get_pos_wo_structure(self) -> dict[int, _Point]: """ Return the position of nodes based on the graph. @@ -1061,12 +1012,12 @@ def get_pos_wo_structure(self) -> dict[int, tuple[float, float]]: pos : dict dictionary of node positions. """ - layers = {} + layers: dict[int, int] = {} connected_components = list(nx.connected_components(self.graph)) for component in connected_components: subgraph = self.graph.subgraph(component) - initial_pos = dict.fromkeys(component, (0, 0)) + initial_pos: dict[int, tuple[int, int]] = dict.fromkeys(component, (0, 0)) if len(set(self.v_out) & set(component)) == 0 and len(set(self.v_in) & set(component)) == 0: pos = nx.spring_layout(subgraph) @@ -1128,7 +1079,7 @@ def get_pos_wo_structure(self) -> dict[int, tuple[float, float]]: g_prime.add_edges_from(self.graph.edges()) l_max = max(layers.values()) l_reverse = {v: l_max - l for v, l in layers.items()} - nx.set_node_attributes(g_prime, l_reverse, "subset") + _set_node_attributes(g_prime, l_reverse, "subset") pos = nx.multipartite_layout(g_prime) for node, layer in layers.items(): pos[node][0] = l_max - layer @@ -1138,7 +1089,7 @@ def get_pos_wo_structure(self) -> dict[int, tuple[float, float]]: pos[node][1] = vert.index(pos[node][1]) return pos - def get_pos_all_correction(self, layers: dict[int, int]) -> dict[int, tuple[float, float]]: + def get_pos_all_correction(self, layers: Mapping[int, int]) -> dict[int, _Point]: """ Return the position of nodes based on the pattern. @@ -1155,63 +1106,75 @@ def get_pos_all_correction(self, layers: dict[int, int]) -> dict[int, tuple[floa g_prime = self.graph.copy() g_prime.add_nodes_from(self.graph.nodes()) g_prime.add_edges_from(self.graph.edges()) - nx.set_node_attributes(g_prime, layers, "subset") - pos = nx.multipartite_layout(g_prime) - for node, layer in layers.items(): - pos[node][0] = layer - vert = list({pos[node][1] for node in self.graph.nodes()}) + _set_node_attributes(g_prime, layers, "subset") + layout = nx.multipartite_layout(g_prime) + vert = list({layout[node][1] for node in self.graph.nodes()}) vert.sort() - for node in self.graph.nodes(): - pos[node][1] = vert.index(pos[node][1]) - return pos + return {node: (layers[node], vert.index(layout[node][1])) for node in self.graph.nodes()} @staticmethod - def _edge_intersects_node(start, end, node_pos, buffer=0.2): + def _edge_intersects_node( + start: _Point, + end: _Point, + node_pos: _Point, + buffer: float = 0.2, + ) -> bool: """Determine if an edge intersects a node.""" - start = np.array(start) - end = np.array(end) - if np.all(start == end): + start_array = np.array(start) + end_array = np.array(end) + if np.all(start_array == end_array): return False - node_pos = np.array(node_pos) + node_pos_array = np.array(node_pos) # Vector from start to end - line_vec = end - start + line_vec = end_array - start_array # Vector from start to node_pos - point_vec = node_pos - start + point_vec = node_pos_array - start_array t = np.dot(point_vec, line_vec) / np.dot(line_vec, line_vec) if t < 0.0 or t > 1.0: return False # Find the projection point - projection = start + t * line_vec + projection = start_array + t * line_vec distance = np.linalg.norm(projection - node_pos) - return distance < buffer + return bool(distance < buffer) @staticmethod - def _control_point(start, end, node_pos, distance=0.6): + def _control_point( + start: _Point, + end: _Point, + node_pos: _Point, + distance: float = 0.6, + ) -> _Point: """Generate a control point to bend the edge around a node.""" + node_pos_array = np.array(node_pos) edge_vector = np.asarray(end, dtype=np.float64) - np.asarray(start, dtype=np.float64) # Rotate the edge vector 90 degrees or -90 degrees according to the node position - cross = np.cross(edge_vector, np.array(node_pos) - np.array(start)) + cross = np.cross(edge_vector, node_pos_array - np.array(start)) if cross > 0: dir_vector = np.array([edge_vector[1], -edge_vector[0]]) # Rotate the edge vector 90 degrees else: dir_vector = np.array([-edge_vector[1], edge_vector[0]]) dir_vector /= np.linalg.norm(dir_vector) # Normalize the vector - control = node_pos + distance * dir_vector - return control.tolist() + u, v = node_pos_array + distance * dir_vector + return u, v @staticmethod - def _bezier_curve(bezier_path, t): + def _bezier_curve(bezier_path: Sequence[_Point], t: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]: """Generate a bezier curve from a list of points.""" n = len(bezier_path) - 1 # order of the curve curve = np.zeros((len(t), 2)) for i, point in enumerate(bezier_path): - curve += np.outer(comb(n, i) * ((1 - t) ** (n - i)) * (t**i), np.array(point)) + curve += np.outer(math.comb(n, i) * ((1 - t) ** (n - i)) * (t**i), np.array(point)) return curve @staticmethod - def _check_path(path, target_node_pos=None): + def _bezier_curve_linspace(bezier_path: Sequence[_Point]) -> npt.NDArray[np.float64]: + t = np.linspace(0, 1, 100, dtype=np.float64) + return GraphVisualizer._bezier_curve(bezier_path, t) + + @staticmethod + def _check_path(path: Iterable[_Point], target_node_pos: _Point | None = None) -> list[_Point]: """If there is an acute angle in the path, merge points.""" path = np.array(path) acute = True @@ -1238,7 +1201,7 @@ def _check_path(path, target_node_pos=None): it += 1 else: acute = False - new_path = path.tolist() + new_path: list[_Point] = path.tolist() if target_node_pos is not None: for point in new_path[:-1]: if np.linalg.norm(np.array(point) - np.array(target_node_pos)) < 0.2: @@ -1246,6 +1209,5 @@ def _check_path(path, target_node_pos=None): return new_path -def comb(n, r): - """Return the binomial coefficient of n and r.""" - return math.factorial(n) // (math.factorial(n - r) * math.factorial(r)) +def _set_node_attributes(graph: nx.Graph[_HashableT], attrs: Mapping[_HashableT, object], name: str) -> None: + nx.set_node_attributes(graph, attrs, name=name) # type: ignore[arg-type] diff --git a/pyproject.toml b/pyproject.toml index 68e31b4e5..1d9de9da0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,12 +142,10 @@ exclude = [ '^examples/qnn\.py$', '^examples/rotation\.py$', '^examples/tn_simulation\.py$', - '^examples/visualization\.py$', '^graphix/device_interface\.py$', '^graphix/gflow\.py$', '^graphix/linalg\.py$', '^graphix/random_objects\.py$', - '^graphix/visualization\.py$', '^tests/test_density_matrix\.py$', '^tests/test_gflow\.py$', '^tests/test_linalg\.py$', @@ -157,7 +155,6 @@ exclude = [ '^tests/test_statevec\.py$', '^tests/test_statevec_backend\.py$', '^tests/test_tnsim\.py$', - '^tests/test_visualization\.py$', ] follow_imports = "silent" follow_untyped_imports = true # required for qiskit, requires mypy >=1.14 @@ -176,12 +173,10 @@ exclude = [ "examples/qnn.py", "examples/rotation.py", "examples/tn_simulation.py", - "examples/visualization.py", "graphix/device_interface.py", "graphix/gflow.py", "graphix/linalg.py", "graphix/random_objects.py", - "graphix/visualization.py", "tests/test_density_matrix.py", "tests/test_gflow.py", "tests/test_linalg.py", @@ -192,7 +187,6 @@ exclude = [ "tests/test_statevec_backend.py", "tests/test_tnsim.py", "tests/test_transpiler.py", - "tests/test_visualization.py", ] reportUnknownArgumentType = "information" reportUnknownLambdaType = "information" diff --git a/tests/test_visualization.py b/tests/test_visualization.py index ddd52cf6b..b68e330be 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -3,7 +3,7 @@ from graphix import gflow, transpiler, visualization -def test_get_pos_from_flow(): +def test_get_pos_from_flow() -> None: circuit = transpiler.Circuit(1) circuit.h(0) pattern = circuit.transpile().pattern @@ -15,5 +15,7 @@ def test_get_pos_from_flow(): local_clifford = pattern.get_vops() vis = visualization.GraphVisualizer(graph, vin, vout, meas_planes, meas_angles, local_clifford) f, l_k = gflow.find_flow(graph, set(vin), set(vout), meas_planes) + assert f is not None + assert l_k is not None pos = vis.get_pos_from_flow(f, l_k) assert pos is not None From 2eca607c56cdb612fd64f5b395df28673f984f16 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Tue, 7 Oct 2025 11:01:48 +0200 Subject: [PATCH 2/4] Fix regression in `visualize_w_gflow` and add test --- graphix/visualization.py | 10 +++++----- tests/test_visualization.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/graphix/visualization.py b/graphix/visualization.py index 0a00969b3..d7d1ad36b 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -466,10 +466,10 @@ def visualize_w_gflow( curve = self._bezier_curve_linspace(edge_path[edge]) plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) - for arrow in arrow_path.values(): + for arrow, path in arrow_path.items(): if arrow[0] == arrow[1]: # self loop if show_loop: - curve = self._bezier_curve_linspace(arrow) + curve = self._bezier_curve_linspace(path) plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) plt.annotate( "", @@ -477,13 +477,13 @@ def visualize_w_gflow( xytext=curve[-2], arrowprops={"arrowstyle": "->", "color": "k", "lw": 1}, ) - elif len(arrow) == 2: # straight line + elif len(path) == 2: # straight line nx.draw_networkx_edges( self.graph, pos, edgelist=[arrow], edge_color="black", arrowstyle="->", arrows=True ) else: - GraphVisualizer._shorten_path(arrow) - curve = self._bezier_curve_linspace(arrow) + GraphVisualizer._shorten_path(path) + curve = self._bezier_curve_linspace(path) plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) plt.annotate( diff --git a/tests/test_visualization.py b/tests/test_visualization.py index b68e330be..02fe23d98 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,5 +1,10 @@ from __future__ import annotations +from math import pi + +import matplotlib.pyplot as plt +import pytest + from graphix import gflow, transpiler, visualization @@ -19,3 +24,21 @@ def test_get_pos_from_flow() -> None: assert l_k is not None pos = vis.get_pos_from_flow(f, l_k) assert pos is not None + + +@pytest.fixture +def mock_plot(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(plt, "show", lambda: None) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_flow_from_pattern() -> None: + circuit = transpiler.Circuit(3) + circuit.cnot(0, 1) + circuit.cnot(2, 1) + circuit.rx(0, pi / 3) + circuit.x(2) + circuit.cnot(2, 1) + pattern = circuit.transpile().pattern + pattern.perform_pauli_measurements(leave_input=True) + pattern.draw_graph(flow_from_pattern=True, show_measurement_planes=True, node_distance=(0.7, 0.6)) From e5411f337f0bb51b09b3cd3d87cd8d94bd49bad6 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Thu, 30 Oct 2025 19:52:38 +0100 Subject: [PATCH 3/4] Refactor visualization --- graphix/gflow.py | 22 +- graphix/pattern.py | 18 +- graphix/visualization.py | 682 ++++++++++-------------------------- tests/test_generator.py | 154 ++++---- tests/test_visualization.py | 122 ++++++- 5 files changed, 393 insertions(+), 605 deletions(-) diff --git a/graphix/gflow.py b/graphix/gflow.py index 5f19bb939..439c63674 100644 --- a/graphix/gflow.py +++ b/graphix/gflow.py @@ -282,7 +282,7 @@ def find_pauliflow( return pf[0], pf[1] -def flow_from_pattern(pattern: Pattern) -> tuple[dict[int, set[int]], dict[int, int]]: +def flow_from_pattern(pattern: Pattern) -> tuple[dict[int, set[int]], dict[int, int]] | tuple[None, None]: """Check if the pattern has a valid flow. If so, return the flow and layers. Parameters @@ -292,9 +292,11 @@ def flow_from_pattern(pattern: Pattern) -> tuple[dict[int, set[int]], dict[int, Returns ------- - f: dict + None, None: + The tuple ``(None, None)`` is returned if the pattern does not have a valid causal flow. + f: dict[int, set[int]] flow function. g[i] is the set of qubits to be corrected for the measurement of qubit i. - l_k: dict + l_k: dict[int, int] layers obtained by flow algorithm. l_k[d] is a node set of depth d. """ if not pattern.is_standard(strict=True): @@ -332,7 +334,7 @@ def flow_from_pattern(pattern: Pattern) -> tuple[dict[int, set[int]], dict[int, return None, None -def gflow_from_pattern(pattern: Pattern) -> tuple[dict[int, set[int]], dict[int, int]]: +def gflow_from_pattern(pattern: Pattern) -> tuple[dict[int, set[int]], dict[int, int]] | tuple[None, None]: """Check if the pattern has a valid gflow. If so, return the gflow and layers. Parameters @@ -342,9 +344,11 @@ def gflow_from_pattern(pattern: Pattern) -> tuple[dict[int, set[int]], dict[int, Returns ------- - g: dict + None, None: + The tuple ``(None, None)`` is returned if the pattern does not have a valid gflow. + g: dict[int, set[int]] gflow function. g[i] is the set of qubits to be corrected for the measurement of qubit i. - l_k: dict + l_k: dict[int, int] layers obtained by gflow algorithm. l_k[d] is a node set of depth d. """ if not pattern.is_standard(strict=True): @@ -404,9 +408,11 @@ def pauliflow_from_pattern( Returns ------- - p: dict + None, None: + The tuple ``(None, None)`` is returned if the pattern does not have a valid Pauli flow. + p: dict[int, set[int]] Pauli flow function. p[i] is the set of qubits to be corrected for the measurement of qubit i. - l_k: dict + l_k: dict[int, int] layers obtained by Pauli flow algorithm. l_k[d] is a node set of depth d. """ if not pattern.is_standard(strict=True): diff --git a/graphix/pattern.py b/graphix/pattern.py index cd63cfdab..d136989f9 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -994,7 +994,7 @@ def get_measurement_order_from_flow(self) -> list[int] | None: measurement order """ graph = self.extract_graph() - vin = set(self.input_nodes) if self.input_nodes is not None else set() + vin = set(self.input_nodes) vout = set(self.output_nodes) meas_planes = self.get_meas_plane() f, l_k = find_flow(graph, vin, vout, meas_planes=meas_planes) @@ -1020,7 +1020,7 @@ def get_measurement_order_from_gflow(self) -> list[int]: isolated = list(nx.isolates(graph)) if isolated: raise ValueError("The input graph must be connected") - vin = set(self.input_nodes) if self.input_nodes is not None else set() + vin = set(self.input_nodes) vout = set(self.output_nodes) meas_planes = self.get_meas_plane() flow, l_k = find_gflow(graph, vin, vout, meas_planes=meas_planes) @@ -1391,8 +1391,7 @@ def draw_graph( show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, - save: bool = False, - filename: str | None = None, + filename: Path | None = None, ) -> None: """Visualize the underlying graph of the pattern with flow or gflow structure. @@ -1412,13 +1411,12 @@ def draw_graph( Distance multiplication factor between nodes for x and y directions. figsize : tuple Figure size of the plot. - save : bool - If True, the plot is saved as a png file. - filename : str - Filename of the saved plot. + filename : Path | None + If not None, filename of the png file to save the plot. If None, the plot is not saved. + Default in None. """ graph = self.extract_graph() - vin = self.input_nodes if self.input_nodes is not None else [] + vin = self.input_nodes vout = self.output_nodes meas_planes = self.get_meas_plane() meas_angles = self.get_angles() @@ -1435,7 +1433,6 @@ def draw_graph( show_loop=show_loop, node_distance=node_distance, figsize=figsize, - save=save, filename=filename, ) else: @@ -1446,7 +1443,6 @@ def draw_graph( show_loop=show_loop, node_distance=node_distance, figsize=figsize, - save=save, filename=filename, ) diff --git a/graphix/visualization.py b/graphix/visualization.py index d7d1ad36b..98d7036cb 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -16,6 +16,8 @@ if TYPE_CHECKING: from collections.abc import Collection, Hashable, Iterable, Mapping, Sequence + from collections.abc import Set as AbstractSet + from pathlib import Path from typing import TypeAlias, TypeVar import numpy.typing as npt @@ -83,7 +85,7 @@ def __init__( self.v_in = v_in self.v_out = v_out if meas_plane is None: - self.meas_planes = dict.fromkeys(iter(g.nodes), Plane.XY) + self.meas_planes = dict.fromkeys(g.nodes - set(v_out), Plane.XY) else: self.meas_planes = dict(meas_plane) self.meas_angles = meas_angles @@ -97,8 +99,7 @@ def visualize( show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, - save: bool = False, - filename: str | None = None, + filename: Path | None = None, ) -> None: """ Visualize the graph with flow or gflow structure. @@ -122,52 +123,40 @@ def visualize( Distance multiplication factor between nodes for x and y directions. figsize : tuple Figure size of the plot. - save : bool - If True, the plot is saved as a png file. - filename : str - Filename of the saved plot. + filename : Path | None + If not None, filename of the png file to save the plot. If None, the plot is not saved. + Default in None. """ f, l_k = gflow.find_flow(self.graph, set(self.v_in), set(self.v_out), meas_planes=self.meas_planes) # try flow if f is not None and l_k is not None: print("Flow detected in the graph.") - self.visualize_w_flow( - f, - l_k, - show_pauli_measurement, - show_local_clifford, - show_measurement_planes, - node_distance, - figsize, - save, - filename, - ) + pos = self.get_pos_from_flow(f, l_k) + edge_path, arrow_path = self.get_edge_path(f, pos) else: g, l_k = gflow.find_gflow(self.graph, set(self.v_in), set(self.v_out), self.meas_planes) # try gflow if g is not None and l_k is not None: print("Gflow detected in the graph. (flow not detected)") - self.visualize_w_gflow( - g, - l_k, - show_pauli_measurement, - show_local_clifford, - show_measurement_planes, - show_loop, - node_distance, - figsize, - save, - filename, - ) + pos = self.get_pos_from_gflow(g, l_k) + edge_path, arrow_path = self.get_edge_path(g, pos) else: print("No flow or gflow detected in the graph.") - self.visualize_wo_structure( - show_pauli_measurement, - show_local_clifford, - show_measurement_planes, - node_distance, - figsize, - save, - filename, - ) + pos = self.get_pos_wo_structure() + edge_path = self.get_edge_path_wo_structure(pos) + arrow_path = None + self.visualize_graph( + pos, + edge_path, + arrow_path, + l_k, + None, + show_pauli_measurement, + show_local_clifford, + show_measurement_planes, + show_loop, + node_distance, + figsize, + filename, + ) def visualize_from_pattern( self, @@ -178,8 +167,7 @@ def visualize_from_pattern( show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, - save: bool = False, - filename: str | None = None, + filename: Path | None = None, ) -> None: """ Visualize the graph with flow or gflow structure found from the given pattern. @@ -204,41 +192,23 @@ def visualize_from_pattern( Distance multiplication factor between nodes for x and y directions. figsize : tuple Figure size of the plot. - save : bool - If True, the plot is saved as a png file. - filename : str - Filename of the saved plot. + filename : Path | None + If not None, filename of the png file to save the plot. If None, the plot is not saved. + Default in None. """ f, l_k = gflow.flow_from_pattern(pattern) # try flow - if f: + if f is not None and l_k is not None: print("The pattern is consistent with flow structure.") - self.visualize_w_flow( - f, - l_k, - show_pauli_measurement, - show_local_clifford, - show_measurement_planes, - node_distance, - figsize, - save, - filename, - ) + pos = self.get_pos_from_flow(f, l_k) + edge_path, arrow_path = self.get_edge_path(f, pos) + corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None = None else: g, l_k = gflow.gflow_from_pattern(pattern) # try gflow - if g: + if g is not None and l_k is not None: print("The pattern is consistent with gflow structure. (not with flow)") - self.visualize_w_gflow( - g, - l_k, - show_pauli_measurement, - show_local_clifford, - show_measurement_planes, - show_loop, - node_distance, - figsize, - save, - filename, - ) + pos = self.get_pos_from_gflow(g, l_k) + edge_path, arrow_path = self.get_edge_path(g, pos) + corrections = None else: print("The pattern is not consistent with flow or gflow structure.") depth, layers = pattern.get_layers() @@ -246,123 +216,29 @@ def visualize_from_pattern( for output in pattern.output_nodes: unfolded_layers[output] = depth + 1 xflow, zflow = gflow.get_corrections_from_pattern(pattern) - self.visualize_all_correction( - unfolded_layers, - xflow, - zflow, - show_pauli_measurement, - show_local_clifford, - show_measurement_planes, - node_distance, - figsize, - save, - filename, - ) - - def visualize_w_flow( - self, - f: Mapping[int, set[int]], - l_k: Mapping[int, int], - show_pauli_measurement: bool = True, - show_local_clifford: bool = False, - show_measurement_planes: bool = False, - node_distance: tuple[float, float] = (1, 1), - figsize: _Point | None = None, - save: bool = False, - filename: str | None = None, - ) -> None: - """ - Visualizes the graph with flow structure. - - Nodes are colored based on their role (input, output, or other) and edges are depicted as arrows - or dashed lines depending on whether they are in the flow mapping. Vertical dashed lines separate - different layers of the graph. This function does not return anything but plots the graph - using matplotlib's pyplot. - - Parameters - ---------- - f : dict - flow mapping. - l_k : dict - Layer mapping. - show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. - show_local_clifford : bool - If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. - node_distance : tuple - Distance multiplication factor between nodes for x and y directions. - figsize : tuple - Figure size of the plot. - save : bool - If True, the plot is saved. - filename : str - Filename of the saved plot. - """ - if figsize is None: - figsize = self.get_figsize(l_k, node_distance=node_distance) - plt.figure(figsize=figsize) - pos = self.get_pos_from_flow(f, l_k) - - edge_path, arrow_path = self.get_edge_path(f, pos) - - for edge in edge_path: - if len(edge_path[edge]) == 2: - nx.draw_networkx_edges(self.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) - else: - curve = self._bezier_curve_linspace(edge_path[edge]) - plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) - - for arrow, path in arrow_path.items(): - if len(path) == 2: - nx.draw_networkx_edges( - self.graph, pos, edgelist=[arrow], edge_color="black", arrowstyle="->", arrows=True - ) - else: - GraphVisualizer._shorten_path(path) - curve = self._bezier_curve_linspace(path) - - plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) - plt.annotate( - "", - xy=curve[-1], - xytext=curve[-2], - arrowprops={"arrowstyle": "->", "color": "k", "lw": 1}, - ) - - self.__draw_nodes_role(pos, show_pauli_measurement) - - if show_local_clifford: - self.__draw_local_clifford(pos) - - if show_measurement_planes: - self.__draw_measurement_planes(pos) - - self._draw_labels(pos) - - x_min = min(pos[node][0] for node in self.graph.nodes()) # Get the minimum x coordinate - x_max = max(pos[node][0] for node in self.graph.nodes()) # Get the maximum x coordinate - y_min = min(pos[node][1] for node in self.graph.nodes()) # Get the minimum y coordinate - y_max = max(pos[node][1] for node in self.graph.nodes()) # Get the maximum y coordinate - - # Draw the vertical lines to separate different layers - for layer in range(min(l_k.values()), max(l_k.values())): - plt.axvline( - x=(layer + 0.5) * node_distance[0], color="gray", linestyle="--", alpha=0.5 - ) # Draw line between layers - for layer in range(min(l_k.values()), max(l_k.values()) + 1): - plt.text( - layer * node_distance[0], y_min - 0.5, f"l: {max(l_k.values()) - layer}", ha="center", va="top" - ) # Add layer label at bottom - - plt.xlim( - x_min - 0.5 * node_distance[0], x_max + 0.5 * node_distance[0] - ) # Add some padding to the left and right - plt.ylim(y_min - 1, y_max + 0.5) # Add some padding to the top and bottom - if save: - plt.savefig(filename) - plt.show() + xzflow: dict[int, set[int]] = deepcopy(xflow) + for key, value in zflow.items(): + if key in xzflow: + xzflow[key] |= value + else: + xzflow[key] = set(value) # copy + pos = self.get_pos_all_correction(unfolded_layers) + edge_path, arrow_path = self.get_edge_path(xzflow, pos) + corrections = xflow, zflow + self.visualize_graph( + pos, + edge_path, + arrow_path, + l_k, + corrections, + show_pauli_measurement, + show_local_clifford, + show_measurement_planes, + show_loop, + node_distance, + figsize, + filename, + ) @staticmethod def _shorten_path(path: list[_Point]) -> None: @@ -406,21 +282,23 @@ def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: b *pos[node], edgecolor=color, facecolor=inner_color, s=350, zorder=2 ) # Draw the nodes manually with scatter() - def visualize_w_gflow( + def visualize_graph( self, - g: Mapping[int, set[int]], - l_k: Mapping[int, int], + pos: Mapping[int, _Point], + edge_path: Mapping[_Edge, Sequence[_Point]], + arrow_path: Mapping[_Edge, list[_Point]] | None, + l_k: Mapping[int, int] | None, + corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None, show_pauli_measurement: bool = True, show_local_clifford: bool = False, show_measurement_planes: bool = False, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: _Point | None = None, - save: bool = False, - filename: str | None = None, + filename: Path | None = None, ) -> None: """ - Visualizes the graph with flow structure. + Visualizes the graph. Nodes are colored based on their role (input, output, or other) and edges are depicted as arrows or dashed lines depending on whether they are in the flow mapping. Vertical dashed lines separate @@ -429,10 +307,16 @@ def visualize_w_gflow( Parameters ---------- - g : dict - gflow mapping. - l_k : dict - Layer mapping. + pos: Mapping[int, _Point] + Node positions. + edge_path: Sequence[Mapping[int, Sequence[_Point]]] + Mapping of edge paths. + arrow_path: Mapping[_Edge, list[_Point]] | None + Mapping of arrow paths. + l_k: Mapping[int, int] | None + Layer mapping if any. + corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None + X and Z corrections if any. show_pauli_measurement : bool If True, the nodes with Pauli measurement angles are colored light blue. show_local_clifford : bool @@ -445,18 +329,19 @@ def visualize_w_gflow( Distance multiplication factor between nodes for x and y directions. figsize : tuple Figure size of the plot. - save : bool - If True, the plot is saved as a png file. - filename : str - Filename of the saved plot. + filename : Path | None + If not None, filename of the png file to save the plot. If None, the plot is not saved. + Default in None. """ - pos = self.get_pos_from_gflow(g, l_k) pos = {k: (v[0] * node_distance[0], v[1] * node_distance[1]) for k, v in pos.items()} # Scale the layout - edge_path, arrow_path = self.get_edge_path(g, pos) - if figsize is None: figsize = self.get_figsize(l_k, pos, node_distance=node_distance) + + if corrections is not None: + # add some padding to the right for the legend + figsize = (figsize[0] + 3.0, figsize[1]) + plt.figure(figsize=figsize) for edge in edge_path: @@ -466,32 +351,42 @@ def visualize_w_gflow( curve = self._bezier_curve_linspace(edge_path[edge]) plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) - for arrow, path in arrow_path.items(): - if arrow[0] == arrow[1]: # self loop - if show_loop: + if arrow_path is not None: + for arrow, path in arrow_path.items(): + if corrections is None: + color = "k" + else: + xflow, zflow = corrections + if arrow[1] not in xflow.get(arrow[0], set()): + color = "tab:green" + elif arrow[1] not in zflow.get(arrow[0], set()): + color = "tab:red" + else: + color = "tab:brown" + if arrow[0] == arrow[1]: # self loop + if show_loop: + curve = self._bezier_curve_linspace(path) + plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) + plt.annotate( + "", + xy=curve[-1], + xytext=curve[-2], + arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, + ) + elif len(path) == 2: # straight line + nx.draw_networkx_edges( + self.graph, pos, edgelist=[arrow], edge_color=color, arrowstyle="->", arrows=True + ) + else: + GraphVisualizer._shorten_path(path) curve = self._bezier_curve_linspace(path) - plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) + plt.plot(curve[:, 0], curve[:, 1], c=color, linewidth=1) plt.annotate( "", xy=curve[-1], xytext=curve[-2], - arrowprops={"arrowstyle": "->", "color": "k", "lw": 1}, + arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, ) - elif len(path) == 2: # straight line - nx.draw_networkx_edges( - self.graph, pos, edgelist=[arrow], edge_color="black", arrowstyle="->", arrows=True - ) - else: - GraphVisualizer._shorten_path(path) - curve = self._bezier_curve_linspace(path) - - plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) - plt.annotate( - "", - xy=curve[-1], - xytext=curve[-2], - arrowprops={"arrowstyle": "->", "color": "k", "lw": 1}, - ) self.__draw_nodes_role(pos, show_pauli_measurement) @@ -503,26 +398,38 @@ def visualize_w_gflow( self._draw_labels(pos) + if corrections is not None: + # legend for arrow colors + plt.plot([], [], "k--", alpha=0.7, label="graph edge") + plt.plot([], [], color="tab:red", label="xflow") + plt.plot([], [], color="tab:green", label="zflow") + plt.plot([], [], color="tab:brown", label="xflow and zflow") + x_min = min(pos[node][0] for node in self.graph.nodes()) # Get the minimum x coordinate x_max = max(pos[node][0] for node in self.graph.nodes()) # Get the maximum x coordinate y_min = min(pos[node][1] for node in self.graph.nodes()) # Get the minimum y coordinate y_max = max(pos[node][1] for node in self.graph.nodes()) # Get the maximum y coordinate - # Draw the vertical lines to separate different layers - for layer in range(min(l_k.values()), max(l_k.values())): - plt.axvline( - x=(layer + 0.5) * node_distance[0], color="gray", linestyle="--", alpha=0.5 - ) # Draw line between layers - for layer in range(min(l_k.values()), max(l_k.values()) + 1): - plt.text( - layer * node_distance[0], y_min - 0.5, f"l: {max(l_k.values()) - layer}", ha="center", va="top" - ) # Add layer label at bottom + if l_k is not None: + # Draw the vertical lines to separate different layers + for layer in range(min(l_k.values()), max(l_k.values())): + plt.axvline( + x=(layer + 0.5) * node_distance[0], color="gray", linestyle="--", alpha=0.5 + ) # Draw line between layers + for layer in range(min(l_k.values()), max(l_k.values()) + 1): + plt.text( + layer * node_distance[0], y_min - 0.5, f"l: {max(l_k.values()) - layer}", ha="center", va="top" + ) # Add layer label at bottom plt.xlim( x_min - 0.5 * node_distance[0], x_max + 0.5 * node_distance[0] ) # Add some padding to the left and right plt.ylim(y_min - 1, y_max + 0.5) # Add some padding to the top and bottom - if save: + + if corrections is not None: + plt.legend(loc="upper right", fontsize=10) + + if filename is not None: plt.savefig(filename) plt.show() @@ -539,207 +446,6 @@ def __draw_measurement_planes(self, pos: Mapping[int, _Point]) -> None: x, y = pos[node] + np.array([0.22, -0.2]) plt.text(x, y, f"{self.meas_planes[node]}", fontsize=9, zorder=3) - def visualize_wo_structure( - self, - show_pauli_measurement: bool = True, - show_local_clifford: bool = False, - show_measurement_planes: bool = False, - node_distance: tuple[float, float] = (1, 1), - figsize: _Point | None = None, - save: bool = False, - filename: str | None = None, - ) -> None: - """ - Visualizes the graph without flow or gflow. - - Nodes are colored based on their role (input, output, or other) and edges are depicted as arrows - or dashed lines depending on whether they are in the flow mapping. Vertical dashed lines separate - different layers of the graph. This function does not return anything but plots the graph - using matplotlib's pyplot. - - Parameters - ---------- - show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. - show_local_clifford : bool - If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. - node_distance : tuple - Distance multiplication factor between nodes for x and y directions. - figsize : tuple - Figure size of the plot. - save : bool - If True, the plot is saved as a png file. - filename : str - Filename of the saved plot. - """ - pos = self.get_pos_wo_structure() - pos = {k: (v[0] * node_distance[0], v[1] * node_distance[1]) for k, v in pos.items()} # Scale the layout - - if figsize is None: - figsize = self.get_figsize(None, pos, node_distance=node_distance) - plt.figure(figsize=figsize) - - edge_path = self.get_edge_path_wo_structure(pos) - - for edge in edge_path: - if len(edge_path[edge]) == 2: - nx.draw_networkx_edges(self.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) - else: - curve = self._bezier_curve_linspace(edge_path[edge]) - plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) - - self.__draw_nodes_role(pos, show_pauli_measurement) - - if show_local_clifford: - self.__draw_local_clifford(pos) - - if show_measurement_planes: - self.__draw_measurement_planes(pos) - - self._draw_labels(pos) - - x_min = min(pos[node][0] for node in self.graph.nodes()) # Get the minimum x coordinate - x_max = max(pos[node][0] for node in self.graph.nodes()) # Get the maximum x coordinate - y_min = min(pos[node][1] for node in self.graph.nodes()) # Get the minimum y coordinate - y_max = max(pos[node][1] for node in self.graph.nodes()) # Get the maximum y coordinate - - plt.xlim( - x_min - 0.5 * node_distance[0], x_max + 0.5 * node_distance[0] - ) # Add some padding to the left and right - plt.ylim(y_min - 0.5, y_max + 0.5) # Add some padding to the top and bottom - - if save: - plt.savefig(filename) - plt.show() - - def visualize_all_correction( - self, - layers: Mapping[int, int], - xflow: Mapping[int, set[int]], - zflow: Mapping[int, set[int]], - show_pauli_measurement: bool = True, - show_local_clifford: bool = False, - show_measurement_planes: bool = False, - node_distance: tuple[float, float] = (1, 1), - figsize: _Point | None = None, - save: bool = False, - filename: str | None = None, - ) -> None: - """ - Visualizes the graph of pattern with all correction flows. - - Nodes are colored based on their role (input, output, or other) and edges of graph are depicted as dashed lines. - Xflow is depicted as red arrows and Zflow is depicted as blue arrows. The function does not return anything but plots the graph using matplotlib's pyplot. - - Parameters - ---------- - layers : dict - Layer mapping obtained from the measurement order of the pattern. - xflow : dict - Dictionary for x correction of the pattern. - zflow : dict - Dictionary for z correction of the pattern. - show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. - show_local_clifford : bool - If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. - node_distance : tuple - Distance multiplication factor between nodes for x and y directions. - figsize : tuple - Figure size of the plot. - save : bool - If True, the plot is saved as a png file. - filename : str - Filename of the saved plot. - """ - pos = self.get_pos_all_correction(layers) - pos = {k: (v[0] * node_distance[0], v[1] * node_distance[1]) for k, v in pos.items()} # Scale the layout - - if figsize is None: - figsize = self.get_figsize(layers, pos, node_distance=node_distance) - # add some padding to the right for the legend - figsize = (figsize[0] + 3.0, figsize[1]) - plt.figure(figsize=figsize) - - xzflow: dict[int, set[int]] = {} - for key, value in deepcopy(xflow).items(): - if key in xzflow: - xzflow[key] |= value - else: - xzflow[key] = value - for key, value in deepcopy(zflow).items(): - if key in xzflow: - xzflow[key] |= value - else: - xzflow[key] = value - edge_path, arrow_path = self.get_edge_path(xzflow, pos) - - for edge in edge_path: - if len(edge_path[edge]) == 2: - nx.draw_networkx_edges(self.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) - else: - curve = self._bezier_curve_linspace(edge_path[edge]) - plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) - for arrow in arrow_path: - if arrow[1] not in xflow.get(arrow[0], set()): - color = "tab:green" - elif arrow[1] not in zflow.get(arrow[0], set()): - color = "tab:red" - else: - color = "tab:brown" - if len(arrow_path[arrow]) == 2: # straight line - nx.draw_networkx_edges( - self.graph, pos, edgelist=[arrow], edge_color=color, arrowstyle="->", arrows=True - ) - else: - path = arrow_path[arrow] - GraphVisualizer._shorten_path(path) - curve = self._bezier_curve_linspace(path) - - plt.plot(curve[:, 0], curve[:, 1], c=color, linewidth=1) - plt.annotate( - "", - xy=curve[-1], - xytext=curve[-2], - arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, - ) - - self.__draw_nodes_role(pos, show_pauli_measurement) - - if show_local_clifford: - self.__draw_local_clifford(pos) - - if show_measurement_planes: - self.__draw_measurement_planes(pos) - - self._draw_labels(pos) - - # legend for arrow colors - plt.plot([], [], "k--", alpha=0.7, label="graph edge") - plt.plot([], [], color="tab:red", label="xflow") - plt.plot([], [], color="tab:green", label="zflow") - plt.plot([], [], color="tab:brown", label="xflow and zflow") - - x_min = min(pos[node][0] for node in self.graph.nodes()) # Get the minimum x coordinate - x_max = max(pos[node][0] for node in self.graph.nodes()) - y_min = min(pos[node][1] for node in self.graph.nodes()) - y_max = max(pos[node][1] for node in self.graph.nodes()) - - plt.xlim( - x_min - 0.5 * node_distance[0], x_max + 3.5 * node_distance[0] - ) # Add some padding to the left and right - plt.ylim(y_min - 0.5, y_max + 0.5) # Add some padding to the top and bottom - - plt.legend(loc="upper right", fontsize=10) - - if save: - plt.savefig(filename) - plt.show() - def get_figsize( self, l_k: Mapping[int, int] | None, @@ -765,7 +471,7 @@ def get_figsize( """ if l_k is None: if pos is None: - raise ValueError("l_k and pos cannot be both None") + raise ValueError("Figure size can only be computed given a layer mapping (l_k) or node positions (pos)") width = len({pos[node][0] for node in self.graph.nodes()}) * 0.8 else: width = (max(l_k.values()) + 1) * 0.8 @@ -774,7 +480,7 @@ def get_figsize( def get_edge_path( self, flow: Mapping[int, int | set[int]], pos: Mapping[int, _Point] - ) -> tuple[dict[int, list[_Point]], dict[_Edge, list[_Point]]]: + ) -> tuple[dict[_Edge, list[_Point]], dict[_Edge, list[_Point]]]: """ Return the path of edges and gflow arrows. @@ -792,7 +498,6 @@ def get_edge_path( arrow_path : dict dictionary of arrow paths. """ - max_iter = 5 edge_path = self.get_edge_path_wo_structure(pos) edge_set = set(self.graph.edges()) arrow_path: dict[_Edge, list[_Point]] = {} @@ -832,8 +537,6 @@ def _point_from_node(pos: Sequence[float], dist: float, angle: float) -> _Point: _point_from_node(pos[arrow[0]], 0.17, 95), ] else: - iteration = 0 - nodes = set(self.graph.nodes()) bezier_path = [pos[arrow[0]], pos[arrow[1]]] if arrow in edge_set or (arrow[1], arrow[0]) in edge_set: mid_point = ( @@ -843,38 +546,45 @@ def _point_from_node(pos: Sequence[float], dist: float, angle: float) -> _Point: if self._edge_intersects_node(pos[arrow[0]], pos[arrow[1]], mid_point, buffer=0.05): ctrl_point = self._control_point(pos[arrow[0]], pos[arrow[1]], mid_point, distance=0.2) bezier_path.insert(1, ctrl_point) - while True: - iteration += 1 - intersect = False - if iteration > max_iter: - break - ctrl_points = [] - for i in range(len(bezier_path) - 1): - start = bezier_path[i] - end = bezier_path[i + 1] - for node in nodes: - if ( - node != arrow[0] - and node != arrow[1] - and self._edge_intersects_node(start, end, pos[node]) - ): - intersect = True - ctrl_points.append( - ( - i, - self._control_point(start, end, pos[node], distance=0.6 / iteration), - ) - ) - if not intersect: - break - for i, (index, ctrl_point) in enumerate(ctrl_points): - bezier_path.insert(index + i + 1, ctrl_point) - bezier_path = self._check_path(bezier_path, pos[arrow[1]]) + bezier_path = self._find_bezier_path(arrow, bezier_path, pos) + arrow_path[arrow] = bezier_path return edge_path, arrow_path - def get_edge_path_wo_structure(self, pos: Mapping[int, _Point]) -> dict[int, list[_Point]]: + def _find_bezier_path( + self, arrow: _Edge, bezier_path: Iterable[tuple[float, float]], pos: Mapping[int, _Point] + ) -> list[_Point]: + bezier_path = list(bezier_path) + max_iter = 5 + iteration = 0 + nodes = set(self.graph.nodes()) + while True: + iteration += 1 + intersect = False + if iteration > max_iter: + break + ctrl_points = [] + for i in range(len(bezier_path) - 1): + start = bezier_path[i] + end = bezier_path[i + 1] + for node in set(nodes): + if node != arrow[0] and node != arrow[1] and self._edge_intersects_node(start, end, pos[node]): + intersect = True + ctrl_points.append( + ( + i, + self._control_point(start, end, pos[node], distance=0.6 / iteration), + ) + ) + nodes -= {node} + if not intersect: + break + for i, (index, ctrl_point) in enumerate(ctrl_points): + bezier_path.insert(index + i + 1, ctrl_point) + return self._check_path(bezier_path, pos[arrow[1]]) + + def get_edge_path_wo_structure(self, pos: Mapping[int, _Point]) -> dict[_Edge, list[_Point]]: """ Return the path of edges. @@ -888,41 +598,7 @@ def get_edge_path_wo_structure(self, pos: Mapping[int, _Point]) -> dict[int, lis edge_path : dict dictionary of edge paths. """ - max_iter = 5 - edge_path: dict[int, list[_Point]] = {} - edge_set = set(self.graph.edges()) - for edge in edge_set: - iteration = 0 - nodes = set(self.graph.nodes()) - bezier_path = [pos[edge[0]], pos[edge[1]]] - while True: - iteration += 1 - intersect = False - if iteration > max_iter: - break - ctrl_points: list[tuple[int, _Point]] = [] - for i in range(len(bezier_path) - 1): - start = bezier_path[i] - end = bezier_path[i + 1] - for node in list(nodes): - if node != edge[0] and node != edge[1] and self._edge_intersects_node(start, end, pos[node]): - intersect = True - ctrl_points.append( - ( - i, - self._control_point( - bezier_path[0], bezier_path[-1], pos[node], distance=0.6 / iteration - ), - ) - ) - nodes -= {node} - if not intersect: - break - for i, ctrl_point in enumerate(ctrl_points): - bezier_path.insert(ctrl_point[0] + i + 1, ctrl_point[1]) - bezier_path = self._check_path(bezier_path) - edge_path[edge] = bezier_path - return edge_path + return {edge: self._find_bezier_path(edge, [pos[edge[0]], pos[edge[1]]], pos) for edge in self.graph.edges()} def get_pos_from_flow(self, f: Mapping[int, set[int]], l_k: Mapping[int, int]) -> dict[int, _Point]: """ diff --git a/tests/test_generator.py b/tests/test_generator.py index b2ba81567..c9f8989d3 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -14,92 +14,93 @@ from graphix.random_objects import rand_gate if TYPE_CHECKING: + from collections.abc import Callable + from numpy.random import Generator + from graphix import Pattern -class TestGenerator: - def get_graph_pflow(self, fx_rng: Generator) -> OpenGraph: - """Create a graph which has pflow but no gflow. - - Parameters - ---------- - fx_rng : :class:`numpy.random.Generator` - See graphix.tests.conftest.py - - Returns - ------- - OpenGraph: :class:`graphix.opengraph.OpenGraph` - """ - graph: nx.Graph[int] = nx.Graph( - [(0, 2), (1, 4), (2, 3), (3, 4), (2, 5), (3, 6), (4, 7), (5, 6), (6, 7), (5, 8), (7, 9)] - ) - inputs = [1, 0] - outputs = [9, 8] - - # Heuristic mixture of Pauli and non-Pauli angles ensuring there's no gflow but there's pflow. - meas_angles: dict[int, float] = { - **dict.fromkeys(range(4), 0), - **dict(zip(range(4, 8), (2 * fx_rng.random(4)).tolist())), - } - meas_planes = dict.fromkeys(range(8), Plane.XY) - meas = {i: Measurement(angle, plane) for (i, angle), plane in zip(meas_angles.items(), meas_planes.values())} - - gf, _ = find_gflow(graph=graph, iset=set(inputs), oset=set(outputs), meas_planes=meas_planes) - pf, _ = find_pauliflow( - graph=graph, iset=set(inputs), oset=set(outputs), meas_planes=meas_planes, meas_angles=meas_angles - ) - - assert gf is None # example graph doesn't have gflow - assert pf is not None # example graph has Pauli flow - - return OpenGraph(inside=graph, inputs=inputs, outputs=outputs, measurements=meas) - - def test_pattern_generation_determinism_flow(self, fx_rng: Generator) -> None: - graph: nx.Graph[int] = nx.Graph([(0, 3), (1, 4), (2, 5), (1, 3), (2, 4), (3, 6), (4, 7), (5, 8)]) - inputs = [1, 0, 2] # non-trivial order to check order is conserved. - outputs = [7, 6, 8] - angles = dict(zip(range(6), (2 * fx_rng.random(6)).tolist())) - meas_planes = dict.fromkeys(range(6), Plane.XY) - - pattern = generate_from_graph(graph, angles, inputs, outputs, meas_planes=meas_planes) - pattern.standardize() - pattern.minimize_space() - repeats = 3 # for testing the determinism of a pattern - results = [pattern.simulate_pattern(rng=fx_rng) for _ in range(repeats)] +def example_flow(rng: Generator) -> Pattern: + graph: nx.Graph[int] = nx.Graph([(0, 3), (1, 4), (2, 5), (1, 3), (2, 4), (3, 6), (4, 7), (5, 8)]) + inputs = [1, 0, 2] # non-trivial order to check order is conserved. + outputs = [7, 6, 8] + angles = dict(zip(range(6), (2 * rng.random(6)).tolist())) + meas_planes = dict.fromkeys(range(6), Plane.XY) - for i in range(1, 3): - inner_product = np.dot(results[0].flatten(), results[i].flatten().conjugate()) - assert abs(inner_product) == pytest.approx(1) + pattern = generate_from_graph(graph, angles, inputs, outputs, meas_planes=meas_planes) + pattern.standardize() - assert pattern.input_nodes == inputs - assert pattern.output_nodes == outputs + assert pattern.input_nodes == inputs + assert pattern.output_nodes == outputs + return pattern - def test_pattern_generation_determinism_gflow(self, fx_rng: Generator) -> None: - graph: nx.Graph[int] = nx.Graph([(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (3, 6), (1, 6)]) - inputs = [3, 1, 5] - outputs = [4, 2, 6] - angles = dict(zip([1, 3, 5], (2 * fx_rng.random(3)).tolist())) - meas_planes = dict.fromkeys([1, 3, 5], Plane.XY) - pattern = generate_from_graph(graph, angles, inputs, outputs, meas_planes=meas_planes) - pattern.standardize() - pattern.minimize_space() +def example_gflow(rng: Generator) -> Pattern: + graph: nx.Graph[int] = nx.Graph([(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (3, 6), (1, 6)]) + inputs = [3, 1, 5] + outputs = [4, 2, 6] + angles = dict(zip([1, 3, 5], (2 * rng.random(3)).tolist())) + meas_planes = dict.fromkeys([1, 3, 5], Plane.XY) - repeats = 3 # for testing the determinism of a pattern - results = [pattern.simulate_pattern(rng=fx_rng) for _ in range(repeats)] + pattern = generate_from_graph(graph, angles, inputs, outputs, meas_planes=meas_planes) + pattern.standardize() - for i in range(1, 3): - inner_product = np.dot(results[0].flatten(), results[i].flatten().conjugate()) - assert abs(inner_product) == pytest.approx(1) + assert pattern.input_nodes == inputs + assert pattern.output_nodes == outputs + return pattern - assert pattern.input_nodes == inputs - assert pattern.output_nodes == outputs - def test_pattern_generation_determinism_pflow(self, fx_rng: Generator) -> None: - og = self.get_graph_pflow(fx_rng) - pattern = og.to_pattern() - pattern.standardize() +def example_graph_pflow(rng: Generator) -> OpenGraph: + """Create a graph which has pflow but no gflow. + + Parameters + ---------- + rng : :class:`numpy.random.Generator` + See graphix.tests.conftest.py + + Returns + ------- + OpenGraph: :class:`graphix.opengraph.OpenGraph` + """ + graph: nx.Graph[int] = nx.Graph( + [(0, 2), (1, 4), (2, 3), (3, 4), (2, 5), (3, 6), (4, 7), (5, 6), (6, 7), (5, 8), (7, 9)] + ) + inputs = [1, 0] + outputs = [9, 8] + + # Heuristic mixture of Pauli and non-Pauli angles ensuring there's no gflow but there's pflow. + meas_angles: dict[int, float] = { + **dict.fromkeys(range(4), 0), + **dict(zip(range(4, 8), (2 * rng.random(4)).tolist())), + } + meas_planes = dict.fromkeys(range(8), Plane.XY) + meas = {i: Measurement(angle, plane) for (i, angle), plane in zip(meas_angles.items(), meas_planes.values())} + + gf, _ = find_gflow(graph=graph, iset=set(inputs), oset=set(outputs), meas_planes=meas_planes) + pf, _ = find_pauliflow( + graph=graph, iset=set(inputs), oset=set(outputs), meas_planes=meas_planes, meas_angles=meas_angles + ) + + assert gf is None # example graph doesn't have gflow + assert pf is not None # example graph has Pauli flow + + return OpenGraph(inside=graph, inputs=inputs, outputs=outputs, measurements=meas) + + +def example_pflow(rng: Generator) -> Pattern: + og = example_graph_pflow(rng) + pattern = og.to_pattern() + pattern.standardize() + assert og.inputs == pattern.input_nodes + assert og.outputs == pattern.output_nodes + return pattern + + +class TestGenerator: + @pytest.mark.parametrize("example", [example_flow, example_gflow, example_pflow]) + def test_pattern_generation_determinism(self, example: Callable[[Generator], Pattern], fx_rng: Generator) -> None: + pattern = example(fx_rng) pattern.minimize_space() repeats = 3 # for testing the determinism of a pattern @@ -109,9 +110,6 @@ def test_pattern_generation_determinism_pflow(self, fx_rng: Generator) -> None: inner_product = np.dot(results[0].flatten(), results[i].flatten().conjugate()) assert abs(inner_product) == pytest.approx(1) - assert og.inputs == pattern.input_nodes - assert og.outputs == pattern.output_nodes - def test_pattern_generation_flow(self, fx_rng: Generator) -> None: nqubits = 3 depth = 2 @@ -147,7 +145,7 @@ def test_pattern_generation_no_internal_nodes(self) -> None: assert nx.utils.graphs_equal(graph, graph_ref) def test_pattern_generation_pflow(self, fx_rng: Generator) -> None: - og = self.get_graph_pflow(fx_rng) + og = example_graph_pflow(fx_rng) pattern = og.to_pattern() graph_generated_pattern = pattern.extract_graph() diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 02fe23d98..76098f342 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,15 +1,25 @@ from __future__ import annotations from math import pi +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import pytest -from graphix import gflow, transpiler, visualization +from graphix import Circuit, Pattern, command, gflow, visualization +from graphix.visualization import GraphVisualizer +from tests.test_generator import example_flow, example_gflow, example_pflow + +if TYPE_CHECKING: + from collections.abc import Callable + + from numpy.random import Generator def test_get_pos_from_flow() -> None: - circuit = transpiler.Circuit(1) + circuit = Circuit(1) circuit.h(0) pattern = circuit.transpile().pattern graph = pattern.extract_graph() @@ -32,8 +42,59 @@ def mock_plot(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.mark.usefixtures("mock_plot") -def test_draw_graph_flow_from_pattern() -> None: - circuit = transpiler.Circuit(3) +@pytest.mark.parametrize("example", [example_flow, example_gflow, example_pflow]) +@pytest.mark.parametrize("flow_from_pattern", [False, True]) +def test_draw_graph(example: Callable[[Generator], Pattern], flow_from_pattern: bool, fx_rng: Generator) -> None: + pattern = example(fx_rng) + pattern.draw_graph( + flow_from_pattern=flow_from_pattern, + node_distance=(0.7, 0.6), + ) + + +def example_hadamard() -> Pattern: + circuit = Circuit(1) + circuit.h(0) + return circuit.transpile().pattern + + +def example_local_clifford() -> Pattern: + pattern = example_hadamard() + pattern.perform_pauli_measurements() + return pattern + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_local_clifford() -> None: + pattern = example_local_clifford() + pattern.standardize() + pattern.draw_graph( + show_local_clifford=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_measurement_planes(fx_rng: Generator) -> None: + pattern = example_pflow(fx_rng) + pattern.draw_graph( + show_measurement_planes=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_loop(fx_rng: Generator) -> None: + pattern = example_pflow(fx_rng) + pattern.draw_graph( + show_loop=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_save() -> None: + circuit = Circuit(3) circuit.cnot(0, 1) circuit.cnot(2, 1) circuit.rx(0, pi / 3) @@ -41,4 +102,55 @@ def test_draw_graph_flow_from_pattern() -> None: circuit.cnot(2, 1) pattern = circuit.transpile().pattern pattern.perform_pauli_measurements(leave_input=True) - pattern.draw_graph(flow_from_pattern=True, show_measurement_planes=True, node_distance=(0.7, 0.6)) + with TemporaryDirectory() as dirname: + filename = Path(dirname) / "image.png" + pattern.draw_graph(node_distance=(0.7, 0.6), filename=filename) + assert filename.exists() + + +def example_visualizer() -> tuple[GraphVisualizer, Pattern]: + pattern = example_hadamard() + graph = pattern.extract_graph() + vis = GraphVisualizer(graph, pattern.input_nodes, pattern.output_nodes) + return vis, pattern + + +@pytest.mark.usefixtures("mock_plot") +def test_graph_visualizer_without_plane() -> None: + vis, pattern = example_visualizer() + vis.visualize() + vis.visualize_from_pattern(pattern) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_without_flow() -> None: + pattern = Pattern(input_nodes=[0], cmds=[command.N(1), command.E((0, 1)), command.M(0)]) + pattern.draw_graph() + + +@pytest.mark.usefixtures("mock_plot") +def test_large_node_number() -> None: + pattern = Pattern(input_nodes=[100]) + pattern.draw_graph() + + +def test_get_figsize_without_layers_or_pos() -> None: + vis, _pattern = example_visualizer() + with pytest.raises(ValueError): + vis.get_figsize(None, None) + + +def test_edge_intersects_node_equals() -> None: + vis, _pattern = example_visualizer() + assert not vis._edge_intersects_node((0, 0), (0, 0), (0, 0)) + + +@pytest.mark.usefixtures("mock_plot") +def test_custom_corrections() -> None: + pattern = Pattern( + input_nodes=[0, 1, 2, 3], + cmds=[command.M(0), command.M(1), command.X(2, {0}), command.Z(2, {0}), command.Z(3, {1})], + ) + graph = pattern.extract_graph() + vis = GraphVisualizer(graph, pattern.input_nodes, pattern.output_nodes) + vis.visualize_from_pattern(pattern) From 703ed8963a0f004541d8d3d5a376e5b9604b3147 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Wed, 5 Nov 2025 16:07:57 +0100 Subject: [PATCH 4/4] Fix plane labels and scaling --- graphix/visualization.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/graphix/visualization.py b/graphix/visualization.py index 98d7036cb..99f723029 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -241,12 +241,14 @@ def visualize_from_pattern( ) @staticmethod - def _shorten_path(path: list[_Point]) -> None: + def _shorten_path(path: Sequence[_Point]) -> list[_Point]: """Shorten the last edge not to hide arrow under the node.""" - last = np.array(path[-1]) - second_last = np.array(path[-2]) + new_path = list(path) + last = np.array(new_path[-1]) + second_last = np.array(new_path[-2]) last_edge: _Point = tuple(last - (last - second_last) / np.linalg.norm(last - second_last) * 0.2) - path[-1] = last_edge + new_path[-1] = last_edge + return new_path def _draw_labels(self, pos: Mapping[int, _Point]) -> None: fontsize = 12 @@ -286,7 +288,7 @@ def visualize_graph( self, pos: Mapping[int, _Point], edge_path: Mapping[_Edge, Sequence[_Point]], - arrow_path: Mapping[_Edge, list[_Point]] | None, + arrow_path: Mapping[_Edge, Sequence[_Point]] | None, l_k: Mapping[int, int] | None, corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None, show_pauli_measurement: bool = True, @@ -309,9 +311,9 @@ def visualize_graph( ---------- pos: Mapping[int, _Point] Node positions. - edge_path: Sequence[Mapping[int, Sequence[_Point]]] + edge_path: Mapping[int, Sequence[_Point]] Mapping of edge paths. - arrow_path: Mapping[_Edge, list[_Point]] | None + arrow_path: Mapping[_Edge, Sequence[_Point]] | None Mapping of arrow paths. l_k: Mapping[int, int] | None Layer mapping if any. @@ -333,7 +335,11 @@ def visualize_graph( If not None, filename of the png file to save the plot. If None, the plot is not saved. Default in None. """ - pos = {k: (v[0] * node_distance[0], v[1] * node_distance[1]) for k, v in pos.items()} # Scale the layout + # Scale the layout. + pos = {k: _scale_pos(v, node_distance) for k, v in pos.items()} + edge_path = {k: [_scale_pos(p, node_distance) for p in l] for k, l in edge_path.items()} + if arrow_path is not None: + arrow_path = {k: [_scale_pos(p, node_distance) for p in l] for k, l in arrow_path.items()} if figsize is None: figsize = self.get_figsize(l_k, pos, node_distance=node_distance) @@ -378,8 +384,8 @@ def visualize_graph( self.graph, pos, edgelist=[arrow], edge_color=color, arrowstyle="->", arrows=True ) else: - GraphVisualizer._shorten_path(path) - curve = self._bezier_curve_linspace(path) + new_path = GraphVisualizer._shorten_path(path) + curve = self._bezier_curve_linspace(new_path) plt.plot(curve[:, 0], curve[:, 1], c=color, linewidth=1) plt.annotate( "", @@ -444,7 +450,7 @@ def __draw_measurement_planes(self, pos: Mapping[int, _Point]) -> None: for node in self.graph.nodes(): if node in self.meas_planes: x, y = pos[node] + np.array([0.22, -0.2]) - plt.text(x, y, f"{self.meas_planes[node]}", fontsize=9, zorder=3) + plt.text(x, y, f"{self.meas_planes[node].name}", fontsize=9, zorder=3) def get_figsize( self, @@ -552,9 +558,7 @@ def _point_from_node(pos: Sequence[float], dist: float, angle: float) -> _Point: return edge_path, arrow_path - def _find_bezier_path( - self, arrow: _Edge, bezier_path: Iterable[tuple[float, float]], pos: Mapping[int, _Point] - ) -> list[_Point]: + def _find_bezier_path(self, arrow: _Edge, bezier_path: Iterable[_Point], pos: Mapping[int, _Point]) -> list[_Point]: bezier_path = list(bezier_path) max_iter = 5 iteration = 0 @@ -887,3 +891,7 @@ def _check_path(path: Iterable[_Point], target_node_pos: _Point | None = None) - def _set_node_attributes(graph: nx.Graph[_HashableT], attrs: Mapping[_HashableT, object], name: str) -> None: nx.set_node_attributes(graph, attrs, name=name) # type: ignore[arg-type] + + +def _scale_pos(pos: _Point, node_distance: tuple[float, float]) -> _Point: + return (pos[0] * node_distance[0], pos[1] * node_distance[1])