Skip to content

Commit a21e3ee

Browse files
committed
Replace uses of in2out and out2in by a depth-first search rewriter
1 parent 68db6d2 commit a21e3ee

File tree

17 files changed

+86
-101
lines changed

17 files changed

+86
-101
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
from pytensor.graph.fg import FunctionGraph, Output
2828
from pytensor.graph.op import Op
2929
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
30-
from pytensor.graph.traversal import applys_between, toposort, vars_between
30+
from pytensor.graph.traversal import (
31+
apply_ancestors,
32+
applys_between,
33+
toposort,
34+
vars_between,
35+
)
3136
from pytensor.graph.utils import AssocList, InconsistencyError
3237
from pytensor.misc.ordered_set import OrderedSet
3338
from pytensor.utils import flatten
@@ -1819,12 +1824,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
18191824
def __init__(
18201825
self,
18211826
node_rewriter: NodeRewriter,
1822-
order: Literal["out_to_in", "in_to_out"] = "in_to_out",
1827+
order: Literal["out_to_in", "in_to_out", "dfs"] = "in_to_out",
18231828
ignore_newtrees: bool = False,
18241829
failure_callback: FailureCallbackType | None = None,
18251830
):
1826-
if order not in ("out_to_in", "in_to_out"):
1827-
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
1831+
valid_orders = ("out_to_in", "in_to_out", "dfs")
1832+
if order not in valid_orders:
1833+
raise ValueError(f"order must be one of {valid_orders}, got {order}")
18281834
self.order = order
18291835
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
18301836

@@ -1834,7 +1840,11 @@ def apply(self, fgraph, start_from=None):
18341840
callback_before = fgraph.execute_callbacks_time
18351841
nb_nodes_start = len(fgraph.apply_nodes)
18361842
t0 = time.perf_counter()
1837-
q = deque(toposort(start_from))
1843+
q = deque(
1844+
apply_ancestors(start_from)
1845+
if (self.order == "dfs")
1846+
else toposort(start_from)
1847+
)
18381848
io_t = time.perf_counter() - t0
18391849

