Skip to content

Commit 262bd26

Browse files
committed
Replace uses of in2out and out2in by bfs_rewriter
1 parent 2641518 commit 262bd26

File tree

19 files changed

+93
-106
lines changed

19 files changed

+93
-106
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
from pytensor.graph.fg import FunctionGraph, Output
2828
from pytensor.graph.op import Op
2929
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
30-
from pytensor.graph.traversal import apply_toposort, applys_between, vars_between
30+
from pytensor.graph.traversal import (
31+
apply_ancestors,
32+
apply_toposort,
33+
applys_between,
34+
vars_between,
35+
)
3136
from pytensor.graph.utils import AssocList, InconsistencyError
3237
from pytensor.misc.ordered_set import OrderedSet
3338
from pytensor.utils import flatten
@@ -1819,12 +1824,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
18191824
def __init__(
18201825
self,
18211826
node_rewriter: NodeRewriter,
1822-
order: Literal["out_to_in", "in_to_out"] = "in_to_out",
1827+
order: Literal["out_to_in", "in_to_out", "bfs"] = "in_to_out",
18231828
ignore_newtrees: bool = False,
18241829
failure_callback: FailureCallbackType | None = None,
18251830
):
1826-
if order not in ("out_to_in", "in_to_out"):
1827-
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
1831+
valid_orders = ("out_to_in", "in_to_out", "bfs")
1832+
if order not in valid_orders:
1833+
raise ValueError(f"order must be one of {valid_orders}, got {order}")
18281834
self.order = order
18291835
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
18301836

@@ -1834,7 +1840,10 @@ def apply(self, fgraph, start_from=None):
18341840
callback_before = fgraph.execute_callbacks_time
18351841
nb_nodes_start = len(fgraph.apply_nodes)
18361842
t0 = time.perf_counter()
1837-
q = deque(apply_toposort(output_nodes=(o.owner for o in start_from)))
1843+
if self.order == "bfs":
1844+
q = deque(apply_ancestors(o.owner for o in start_from))
1845+
else:
1846+
q = deque(apply_toposort(o.owner for o in start_from))
18381847
io_t = time.perf_counter() - t0
18391848

