Skip to content

Commit a7c3080

Browse files
ColbyDeLisleRafael Haenel
andauthored
feat: Small update to the Python API (zxcalc#76)
This PR contains a few small updates to the Python API for convenience. Essentially it simply exposes the following methods: * `VecGraph`: * `.adjoint` * `.plug` * `.clone` * `Decomposer`: * `.done` * `.save` * `.decomp_parallel` --------- Co-authored-by: Rafael Haenel <[email protected]>
1 parent c13bf87 commit a7c3080

File tree

7 files changed

+72
-21
lines changed

7 files changed

+72
-21
lines changed

pybindings/quizx/__init__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1-
from . import _quizx, simplify
1+
from . import simplify
22
from .graph import VecGraph
3-
from .circuit import Circuit
3+
from .circuit import Circuit, extract_circuit
44
from .decompose import Decomposer
55
from ._quizx import Scalar
66

7-
__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer", "Scalar"]
8-
9-
10-
def extract_circuit(g):
11-
c = Circuit()
12-
c._c = _quizx.extract_circuit(g._g)
13-
return c
7+
__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer", "Scalar", "extract_circuit"]

pybindings/quizx/_quizx.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class VecGraph:
6565
def outputs(self) -> list[int]: ...
6666
def num_outputs(self) -> int: ...
6767
def set_outputs(self, outputs: list[int]) -> None: ...
68+
def adjoint(self) -> None: ...
69+
def plug(self, other: "VecGraph") -> None: ...
70+
def clone(self) -> "VecGraph": ...
6871

6972
@final
7073
class Circuit:
@@ -95,10 +98,13 @@ class Decomposer:
9598
def empty() -> Decomposer: ...
9699
def __init__(self, g: VecGraph) -> None: ...
97100
def graphs(self) -> list[VecGraph]: ...
101+
def done(self) -> list[VecGraph]: ...
102+
def save(self, b: bool) -> None: ...
98103
def apply_optimizations(self, b: bool) -> None: ...
99104
def max_terms(self) -> int: ...
100105
def decomp_top(self) -> None: ...
101106
def decomp_all(self) -> None: ...
107+
def decomp_parallel(self, depth: int) -> None: ...
102108
def decomp_until_depth(self, depth: int) -> None: ...
103109
def use_cats(self, b: bool) -> None: ...
104110
def get_nterms(self) -> int: ...

pybindings/quizx/circuit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
from .graph import VecGraph
33

44

5+
def extract_circuit(g: VecGraph) -> "Circuit":
6+
c = Circuit()
7+
c._c = _quizx.extract_circuit(g.get_raw_graph())
8+
return c
9+
10+
511
class Circuit:
612
def __init__(self):
713
self._c = None

pybindings/quizx/decompose.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Optional
22

33
from . import _quizx
44
from .graph import VecGraph
@@ -15,8 +15,14 @@ def __init__(self, graph: Optional[VecGraph] = None):
1515
else:
1616
self._d = _quizx.Decomposer(graph.get_raw_graph())
1717

18-
def graphs(self) -> List[VecGraph]:
19-
return [VecGraph(g) for g in self._d.graphs()]
18+
def graphs(self) -> list[VecGraph]:
19+
return [VecGraph.from_raw_graph(g) for g in self._d.graphs()]
20+
21+
def done(self) -> list[VecGraph]:
22+
return [VecGraph.from_raw_graph(g) for g in self._d.done()]
23+
24+
def save(self, b: bool):
25+
self._d.save(b)
2026

2127
def apply_optimizations(self, b: bool):
2228
self._d.apply_optimizations(b)
@@ -30,6 +36,9 @@ def decomp_top(self):
3036
def decomp_all(self):
3137
self._d.decomp_all()
3238

39+
def decomp_parallel(self, depth: int = 4):
40+
self._d.decomp_parallel(depth)
41+
3342
def decomp_until_depth(self, depth: int):
3443
self._d.decomp_until_depth(depth)
3544

pybindings/quizx/graph.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .scalar import from_pyzx_scalar, to_pyzx_scalar
1818
from fractions import Fraction
19-
from typing import Tuple, Dict, Any, Optional
19+
from typing import Tuple, Dict, Any
2020
from pyzx.graph.base import BaseGraph # type: ignore
2121
from pyzx.utils import VertexType, EdgeType # type: ignore
2222
from pyzx.graph.scalar import Scalar
@@ -32,12 +32,9 @@ class VecGraph(BaseGraph[int, Tuple[int, int]]):
3232

3333
# The documentation of what these methods do
3434
# can be found in base.BaseGraph
35-
def __init__(self, rust_graph: Optional[_quizx.VecGraph] = None):
36-
if rust_graph:
37-
self._g = rust_graph
38-
else:
39-
self._g = _quizx.VecGraph()
40-
BaseGraph.__init__(self)
35+
def __init__(self) -> None:
36+
self._g = _quizx.VecGraph()
37+
super().__init__()
4138
self._vdata: Dict[int, Any] = dict()
4239

4340
def get_raw_graph(self) -> _quizx.VecGraph:
@@ -172,7 +169,7 @@ def vertices_in_range(self, start, end):
172169
for v in self.vertices():
173170
if not start < v < end:
174171
continue
175-
if all(start < v2 < end for v2 in self.graph[v]):
172+
if all(start < v2 < end for v2 in self.neighbors(v)):
176173
yield v
177174

178175
def edges(self):
@@ -328,3 +325,15 @@ def scalar(self, s: Scalar):
328325

329326
def is_ground(self, vertex):
330327
return False
328+
329+
def adjoint(self):
330+
self._g.adjoint()
331+
332+
def plug(self, other: "VecGraph"):
333+
if other._g is self._g:
334+
self._g.plug(other._g.clone())
335+
else:
336+
self._g.plug(other._g)
337+
338+
def clone(self) -> "VecGraph":
339+
return VecGraph.from_raw_graph(self._g.clone())

pybindings/src/lib.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,18 @@ impl VecGraph {
313313
fn set_scalar(&mut self, scalar: Scalar) {
314314
*self.g.scalar_mut() = scalar.into();
315315
}
316+
317+
fn adjoint(&mut self) {
318+
self.g.adjoint()
319+
}
320+
321+
fn plug(&mut self, other: &VecGraph) {
322+
self.g.plug(&other.g);
323+
}
324+
325+
fn clone(&self) -> VecGraph {
326+
VecGraph { g: self.g.clone() }
327+
}
316328
}
317329

318330
#[pyclass]
@@ -354,6 +366,18 @@ impl Decomposer {
354366
Ok(gs)
355367
}
356368

369+
fn done(&self) -> PyResult<Vec<VecGraph>> {
370+
let mut gs = vec![];
371+
for g in &self.d.done {
372+
gs.push(VecGraph { g: g.clone() });
373+
}
374+
Ok(gs)
375+
}
376+
377+
fn save(&mut self, b: bool) {
378+
self.d.save(b);
379+
}
380+
357381
fn apply_optimizations(&mut self, b: bool) {
358382
if b {
359383
self.d.with_simp(quizx::decompose::SimpFunc::FullSimp);
@@ -374,6 +398,9 @@ impl Decomposer {
374398
fn decomp_until_depth(&mut self, depth: usize) {
375399
self.d.decomp_until_depth(depth);
376400
}
401+
fn decomp_parallel(&mut self, depth: usize) {
402+
self.d = self.d.clone().decomp_parallel(depth);
403+
}
377404
fn use_cats(&mut self, b: bool) {
378405
self.d.use_cats(b);
379406
}

quizx/src/json/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl JsonScalar {
4747

4848
// In the Clifford+T case where we have Scalar4, we can extract factors of sqrt(2) directly from the
4949
// coefficients. Since the coefficients are reduced, sqrt(2) is represented as
50-
// [1, 0, +-1, 0], [0, 1, +-1, 0], where the +- lead to phase contributions already extracted in `phase`
50+
// [1, 0, +-1, 0], [0, 1, 0, +-1], where the +- lead to phase contributions already extracted in `phase`
5151
let (power_sqrt2, floatfactor) =
5252
match coeffs.iter_coeffs().collect::<Vec<_>>().as_slice() {
5353
[a, 0, b, 0] | [0, a, 0, b]

0 commit comments

Comments
 (0)