18401850
def importer(node):
@@ -1957,6 +1967,7 @@ def walking_rewriter(
19571967

19581968
in2out = partial(walking_rewriter, "in_to_out")
19591969
out2in = partial(walking_rewriter, "out_to_in")
1970+
dfs_rewriter = partial(walking_rewriter, "dfs")
19601971

19611972

19621973
class ChangeTracker(Feature):

pytensor/scan/rewriting.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytensor.graph.rewriting.basic import (
2929
GraphRewriter,
3030
copy_stack_trace,
31-
in2out,
31+
dfs_rewriter,
3232
node_rewriter,
3333
)
3434
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
@@ -2548,15 +2548,15 @@ def scan_push_out_dot1(fgraph, node):
25482548
# ScanSaveMem should execute only once per node.
25492549
optdb.register(
25502550
"scan_save_mem_prealloc",
2551-
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
2551+
dfs_rewriter(scan_save_mem_prealloc, ignore_newtrees=True),
25522552
"fast_run",
25532553
"scan",
25542554
"scan_save_mem",
25552555
position=1.61,
25562556
)
25572557
optdb.register(
25582558
"scan_save_mem_no_prealloc",
2559-
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
2559+
dfs_rewriter(scan_save_mem_no_prealloc, ignore_newtrees=True),
25602560
"numba",
25612561
"jax",
25622562
"pytorch",
@@ -2577,7 +2577,7 @@ def scan_push_out_dot1(fgraph, node):
25772577

25782578
scan_seqopt1.register(
25792579
"scan_remove_constants_and_unused_inputs0",
2580-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2580+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
25812581
"remove_constants_and_unused_inputs_scan",
25822582
"fast_run",
25832583
"scan",
@@ -2586,7 +2586,7 @@ def scan_push_out_dot1(fgraph, node):
25862586

25872587
scan_seqopt1.register(
25882588
"scan_push_out_non_seq",
2589-
in2out(scan_push_out_non_seq, ignore_newtrees=True),
2589+
dfs_rewriter(scan_push_out_non_seq, ignore_newtrees=True),
25902590
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
25912591
"fast_run",
25922592
"scan",
@@ -2596,7 +2596,7 @@ def scan_push_out_dot1(fgraph, node):
25962596

25972597
scan_seqopt1.register(
25982598
"scan_push_out_seq",
2599-
in2out(scan_push_out_seq, ignore_newtrees=True),
2599+
dfs_rewriter(scan_push_out_seq, ignore_newtrees=True),
26002600
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
26012601
"fast_run",
26022602
"scan",
@@ -2607,7 +2607,7 @@ def scan_push_out_dot1(fgraph, node):
26072607

26082608
scan_seqopt1.register(
26092609
"scan_push_out_dot1",
2610-
in2out(scan_push_out_dot1, ignore_newtrees=True),
2610+
dfs_rewriter(scan_push_out_dot1, ignore_newtrees=True),
26112611
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name
26122612
"fast_run",
26132613
"more_mem",
@@ -2620,7 +2620,7 @@ def scan_push_out_dot1(fgraph, node):
26202620
scan_seqopt1.register(
26212621
"scan_push_out_add",
26222622
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2623-
in2out(scan_push_out_add, ignore_newtrees=False),
2623+
dfs_rewriter(scan_push_out_add, ignore_newtrees=False),
26242624
"scan_pushout_add", # For backcompat: so it can be tagged with old name
26252625
"fast_run",
26262626
"more_mem",
@@ -2631,22 +2631,22 @@ def scan_push_out_dot1(fgraph, node):
26312631

26322632
scan_eqopt2.register(
26332633
"while_scan_merge_subtensor_last_element",
2634-
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
2634+
dfs_rewriter(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
26352635
"fast_run",
26362636
"scan",
26372637
)
26382638

26392639
scan_eqopt2.register(
26402640
"constant_folding_for_scan2",
2641-
in2out(constant_folding, ignore_newtrees=True),
2641+
dfs_rewriter(constant_folding, ignore_newtrees=True),
26422642
"fast_run",
26432643
"scan",
26442644
)
26452645

26462646

26472647
scan_eqopt2.register(
26482648
"scan_remove_constants_and_unused_inputs1",
2649-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2649+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26502650
"remove_constants_and_unused_inputs_scan",
26512651
"fast_run",
26522652
"scan",
@@ -2661,23 +2661,23 @@ def scan_push_out_dot1(fgraph, node):
26612661
# After Merge optimization
26622662
scan_eqopt2.register(
26632663
"scan_remove_constants_and_unused_inputs2",
2664-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2664+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26652665
"remove_constants_and_unused_inputs_scan",
26662666
"fast_run",
26672667
"scan",
26682668
)
26692669

26702670
scan_eqopt2.register(
26712671
"scan_merge_inouts",
2672-
in2out(scan_merge_inouts, ignore_newtrees=True),
2672+
dfs_rewriter(scan_merge_inouts, ignore_newtrees=True),
26732673
"fast_run",
26742674
"scan",
26752675
)
26762676

26772677
# After everything else
26782678
scan_eqopt2.register(
26792679
"scan_remove_constants_and_unused_inputs3",
2680-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2680+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26812681
"remove_constants_and_unused_inputs_scan",
26822682
"fast_run",
26832683
"scan",

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pytensor.compile import optdb
55
from pytensor.graph import Constant, graph_inputs
6-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
6+
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter
77
from pytensor.scan.op import Scan
88
from pytensor.scan.rewriting import scan_seqopt1
99
from pytensor.tensor._linalg.solve.tridiagonal import (
@@ -243,7 +243,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
243243

244244
scan_seqopt1.register(
245245
scan_split_non_sequence_decomposition_and_solve.__name__,
246-
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
246+
dfs_rewriter(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
247247
"fast_run",
248248
"scan",
249249
"scan_pushout",
@@ -260,7 +260,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node):
260260

261261
optdb["specialize"].register(
262262
reuse_decomposition_multiple_solves_jax.__name__,
263-
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
263+
dfs_rewriter(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
264264
"jax",
265265
use_db_name_as_tag=False,
266266
)
@@ -275,7 +275,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
275275

276276
scan_seqopt1.register(
277277
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
278-
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
278+
dfs_rewriter(
279+
scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True
280+
),
279281
"jax",
280282
use_db_name_as_tag=False,
281283
position=2,

pytensor/tensor/random/rewriting/basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from pytensor.configdefaults import config
55
from pytensor.graph import ancestors
66
from pytensor.graph.op import compute_test_value
7-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
7+
from pytensor.graph.rewriting.basic import (
8+
copy_stack_trace,
9+
dfs_rewriter,
10+
node_rewriter,
11+
)
812
from pytensor.tensor import NoneConst, TensorVariable
913
from pytensor.tensor.basic import constant
1014
from pytensor.tensor.elemwise import DimShuffle
@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node):
5761

5862
optdb.register(
5963
"random_make_inplace",
60-
in2out(random_make_inplace, ignore_newtrees=True),
64+
dfs_rewriter(random_make_inplace, ignore_newtrees=True),
6165
"fast_run",
6266
"inplace",
6367
position=50.9,

pytensor/tensor/random/rewriting/jax.py

Lines changed: 11 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from pytensor.compile import optdb
44
from pytensor.graph import Constant
5-
from pytensor.graph.rewriting.basic import in2out, node_rewriter
6-
from pytensor.graph.rewriting.db import SequenceDB
5+
from pytensor.graph.rewriting.basic import dfs_rewriter, in2out, node_rewriter
76
from pytensor.tensor import abs as abs_t
87
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
98
from pytensor.tensor.basic import (
@@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
179178
return new_op.make_node(rng, size, a_vector_param, *other_params).outputs
180179

181180

182-
random_vars_opt = SequenceDB()
183-
random_vars_opt.register(
184-
"lognormal_from_normal",
185-
in2out(lognormal_from_normal),
186-
"jax",
187-
)
188-
random_vars_opt.register(
189-
"halfnormal_from_normal",
190-
in2out(halfnormal_from_normal),
191-
"jax",
192-
)
193-
random_vars_opt.register(
194-
"geometric_from_uniform",
195-
in2out(geometric_from_uniform),
196-
"jax",
197-
)
198-
random_vars_opt.register(
199-
"negative_binomial_from_gamma_poisson",
200-
in2out(negative_binomial_from_gamma_poisson),
201-
"jax",
202-
)
203-
random_vars_opt.register(
204-
"inverse_gamma_from_gamma",
205-
in2out(inverse_gamma_from_gamma),
206-
"jax",
207-
)
208-
random_vars_opt.register(
209-
"generalized_gamma_from_gamma",
210-
in2out(generalized_gamma_from_gamma),
211-
"jax",
212-
)
213-
random_vars_opt.register(
214-
"wald_from_normal_uniform",
215-
in2out(wald_from_normal_uniform),
216-
"jax",
217-
)
218-
random_vars_opt.register(
219-
"beta_binomial_from_beta_binomial",
220-
in2out(beta_binomial_from_beta_binomial),
221-
"jax",
222-
)
223-
random_vars_opt.register(
224-
"materialize_implicit_arange_choice_without_replacement",
225-
in2out(materialize_implicit_arange_choice_without_replacement),
226-
"jax",
181+
random_vars_opt = dfs_rewriter(
182+
lognormal_from_normal,
183+
halfnormal_from_normal,
184+
geometric_from_uniform,
185+
negative_binomial_from_gamma_poisson,
186+
inverse_gamma_from_gamma,
187+
generalized_gamma_from_gamma,
188+
wald_from_normal_uniform,
189+
beta_binomial_from_beta_binomial,
190+
materialize_implicit_arange_choice_without_replacement,
227191
)
228192
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
229193

pytensor/tensor/random/rewriting/numba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pytensor.compile import optdb
22
from pytensor.graph import node_rewriter
3-
from pytensor.graph.rewriting.basic import out2in
3+
from pytensor.graph.rewriting.basic import dfs_rewriter
44
from pytensor.tensor import as_tensor, constant
55
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
66
from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
8282

8383
optdb.register(
8484
introduce_explicit_core_shape_rv.__name__,
85-
out2in(introduce_explicit_core_shape_rv),
85+
dfs_rewriter(introduce_explicit_core_shape_rv),
8686
"numba",
8787
position=100,
8888
)

pytensor/tensor/rewriting/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
NodeRewriter,
3737
Rewriter,
3838
copy_stack_trace,
39+
dfs_rewriter,
3940
in2out,
4041
node_rewriter,
4142
)
@@ -518,7 +519,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
518519

519520
compile.optdb.register(
520521
"local_alloc_empty_to_zeros",
521-
in2out(local_alloc_empty_to_zeros),
522+
dfs_rewriter(local_alloc_empty_to_zeros),
522523
# After move to gpu and merge2, before inplace.
523524
"alloc_empty_to_zeros",
524525
position=49.3,

pytensor/tensor/rewriting/blas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
EquilibriumGraphRewriter,
7878
GraphRewriter,
7979
copy_stack_trace,
80-
in2out,
80+
dfs_rewriter,
8181
node_rewriter,
8282
)
8383
from pytensor.graph.rewriting.db import SequenceDB
@@ -721,7 +721,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
721721
# fast_compile is needed to have GpuDot22 created.
722722
blas_optdb.register(
723723
"local_dot_to_dot22",
724-
in2out(local_dot_to_dot22),
724+
dfs_rewriter(local_dot_to_dot22),
725725
"fast_run",
726726
"fast_compile",
727727
position=0,
@@ -744,7 +744,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
744744
)
745745

746746

747-
blas_opt_inplace = in2out(
747+
blas_opt_inplace = dfs_rewriter(
748748
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
749749
)
750750
optdb.register(
@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
883883
# dot22scalar and gemm give more speed up then dot22scalar
884884
blas_optdb.register(
885885
"local_dot22_to_dot22scalar",
886-
in2out(local_dot22_to_dot22scalar),
886+
dfs_rewriter(local_dot22_to_dot22scalar),
887887
"fast_run",
888888
position=12,
889889
)

0 commit comments

Comments
 (0)