Skip to content

Commit dbb4551

Browse files
committed
Improve rewrites
1 parent 5bfbcbe commit dbb4551

File tree

4 files changed

+52
-21
lines changed

4 files changed

+52
-21
lines changed

pytensor/graph/basic.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -875,23 +875,26 @@ def walk(
875875
else:
876876
nodes_pop = nodes.pop
877877

878-
while nodes:
879-
node: T = nodes_pop()
880-
881-
node_hash: int = hash_fn(node)
882-
883-
if node_hash not in rval_set:
884-
rval_set.add(node_hash)
885-
886-
new_nodes: Iterable[T] | None = expand(node)
887-
888-
if return_children:
878+
if return_children:
879+
while nodes:
880+
node: T = nodes_pop()
881+
node_hash: int = hash_fn(node)
882+
if node_hash not in rval_set:
883+
new_nodes: Iterable[T] | None = expand(node)
889884
yield node, new_nodes
890-
else:
885+
rval_set.add(node_hash)
886+
if new_nodes:
887+
nodes.extend(new_nodes)
888+
else:
889+
while nodes:
890+
node: T = nodes_pop()
891+
node_hash: int = hash_fn(node)
892+
if node_hash not in rval_set:
893+
new_nodes: Iterable[T] | None = expand(node)
891894
yield node
892-
893-
if new_nodes:
894-
nodes.extend(new_nodes)
895+
rval_set.add(node_hash)
896+
if new_nodes:
897+
nodes.extend(new_nodes)
895898

896899

897900
def ancestors(

pytensor/graph/rewriting/basic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
applys_between,
2626
io_toposort,
2727
vars_between,
28+
walk,
2829
)
2930
from pytensor.graph.features import AlreadyThere, Feature
3031
from pytensor.graph.fg import FunctionGraph, Output
@@ -1821,12 +1822,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
18211822
def __init__(
18221823
self,
18231824
node_rewriter: NodeRewriter,
1824-
order: Literal["out_to_in", "in_to_out"] = "in_to_out",
1825+
order: Literal["out_to_in", "in_to_out", "bfs", "dfs"] = "bfs",
18251826
ignore_newtrees: bool = False,
18261827
failure_callback: FailureCallbackType | None = None,
18271828
):
1828-
if order not in ("out_to_in", "in_to_out"):
1829-
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
1829+
valid_options = ("out_to_in", "in_to_out", "bfs", "dfs")
1830+
if order not in valid_options:
1831+
raise ValueError(f"order must be one of {valid_options}, got {order}")
18301832
self.order = order
18311833
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
18321834

@@ -1836,9 +1838,19 @@ def apply(self, fgraph, start_from=None):
18361838
callback_before = fgraph.execute_callbacks_time
18371839
nb_nodes_start = len(fgraph.apply_nodes)
18381840
t0 = time.perf_counter()
1839-
q = deque(io_toposort(fgraph.inputs, start_from))
1841+
if self.order in {"bfs", "dfs"}:
1842+
q = deque(
1843+
walk(
1844+
[out.owner for out in fgraph.outputs if out.owner is not None],
1845+
expand=lambda n: (i.owner for i in n.inputs if i.owner is not None),
1846+
bfs=(self.order == "bfs"),
1847+
)
1848+
)
1849+
else:
1850+
q = deque(io_toposort(fgraph.inputs, start_from))
18401851
io_t = time.perf_counter() - t0
18411852

1853+
# Importer is ignored if self.ignore_newtrees = False
18421854
def importer(node):
18431855
if node is not current_node:
18441856
q.append(node)
@@ -1959,6 +1971,8 @@ def walking_rewriter(
19591971

19601972
in2out = partial(walking_rewriter, "in_to_out")
19611973
out2in = partial(walking_rewriter, "out_to_in")
1974+
bfs = partial(walking_rewriter, "bfs")
1975+
dfs = partial(walking_rewriter, "dfs")
19621976

19631977

19641978
class ChangeTracker(Feature):

pytensor/graph/rewriting/db.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,10 @@ class EquilibriumDB(RewriteDatabase):
310310
"""
311311

312312
def __init__(
313-
self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False
313+
self,
314+
ignore_newtrees: bool = True,
315+
tracks_on_change_inputs: bool = False,
316+
eq_rewriter_class=pytensor_rewriting.EquilibriumGraphRewriter,
314317
):
315318
"""
316319
@@ -329,6 +332,7 @@ def __init__(
329332
self.tracks_on_change_inputs = tracks_on_change_inputs
330333
self.__final__: dict[str, bool] = {}
331334
self.__cleanup__: dict[str, bool] = {}
335+
self.eq_rewriter_class = eq_rewriter_class
332336

333337
def register(
334338
self,
@@ -360,7 +364,7 @@ def query(self, *tags, **kwtags):
360364
final_rewriters = None
361365
if len(cleanup_rewriters) == 0:
362366
cleanup_rewriters = None
363-
return pytensor_rewriting.EquilibriumGraphRewriter(
367+
return self.eq_rewriter_class(
364368
rewriters,
365369
max_use_ratio=config.optdb__max_use_ratio,
366370
ignore_newtrees=self.ignore_newtrees,

pytensor/scan/rewriting.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.graph.op import compute_test_value
3131
from pytensor.graph.replace import clone_replace
3232
from pytensor.graph.rewriting.basic import (
33+
EquilibriumGraphRewriter,
3334
GraphRewriter,
3435
copy_stack_trace,
3536
in2out,
@@ -2517,6 +2518,15 @@ def scan_push_out_dot1(fgraph, node):
25172518
return False
25182519

25192520

2521+
class ScanEquilibriumGraphRewriter(EquilibriumGraphRewriter):
2522+
"""Subclass of EquilibriumGraphRewriter that aborts early if there are no Scan Ops in the graph"""
2523+
2524+
def apply(self, fgraph, start_from=None):
2525+
if not any(isinstance(node.op, Scan) for node in fgraph.apply_nodes):
2526+
return
2527+
super().apply(fgraph=fgraph, start_from=start_from)
2528+
2529+
25202530
# I've added an equilibrium because later scan optimization in the sequence
25212531
# can make it such that earlier optimizations should apply. However, in
25222532
# general I do not expect the sequence to run more then once

0 commit comments

Comments
 (0)