Skip to content

Commit 20d88e5

Browse files
committed
Speedup toposort
1 parent e1c5da0 commit 20d88e5

File tree

7 files changed

+96
-82
lines changed

7 files changed

+96
-82
lines changed

pytensor/graph/basic.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,13 +1522,38 @@ def _compute_deps_cache_(io):
15221522
return rlist
15231523

15241524

1525+
def node_toposort(
1526+
output_nodes: Iterable[Apply] = (),
1527+
input_nodes: Iterable[Apply] = (),
1528+
) -> Generator[Apply]:
1529+
computed = {None, *input_nodes}
1530+
todo = list(output_nodes)
1531+
while todo:
1532+
cur = todo[-1]
1533+
if cur in computed:
1534+
todo.pop()
1535+
continue
1536+
# Since computed includes None we don't need to filter it in this check
1537+
if all(i.owner in computed for i in cur.inputs):
1538+
computed.add(cur)
1539+
yield todo.pop()
1540+
else:
1541+
todo.extend(i.owner for i in cur.inputs if i.owner is not None)
1542+
1543+
15251544
def io_toposort(
15261545
inputs: Iterable[Variable],
15271546
outputs: Reversible[Variable],
15281547
orderings: dict[Apply, list[Apply]] | None = None,
15291548
clients: dict[Variable, list[Variable]] | None = None,
15301549
) -> list[Apply]:
1531-
"""Perform topological sort from input and output nodes.
1550+
"""Perform topological of nodes between input and output variables.
1551+
1552+
Notes
1553+
-----
1554+
If sorting from root or single-output node variables, without orderings or clients,
1555+
it's slightly faster to use `list(node_toposort((o.owner for o in outputs)))` instead,
1556+
as the individual variables can be ignored
15321557
15331558
Parameters
15341559
----------
@@ -1543,32 +1568,35 @@ def io_toposort(
15431568
each node in the subgraph that is sorted.
15441569
15451570
"""
1546-
if not orderings and clients is None: # ordering can be None or empty dict
1547-
# Specialized function that is faster when more then ~10 nodes
1548-
# when no ordering.
1549-
1550-
# Do a new stack implementation with the vm algo.
1551-
# This will change the order returned.
1571+
if not orderings and clients is None:
1572+
# Specialized function that is faster when no special orderings are required and clients need not be computed.
15521573
computed = set(inputs)
1553-
todo = [o.owner for o in reversed(outputs) if o.owner]
1574+
todo = [o.owner for o in outputs if o.owner is not None]
15541575
order = []
15551576
while todo:
1556-
cur = todo.pop()
1557-
if all(out in computed for out in cur.outputs):
1577+
cur = todo[-1]
1578+
# It's faster to short circuit on the first output, as most nodes will have all edges non-computed
1579+
# Starting the `all` iterator has a non-negligeable cost
1580+
if cur.outputs[0] in computed and all(
1581+
out in computed for out in cur.outputs[1:]
1582+
):
1583+
todo.pop()
15581584
continue
15591585
if all(i in computed or i.owner is None for i in cur.inputs):
15601586
computed.update(cur.outputs)
1561-
order.append(cur)
1587+
order.append(todo.pop())
15621588
else:
1563-
todo.append(cur)
15641589
todo.extend(
1565-
i.owner for i in cur.inputs if (i.owner and i not in computed)
1590+
i.owner
1591+
for i in cur.inputs
1592+
if ((i.owner is not None) and (i not in computed))
15661593
)
15671594
return order
15681595

15691596
iset = set(inputs)
15701597

1571-
if not orderings: # ordering can be None or empty dict
1598+
if not orderings:
1599+
# ordering can be None or empty dict
15721600
# Specialized function that is faster when no ordering.
15731601
# Also include the cache in the function itself for speed up.
15741602

pytensor/graph/fg.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
clone_get_equiv,
1616
graph_inputs,
1717
io_toposort,
18+
node_toposort,
1819
vars_between,
1920
)
2021
from pytensor.graph.basic import as_string as graph_as_string
2122
from pytensor.graph.features import AlreadyThere, Feature, ReplaceValidate
2223
from pytensor.graph.op import Op
2324
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
24-
from pytensor.misc.ordered_set import OrderedSet
2525

2626

2727
ClientType = tuple[Apply, int]
@@ -353,21 +353,16 @@ def import_node(
353353
apply_node : Apply
354354
The node to be imported.
355355
check : bool
356-
Check that the inputs for the imported nodes are also present in
357-
the `FunctionGraph`.
356+
Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
358357
reason : str
359358
The name of the optimization or operation in progress.
360359
import_missing : bool
361360
Add missing inputs instead of raising an exception.
362361
"""
363362
# We import the nodes in topological order. We only are interested in
364-
# new nodes, so we use all variables we know of as if they were the
365-
# input set. (The functions in the graph module only use the input set
366-
# to know where to stop going down.)
367-
new_nodes = io_toposort(self.variables, apply_node.outputs)
368-
369-
if check:
370-
for node in new_nodes:
363+
# new nodes, so we use all nodes we know of as inputs to interrupt the toposort
364+
for node in node_toposort([apply_node], node_inputs=self.apply_nodes):
365+
if check:
371366
for var in node.inputs:
372367
if (
373368
var.owner is None
@@ -387,8 +382,6 @@ def import_node(
387382
)
388383
raise MissingInputError(error_msg, variable=var)
389384

390-
for node in new_nodes:
391-
assert node not in self.apply_nodes
392385
self.apply_nodes.add(node)
393386
if not hasattr(node.tag, "imported_by"):
394387
node.tag.imported_by = []
@@ -753,11 +746,12 @@ def toposort(self) -> list[Apply]:
753746
:meth:`FunctionGraph.orderings`.
754747
755748
"""
756-
if len(self.apply_nodes) < 2:
757-
# No sorting is necessary
758-
return list(self.apply_nodes)
759-
760-
return io_toposort(self.inputs, self.outputs, self.orderings())
749+
orderings = self.orderings()
750+
if orderings:
751+
return io_toposort(self.inputs, self.outputs, orderings)
752+
else:
753+
# Faster implementation when no orderings are needed
754+
return node_toposort(o.owner for o in self.outputs)
761755

762756
def orderings(self) -> dict[Apply, list[Apply]]:
763757
"""Return a map of node to node evaluation dependencies.
@@ -776,28 +770,14 @@ def orderings(self) -> dict[Apply, list[Apply]]:
776770
take care of computing the dependencies by itself.
777771
778772
"""
779-
assert isinstance(self._features, list)
780773
all_orderings: list[dict] = []
781774

782775
for feature in self._features:
783776
if hasattr(feature, "orderings"):
784777
orderings = feature.orderings(self)
785-
if not isinstance(orderings, dict):
786-
raise TypeError(
787-
"Non-deterministic return value from "
788-
+ str(feature.orderings)
789-
+ ". Nondeterministic object is "
790-
+ str(orderings)
791-
)
792-
if len(orderings) > 0:
778+
if orderings:
793779
all_orderings.append(orderings)
794-
for node, prereqs in orderings.items():
795-
if not isinstance(prereqs, list | OrderedSet):
796-
raise TypeError(
797-
"prereqs must be a type with a "
798-
"deterministic iteration order, or toposort "
799-
" will be non-deterministic."
800-
)
780+
801781
if len(all_orderings) == 1:
802782
# If there is only 1 ordering, we reuse it directly.
803783
return all_orderings[0].copy()

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Constant,
2626
Variable,
2727
applys_between,
28-
io_toposort,
28+
node_toposort,
2929
vars_between,
3030
)
3131
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder
@@ -2027,7 +2027,7 @@ def apply(self, fgraph, start_from=None):
20272027
callback_before = fgraph.execute_callbacks_time
20282028
nb_nodes_start = len(fgraph.apply_nodes)
20292029
t0 = time.perf_counter()
2030-
q = deque(io_toposort(fgraph.inputs, start_from))
2030+
q = deque(node_toposort(output_nodes=(o.owner for o in start_from)))
20312031
io_t = time.perf_counter() - t0
20322032

20332033
def importer(node):
@@ -2320,11 +2320,6 @@ def add_requirements(self, fgraph):
23202320
def apply(self, fgraph, start_from=None):
23212321
change_tracker = ChangeTracker()
23222322
fgraph.attach_feature(change_tracker)
2323-
if start_from is None:
2324-
start_from = fgraph.outputs
2325-
else:
2326-
for node in start_from:
2327-
assert node in fgraph.outputs
23282323

23292324
changed = True
23302325
max_use_abort = False
@@ -2403,7 +2398,7 @@ def apply_cleanup(profs_dict):
24032398
changed |= apply_cleanup(iter_cleanup_sub_profs)
24042399

24052400
topo_t0 = time.perf_counter()
2406-
q = deque(io_toposort(fgraph.inputs, start_from))
2401+
q = deque(node_toposort(o.owner for o in fgraph.outputs))
24072402
io_toposort_timing.append(time.perf_counter() - topo_t0)
24082403

24092404
nb_nodes.append(len(q))

pytensor/scan/rewriting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
equal_computations,
2424
graph_inputs,
2525
io_toposort,
26+
node_toposort,
2627
)
2728
from pytensor.graph.destroyhandler import DestroyHandler
2829
from pytensor.graph.features import ReplaceValidate
@@ -225,7 +226,7 @@ def scan_push_out_non_seq(fgraph, node):
225226

226227
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
227228

228-
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
229+
local_fgraph_topo = node_toposort(node_outputs)
229230
local_fgraph_outs_set = set(node_outputs)
230231
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
231232

@@ -435,7 +436,7 @@ def scan_push_out_seq(fgraph, node):
435436

436437
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
437438

438-
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
439+
local_fgraph_topo = node_toposort(node_outputs)
439440
local_fgraph_outs_set = set(node_outputs)
440441
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
441442

pytensor/tensor/rewriting/blas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959

6060
import numpy as np
6161

62+
from pytensor.graph.basic import node_toposort
6263
from pytensor.tensor.rewriting.basic import register_specialize
6364

6465

@@ -459,6 +460,8 @@ def apply(self, fgraph):
459460
callbacks_before = fgraph.execute_callbacks_times.copy()
460461
callback_before = fgraph.execute_callbacks_time
461462

463+
nodelist = node_toposort(o.owner for o in fgraph.outputs)
464+
462465
def on_import(new_node):
463466
if new_node is not node:
464467
nodelist.append(new_node)
@@ -470,7 +473,6 @@ def on_import(new_node):
470473
while did_something:
471474
nb_iter += 1
472475
t0 = time.perf_counter()
473-
nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs)
474476
time_toposort += time.perf_counter() - t0
475477
did_something = False
476478
nodelist.reverse()

tests/graph/test_basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_var_by_name,
2424
graph_inputs,
2525
io_toposort,
26+
node_toposort,
2627
orphans_between,
2728
truncated_graph_inputs,
2829
variable_depends_on,
@@ -329,6 +330,23 @@ def test_multi_output_nodes(self):
329330
out.owner,
330331
}
331332

333+
@pytest.mark.parametrize(
334+
"toposort_func",
335+
[
336+
lambda x: io_toposort([], [x]),
337+
lambda x: list(node_toposort([x.owner])),
338+
],
339+
ids=["io_toposort", "node_toposort"],
340+
)
341+
def test_benchmark(self, toposort_func, benchmark):
342+
r1 = MyVariable(1)
343+
out = r1
344+
for i in range(50):
345+
out = MyOp(out, out)
346+
347+
assert toposort_func(out) == io_toposort([r1], [out])
348+
benchmark(toposort_func, out)
349+
332350

333351
class TestEval:
334352
def setup_method(self):

0 commit comments

Comments
 (0)