28
28
from pytensor .graph .rewriting .basic import (
29
29
GraphRewriter ,
30
30
copy_stack_trace ,
31
- in2out ,
31
+ dfs_rewriter ,
32
32
node_rewriter ,
33
33
)
34
34
from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
@@ -2548,15 +2548,15 @@ def scan_push_out_dot1(fgraph, node):
2548
2548
# ScanSaveMem should execute only once per node.
2549
2549
optdb .register (
2550
2550
"scan_save_mem_prealloc" ,
2551
- in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
2551
+ dfs_rewriter (scan_save_mem_prealloc , ignore_newtrees = True ),
2552
2552
"fast_run" ,
2553
2553
"scan" ,
2554
2554
"scan_save_mem" ,
2555
2555
position = 1.61 ,
2556
2556
)
2557
2557
optdb .register (
2558
2558
"scan_save_mem_no_prealloc" ,
2559
- in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2559
+ dfs_rewriter (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2560
2560
"numba" ,
2561
2561
"jax" ,
2562
2562
"pytorch" ,
@@ -2577,7 +2577,7 @@ def scan_push_out_dot1(fgraph, node):
2577
2577
2578
2578
scan_seqopt1 .register (
2579
2579
"scan_remove_constants_and_unused_inputs0" ,
2580
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2580
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2581
2581
"remove_constants_and_unused_inputs_scan" ,
2582
2582
"fast_run" ,
2583
2583
"scan" ,
@@ -2586,7 +2586,7 @@ def scan_push_out_dot1(fgraph, node):
2586
2586
2587
2587
scan_seqopt1 .register (
2588
2588
"scan_push_out_non_seq" ,
2589
- in2out (scan_push_out_non_seq , ignore_newtrees = True ),
2589
+ dfs_rewriter (scan_push_out_non_seq , ignore_newtrees = True ),
2590
2590
"scan_pushout_nonseqs_ops" , # For backcompat: so it can be tagged with old name
2591
2591
"fast_run" ,
2592
2592
"scan" ,
@@ -2596,7 +2596,7 @@ def scan_push_out_dot1(fgraph, node):
2596
2596
2597
2597
scan_seqopt1 .register (
2598
2598
"scan_push_out_seq" ,
2599
- in2out (scan_push_out_seq , ignore_newtrees = True ),
2599
+ dfs_rewriter (scan_push_out_seq , ignore_newtrees = True ),
2600
2600
"scan_pushout_seqs_ops" , # For backcompat: so it can be tagged with old name
2601
2601
"fast_run" ,
2602
2602
"scan" ,
@@ -2607,7 +2607,7 @@ def scan_push_out_dot1(fgraph, node):
2607
2607
2608
2608
scan_seqopt1 .register (
2609
2609
"scan_push_out_dot1" ,
2610
- in2out (scan_push_out_dot1 , ignore_newtrees = True ),
2610
+ dfs_rewriter (scan_push_out_dot1 , ignore_newtrees = True ),
2611
2611
"scan_pushout_dot1" , # For backcompat: so it can be tagged with old name
2612
2612
"fast_run" ,
2613
2613
"more_mem" ,
@@ -2620,7 +2620,7 @@ def scan_push_out_dot1(fgraph, node):
2620
2620
scan_seqopt1 .register (
2621
2621
"scan_push_out_add" ,
2622
2622
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2623
- in2out (scan_push_out_add , ignore_newtrees = False ),
2623
+ dfs_rewriter (scan_push_out_add , ignore_newtrees = False ),
2624
2624
"scan_pushout_add" , # For backcompat: so it can be tagged with old name
2625
2625
"fast_run" ,
2626
2626
"more_mem" ,
@@ -2631,22 +2631,22 @@ def scan_push_out_dot1(fgraph, node):
2631
2631
2632
2632
scan_eqopt2 .register (
2633
2633
"while_scan_merge_subtensor_last_element" ,
2634
- in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2634
+ dfs_rewriter (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2635
2635
"fast_run" ,
2636
2636
"scan" ,
2637
2637
)
2638
2638
2639
2639
scan_eqopt2 .register (
2640
2640
"constant_folding_for_scan2" ,
2641
- in2out (constant_folding , ignore_newtrees = True ),
2641
+ dfs_rewriter (constant_folding , ignore_newtrees = True ),
2642
2642
"fast_run" ,
2643
2643
"scan" ,
2644
2644
)
2645
2645
2646
2646
2647
2647
scan_eqopt2 .register (
2648
2648
"scan_remove_constants_and_unused_inputs1" ,
2649
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2649
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2650
2650
"remove_constants_and_unused_inputs_scan" ,
2651
2651
"fast_run" ,
2652
2652
"scan" ,
@@ -2661,23 +2661,23 @@ def scan_push_out_dot1(fgraph, node):
2661
2661
# After Merge optimization
2662
2662
scan_eqopt2 .register (
2663
2663
"scan_remove_constants_and_unused_inputs2" ,
2664
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2664
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2665
2665
"remove_constants_and_unused_inputs_scan" ,
2666
2666
"fast_run" ,
2667
2667
"scan" ,
2668
2668
)
2669
2669
2670
2670
scan_eqopt2 .register (
2671
2671
"scan_merge_inouts" ,
2672
- in2out (scan_merge_inouts , ignore_newtrees = True ),
2672
+ dfs_rewriter (scan_merge_inouts , ignore_newtrees = True ),
2673
2673
"fast_run" ,
2674
2674
"scan" ,
2675
2675
)
2676
2676
2677
2677
# After everything else
2678
2678
scan_eqopt2 .register (
2679
2679
"scan_remove_constants_and_unused_inputs3" ,
2680
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2680
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2681
2681
"remove_constants_and_unused_inputs_scan" ,
2682
2682
"fast_run" ,
2683
2683
"scan" ,
0 commit comments