27
27
from pytensor .graph .replace import clone_replace
28
28
from pytensor .graph .rewriting .basic import (
29
29
GraphRewriter ,
30
+ bfs_rewriter ,
30
31
copy_stack_trace ,
31
- in2out ,
32
32
node_rewriter ,
33
33
)
34
34
from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
@@ -2534,15 +2534,15 @@ def scan_push_out_dot1(fgraph, node):
2534
2534
# ScanSaveMem should execute only once per node.
2535
2535
optdb .register (
2536
2536
"scan_save_mem_prealloc" ,
2537
- in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
2537
+ bfs_rewriter (scan_save_mem_prealloc , ignore_newtrees = True ),
2538
2538
"fast_run" ,
2539
2539
"scan" ,
2540
2540
"scan_save_mem" ,
2541
2541
position = 1.61 ,
2542
2542
)
2543
2543
optdb .register (
2544
2544
"scan_save_mem_no_prealloc" ,
2545
- in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2545
+ bfs_rewriter (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2546
2546
"numba" ,
2547
2547
"jax" ,
2548
2548
"pytorch" ,
@@ -2563,7 +2563,7 @@ def scan_push_out_dot1(fgraph, node):
2563
2563
2564
2564
scan_seqopt1 .register (
2565
2565
"scan_remove_constants_and_unused_inputs0" ,
2566
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2566
+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2567
2567
"remove_constants_and_unused_inputs_scan" ,
2568
2568
"fast_run" ,
2569
2569
"scan" ,
@@ -2572,7 +2572,7 @@ def scan_push_out_dot1(fgraph, node):
2572
2572
2573
2573
scan_seqopt1 .register (
2574
2574
"scan_push_out_non_seq" ,
2575
- in2out (scan_push_out_non_seq , ignore_newtrees = True ),
2575
+ bfs_rewriter (scan_push_out_non_seq , ignore_newtrees = True ),
2576
2576
"scan_pushout_nonseqs_ops" , # For backcompat: so it can be tagged with old name
2577
2577
"fast_run" ,
2578
2578
"scan" ,
@@ -2582,7 +2582,7 @@ def scan_push_out_dot1(fgraph, node):
2582
2582
2583
2583
scan_seqopt1 .register (
2584
2584
"scan_push_out_seq" ,
2585
- in2out (scan_push_out_seq , ignore_newtrees = True ),
2585
+ bfs_rewriter (scan_push_out_seq , ignore_newtrees = True ),
2586
2586
"scan_pushout_seqs_ops" , # For backcompat: so it can be tagged with old name
2587
2587
"fast_run" ,
2588
2588
"scan" ,
@@ -2593,7 +2593,7 @@ def scan_push_out_dot1(fgraph, node):
2593
2593
2594
2594
scan_seqopt1 .register (
2595
2595
"scan_push_out_dot1" ,
2596
- in2out (scan_push_out_dot1 , ignore_newtrees = True ),
2596
+ bfs_rewriter (scan_push_out_dot1 , ignore_newtrees = True ),
2597
2597
"scan_pushout_dot1" , # For backcompat: so it can be tagged with old name
2598
2598
"fast_run" ,
2599
2599
"more_mem" ,
@@ -2606,7 +2606,7 @@ def scan_push_out_dot1(fgraph, node):
2606
2606
scan_seqopt1 .register (
2607
2607
"scan_push_out_add" ,
2608
2608
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2609
- in2out (scan_push_out_add , ignore_newtrees = False ),
2609
+ bfs_rewriter (scan_push_out_add , ignore_newtrees = False ),
2610
2610
"scan_pushout_add" , # For backcompat: so it can be tagged with old name
2611
2611
"fast_run" ,
2612
2612
"more_mem" ,
@@ -2617,22 +2617,22 @@ def scan_push_out_dot1(fgraph, node):
2617
2617
2618
2618
scan_eqopt2 .register (
2619
2619
"while_scan_merge_subtensor_last_element" ,
2620
- in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2620
+ bfs_rewriter (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2621
2621
"fast_run" ,
2622
2622
"scan" ,
2623
2623
)
2624
2624
2625
2625
scan_eqopt2 .register (
2626
2626
"constant_folding_for_scan2" ,
2627
- in2out (constant_folding , ignore_newtrees = True ),
2627
+ bfs_rewriter (constant_folding , ignore_newtrees = True ),
2628
2628
"fast_run" ,
2629
2629
"scan" ,
2630
2630
)
2631
2631
2632
2632
2633
2633
scan_eqopt2 .register (
2634
2634
"scan_remove_constants_and_unused_inputs1" ,
2635
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2635
+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2636
2636
"remove_constants_and_unused_inputs_scan" ,
2637
2637
"fast_run" ,
2638
2638
"scan" ,
@@ -2647,23 +2647,23 @@ def scan_push_out_dot1(fgraph, node):
2647
2647
# After Merge optimization
2648
2648
scan_eqopt2 .register (
2649
2649
"scan_remove_constants_and_unused_inputs2" ,
2650
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2650
+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2651
2651
"remove_constants_and_unused_inputs_scan" ,
2652
2652
"fast_run" ,
2653
2653
"scan" ,
2654
2654
)
2655
2655
2656
2656
scan_eqopt2 .register (
2657
2657
"scan_merge_inouts" ,
2658
- in2out (scan_merge_inouts , ignore_newtrees = True ),
2658
+ bfs_rewriter (scan_merge_inouts , ignore_newtrees = True ),
2659
2659
"fast_run" ,
2660
2660
"scan" ,
2661
2661
)
2662
2662
2663
2663
# After everything else
2664
2664
scan_eqopt2 .register (
2665
2665
"scan_remove_constants_and_unused_inputs3" ,
2666
- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2666
+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2667
2667
"remove_constants_and_unused_inputs_scan" ,
2668
2668
"fast_run" ,
2669
2669
"scan" ,
0 commit comments