29
29
EquilibriumGraphRewriter ,
30
30
GraphRewriter ,
31
31
copy_stack_trace ,
32
- in2out ,
32
+ dfs_rewriter ,
33
33
node_rewriter ,
34
34
)
35
35
from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
@@ -2558,15 +2558,15 @@ def apply(self, fgraph, start_from=None):
2558
2558
# ScanSaveMem should execute only once per node.
2559
2559
optdb .register (
2560
2560
"scan_save_mem_prealloc" ,
2561
- in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
2561
+ dfs_rewriter (scan_save_mem_prealloc , ignore_newtrees = True ),
2562
2562
"fast_run" ,
2563
2563
"scan" ,
2564
2564
"scan_save_mem" ,
2565
2565
position = 1.61 ,
2566
2566
)
2567
2567
optdb .register (
2568
2568
"scan_save_mem_no_prealloc" ,
2569
- in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2569
+ dfs_rewriter (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2570
2570
"numba" ,
2571
2571
"jax" ,
2572
2572
"pytorch" ,
@@ -2587,7 +2587,7 @@ def apply(self, fgraph, start_from=None):
2587
2587
2588
2588
scan_seqopt1 .register (
2589
2589
"scan_remove_constants_and_unused_inputs0" ,
2590
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2590
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2591
2591
"remove_constants_and_unused_inputs_scan" ,
2592
2592
"fast_run" ,
2593
2593
"scan" ,
@@ -2596,7 +2596,7 @@ def apply(self, fgraph, start_from=None):
2596
2596
2597
2597
scan_seqopt1 .register (
2598
2598
"scan_push_out_non_seq" ,
2599
- in2out (scan_push_out_non_seq , ignore_newtrees = True ),
2599
+ dfs_rewriter (scan_push_out_non_seq , ignore_newtrees = True ),
2600
2600
"scan_pushout_nonseqs_ops" , # For backcompat: so it can be tagged with old name
2601
2601
"fast_run" ,
2602
2602
"scan" ,
@@ -2606,7 +2606,7 @@ def apply(self, fgraph, start_from=None):
2606
2606
2607
2607
scan_seqopt1 .register (
2608
2608
"scan_push_out_seq" ,
2609
- in2out (scan_push_out_seq , ignore_newtrees = True ),
2609
+ dfs_rewriter (scan_push_out_seq , ignore_newtrees = True ),
2610
2610
"scan_pushout_seqs_ops" , # For backcompat: so it can be tagged with old name
2611
2611
"fast_run" ,
2612
2612
"scan" ,
@@ -2617,7 +2617,7 @@ def apply(self, fgraph, start_from=None):
2617
2617
2618
2618
scan_seqopt1 .register (
2619
2619
"scan_push_out_dot1" ,
2620
- in2out (scan_push_out_dot1 , ignore_newtrees = True ),
2620
+ dfs_rewriter (scan_push_out_dot1 , ignore_newtrees = True ),
2621
2621
"scan_pushout_dot1" , # For backcompat: so it can be tagged with old name
2622
2622
"fast_run" ,
2623
2623
"more_mem" ,
@@ -2630,7 +2630,7 @@ def apply(self, fgraph, start_from=None):
2630
2630
scan_seqopt1 .register (
2631
2631
"scan_push_out_add" ,
2632
2632
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2633
- in2out (scan_push_out_add , ignore_newtrees = False ),
2633
+ dfs_rewriter (scan_push_out_add , ignore_newtrees = False ),
2634
2634
"scan_pushout_add" , # For backcompat: so it can be tagged with old name
2635
2635
"fast_run" ,
2636
2636
"more_mem" ,
@@ -2641,22 +2641,22 @@ def apply(self, fgraph, start_from=None):
2641
2641
2642
2642
scan_eqopt2 .register (
2643
2643
"while_scan_merge_subtensor_last_element" ,
2644
- in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2644
+ dfs_rewriter (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2645
2645
"fast_run" ,
2646
2646
"scan" ,
2647
2647
)
2648
2648
2649
2649
scan_eqopt2 .register (
2650
2650
"constant_folding_for_scan2" ,
2651
- in2out (constant_folding , ignore_newtrees = True ),
2651
+ dfs_rewriter (constant_folding , ignore_newtrees = True ),
2652
2652
"fast_run" ,
2653
2653
"scan" ,
2654
2654
)
2655
2655
2656
2656
2657
2657
scan_eqopt2 .register (
2658
2658
"scan_remove_constants_and_unused_inputs1" ,
2659
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2659
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2660
2660
"remove_constants_and_unused_inputs_scan" ,
2661
2661
"fast_run" ,
2662
2662
"scan" ,
@@ -2671,23 +2671,23 @@ def apply(self, fgraph, start_from=None):
2671
2671
# After Merge optimization
2672
2672
scan_eqopt2 .register (
2673
2673
"scan_remove_constants_and_unused_inputs2" ,
2674
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2674
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2675
2675
"remove_constants_and_unused_inputs_scan" ,
2676
2676
"fast_run" ,
2677
2677
"scan" ,
2678
2678
)
2679
2679
2680
2680
scan_eqopt2 .register (
2681
2681
"scan_merge_inouts" ,
2682
- in2out (scan_merge_inouts , ignore_newtrees = True ),
2682
+ dfs_rewriter (scan_merge_inouts , ignore_newtrees = True ),
2683
2683
"fast_run" ,
2684
2684
"scan" ,
2685
2685
)
2686
2686
2687
2687
# After everything else
2688
2688
scan_eqopt2 .register (
2689
2689
"scan_remove_constants_and_unused_inputs3" ,
2690
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2690
+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2691
2691
"remove_constants_and_unused_inputs_scan" ,
2692
2692
"fast_run" ,
2693
2693
"scan" ,
0 commit comments