18401849
def importer(node):
@@ -1957,6 +1966,7 @@ def walking_rewriter(
19571966

19581967
in2out = partial(walking_rewriter, "in_to_out")
19591968
out2in = partial(walking_rewriter, "out_to_in")
1969+
bfs_rewriter = partial(walking_rewriter, "bfs")
19601970

19611971

19621972
class ChangeTracker(Feature):

pytensor/scan/rewriting.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from pytensor.graph.replace import clone_replace
2828
from pytensor.graph.rewriting.basic import (
2929
GraphRewriter,
30+
bfs_rewriter,
3031
copy_stack_trace,
31-
in2out,
3232
node_rewriter,
3333
)
3434
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
@@ -2534,15 +2534,15 @@ def scan_push_out_dot1(fgraph, node):
25342534
# ScanSaveMem should execute only once per node.
25352535
optdb.register(
25362536
"scan_save_mem_prealloc",
2537-
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
2537+
bfs_rewriter(scan_save_mem_prealloc, ignore_newtrees=True),
25382538
"fast_run",
25392539
"scan",
25402540
"scan_save_mem",
25412541
position=1.61,
25422542
)
25432543
optdb.register(
25442544
"scan_save_mem_no_prealloc",
2545-
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
2545+
bfs_rewriter(scan_save_mem_no_prealloc, ignore_newtrees=True),
25462546
"numba",
25472547
"jax",
25482548
"pytorch",
@@ -2563,7 +2563,7 @@ def scan_push_out_dot1(fgraph, node):
25632563

25642564
scan_seqopt1.register(
25652565
"scan_remove_constants_and_unused_inputs0",
2566-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2566+
bfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
25672567
"remove_constants_and_unused_inputs_scan",
25682568
"fast_run",
25692569
"scan",
@@ -2572,7 +2572,7 @@ def scan_push_out_dot1(fgraph, node):
25722572

25732573
scan_seqopt1.register(
25742574
"scan_push_out_non_seq",
2575-
in2out(scan_push_out_non_seq, ignore_newtrees=True),
2575+
bfs_rewriter(scan_push_out_non_seq, ignore_newtrees=True),
25762576
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
25772577
"fast_run",
25782578
"scan",
@@ -2582,7 +2582,7 @@ def scan_push_out_dot1(fgraph, node):
25822582

25832583
scan_seqopt1.register(
25842584
"scan_push_out_seq",
2585-
in2out(scan_push_out_seq, ignore_newtrees=True),
2585+
bfs_rewriter(scan_push_out_seq, ignore_newtrees=True),
25862586
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
25872587
"fast_run",
25882588
"scan",
@@ -2593,7 +2593,7 @@ def scan_push_out_dot1(fgraph, node):
25932593

25942594
scan_seqopt1.register(
25952595
"scan_push_out_dot1",
2596-
in2out(scan_push_out_dot1, ignore_newtrees=True),
2596+
bfs_rewriter(scan_push_out_dot1, ignore_newtrees=True),
25972597
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name
25982598
"fast_run",
25992599
"more_mem",
@@ -2606,7 +2606,7 @@ def scan_push_out_dot1(fgraph, node):
26062606
scan_seqopt1.register(
26072607
"scan_push_out_add",
26082608
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2609-
in2out(scan_push_out_add, ignore_newtrees=False),
2609+
bfs_rewriter(scan_push_out_add, ignore_newtrees=False),
26102610
"scan_pushout_add", # For backcompat: so it can be tagged with old name
26112611
"fast_run",
26122612
"more_mem",
@@ -2617,22 +2617,22 @@ def scan_push_out_dot1(fgraph, node):
26172617

26182618
scan_eqopt2.register(
26192619
"while_scan_merge_subtensor_last_element",
2620-
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
2620+
bfs_rewriter(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
26212621
"fast_run",
26222622
"scan",
26232623
)
26242624

26252625
scan_eqopt2.register(
26262626
"constant_folding_for_scan2",
2627-
in2out(constant_folding, ignore_newtrees=True),
2627+
bfs_rewriter(constant_folding, ignore_newtrees=True),
26282628
"fast_run",
26292629
"scan",
26302630
)
26312631

26322632

26332633
scan_eqopt2.register(
26342634
"scan_remove_constants_and_unused_inputs1",
2635-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2635+
bfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26362636
"remove_constants_and_unused_inputs_scan",
26372637
"fast_run",
26382638
"scan",
@@ -2647,23 +2647,23 @@ def scan_push_out_dot1(fgraph, node):
26472647
# After Merge optimization
26482648
scan_eqopt2.register(
26492649
"scan_remove_constants_and_unused_inputs2",
2650-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2650+
bfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26512651
"remove_constants_and_unused_inputs_scan",
26522652
"fast_run",
26532653
"scan",
26542654
)
26552655

26562656
scan_eqopt2.register(
26572657
"scan_merge_inouts",
2658-
in2out(scan_merge_inouts, ignore_newtrees=True),
2658+
bfs_rewriter(scan_merge_inouts, ignore_newtrees=True),
26592659
"fast_run",
26602660
"scan",
26612661
)
26622662

26632663
# After everything else
26642664
scan_eqopt2.register(
26652665
"scan_remove_constants_and_unused_inputs3",
2666-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2666+
bfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26672667
"remove_constants_and_unused_inputs_scan",
26682668
"fast_run",
26692669
"scan",

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pytensor.compile import optdb
55
from pytensor.graph import Constant, graph_inputs
6-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
6+
from pytensor.graph.rewriting.basic import bfs_rewriter, copy_stack_trace, node_rewriter
77
from pytensor.scan.op import Scan
88
from pytensor.scan.rewriting import scan_seqopt1
99
from pytensor.tensor._linalg.solve.tridiagonal import (
@@ -243,7 +243,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
243243

244244
scan_seqopt1.register(
245245
scan_split_non_sequence_decomposition_and_solve.__name__,
246-
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
246+
bfs_rewriter(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
247247
"fast_run",
248248
"scan",
249249
"scan_pushout",
@@ -260,7 +260,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node):
260260

261261
optdb["specialize"].register(
262262
reuse_decomposition_multiple_solves_jax.__name__,
263-
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
263+
bfs_rewriter(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
264264
"jax",
265265
use_db_name_as_tag=False,
266266
)
@@ -275,7 +275,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
275275

276276
scan_seqopt1.register(
277277
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
278-
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
278+
bfs_rewriter(
279+
scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True
280+
),
279281
"jax",
280282
use_db_name_as_tag=False,
281283
position=2,

pytensor/tensor/random/rewriting/basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from pytensor.configdefaults import config
55
from pytensor.graph import ancestors
66
from pytensor.graph.op import compute_test_value
7-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
7+
from pytensor.graph.rewriting.basic import (
8+
bfs_rewriter,
9+
copy_stack_trace,
10+
node_rewriter,
11+
)
812
from pytensor.tensor import NoneConst, TensorVariable
913
from pytensor.tensor.basic import constant
1014
from pytensor.tensor.elemwise import DimShuffle
@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node):
5761

5862
optdb.register(
5963
"random_make_inplace",
60-
in2out(random_make_inplace, ignore_newtrees=True),
64+
bfs_rewriter(random_make_inplace, ignore_newtrees=True),
6165
"fast_run",
6266
"inplace",
6367
position=50.9,

pytensor/tensor/random/rewriting/jax.py

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from pytensor.compile import optdb
44
from pytensor.graph import Constant
5-
from pytensor.graph.rewriting.basic import in2out, node_rewriter
6-
from pytensor.graph.rewriting.db import SequenceDB
5+
from pytensor.graph.rewriting.basic import bfs_rewriter, in2out, node_rewriter
76
from pytensor.tensor import abs as abs_t
87
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
98
from pytensor.tensor.basic import (
@@ -179,51 +178,18 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
179178
return new_op.make_node(rng, size, a_vector_param, *other_params).outputs
180179

181180

182-
random_vars_opt = SequenceDB()
183-
random_vars_opt.register(
184-
"lognormal_from_normal",
185-
in2out(lognormal_from_normal),
186-
"jax",
187-
)
188-
random_vars_opt.register(
189-
"halfnormal_from_normal",
190-
in2out(halfnormal_from_normal),
191-
"jax",
192-
)
193-
random_vars_opt.register(
194-
"geometric_from_uniform",
195-
in2out(geometric_from_uniform),
196-
"jax",
197-
)
198-
random_vars_opt.register(
199-
"negative_binomial_from_gamma_poisson",
200-
in2out(negative_binomial_from_gamma_poisson),
201-
"jax",
202-
)
203-
random_vars_opt.register(
204-
"inverse_gamma_from_gamma",
205-
in2out(inverse_gamma_from_gamma),
206-
"jax",
207-
)
208-
random_vars_opt.register(
209-
"generalized_gamma_from_gamma",
210-
in2out(generalized_gamma_from_gamma),
211-
"jax",
212-
)
213-
random_vars_opt.register(
214-
"wald_from_normal_uniform",
215-
in2out(wald_from_normal_uniform),
216-
"jax",
217-
)
218-
random_vars_opt.register(
219-
"beta_binomial_from_beta_binomial",
220-
in2out(beta_binomial_from_beta_binomial),
221-
"jax",
222-
)
223-
random_vars_opt.register(
224-
"materialize_implicit_arange_choice_without_replacement",
225-
in2out(materialize_implicit_arange_choice_without_replacement),
226-
"jax",
181+
random_vars_opt = bfs_rewriter(
182+
[
183+
lognormal_from_normal,
184+
halfnormal_from_normal,
185+
geometric_from_uniform,
186+
negative_binomial_from_gamma_poisson,
187+
inverse_gamma_from_gamma,
188+
generalized_gamma_from_gamma,
189+
wald_from_normal_uniform,
190+
beta_binomial_from_beta_binomial,
191+
materialize_implicit_arange_choice_without_replacement,
192+
]
227193
)
228194
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
229195

pytensor/tensor/random/rewriting/numba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pytensor.compile import optdb
22
from pytensor.graph import node_rewriter
3-
from pytensor.graph.rewriting.basic import out2in
3+
from pytensor.graph.rewriting.basic import bfs_rewriter
44
from pytensor.tensor import as_tensor, constant
55
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
66
from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
8282

8383
optdb.register(
8484
introduce_explicit_core_shape_rv.__name__,
85-
out2in(introduce_explicit_core_shape_rv),
85+
bfs_rewriter(introduce_explicit_core_shape_rv),
8686
"numba",
8787
position=100,
8888
)

pytensor/tensor/rewriting/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
NodeProcessingGraphRewriter,
3636
NodeRewriter,
3737
Rewriter,
38+
bfs_rewriter,
3839
copy_stack_trace,
3940
in2out,
4041
node_rewriter,
@@ -518,7 +519,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
518519

519520
compile.optdb.register(
520521
"local_alloc_empty_to_zeros",
521-
in2out(local_alloc_empty_to_zeros),
522+
bfs_rewriter(local_alloc_empty_to_zeros),
522523
# After move to gpu and merge2, before inplace.
523524
"alloc_empty_to_zeros",
524525
position=49.3,

pytensor/tensor/rewriting/blas.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@
7676
from pytensor.graph.rewriting.basic import (
7777
EquilibriumGraphRewriter,
7878
GraphRewriter,
79+
bfs_rewriter,
7980
copy_stack_trace,
80-
in2out,
8181
node_rewriter,
8282
)
8383
from pytensor.graph.rewriting.db import SequenceDB
@@ -460,7 +460,8 @@ def apply(self, fgraph):
460460
callbacks_before = fgraph.execute_callbacks_times.copy()
461461
callback_before = fgraph.execute_callbacks_time
462462

463-
nodelist = apply_toposort(o.owner for o in fgraph.outputs)
463+
nodelist = list(apply_toposort(o.owner for o in fgraph.outputs))
464+
nodelist.reverse()
464465

465466
def on_import(new_node):
466467
if new_node is not node:
@@ -475,7 +476,6 @@ def on_import(new_node):
475476
t0 = time.perf_counter()
476477
time_toposort += time.perf_counter() - t0
477478
did_something = False
478-
nodelist.reverse()
479479
for node in nodelist:
480480
if not (
481481
isinstance(node.op, Elemwise)
@@ -721,7 +721,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
721721
# fast_compile is needed to have GpuDot22 created.
722722
blas_optdb.register(
723723
"local_dot_to_dot22",
724-
in2out(local_dot_to_dot22),
724+
bfs_rewriter(local_dot_to_dot22),
725725
"fast_run",
726726
"fast_compile",
727727
position=0,
@@ -744,7 +744,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
744744
)
745745

746746

747-
blas_opt_inplace = in2out(
747+
blas_opt_inplace = bfs_rewriter(
748748
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
749749
)
750750
optdb.register(
@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
883883
# dot22scalar and gemm give more speed up then dot22scalar
884884
blas_optdb.register(
885885
"local_dot22_to_dot22scalar",
886-
in2out(local_dot22_to_dot22scalar),
886+
bfs_rewriter(local_dot22_to_dot22scalar),
887887
"fast_run",
888888
position=12,
889889
)

0 commit comments

Comments
 (0)