15
15
clone_get_equiv ,
16
16
graph_inputs ,
17
17
io_toposort ,
18
+ node_toposort ,
18
19
vars_between ,
19
20
)
20
21
from pytensor .graph .basic import as_string as graph_as_string
21
22
from pytensor .graph .features import AlreadyThere , Feature , ReplaceValidate
22
23
from pytensor .graph .op import Op
23
24
from pytensor .graph .utils import MetaObject , MissingInputError , TestValueError
24
- from pytensor .misc .ordered_set import OrderedSet
25
25
26
26
27
27
ClientType = tuple [Apply , int ]
@@ -353,21 +353,16 @@ def import_node(
353
353
apply_node : Apply
354
354
The node to be imported.
355
355
check : bool
356
- Check that the inputs for the imported nodes are also present in
357
- the `FunctionGraph`.
356
+ Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
358
357
reason : str
359
358
The name of the optimization or operation in progress.
360
359
import_missing : bool
361
360
Add missing inputs instead of raising an exception.
362
361
"""
363
362
# We import the nodes in topological order. We only are interested in
364
- # new nodes, so we use all variables we know of as if they were the
365
- # input set. (The functions in the graph module only use the input set
366
- # to know where to stop going down.)
367
- new_nodes = io_toposort (self .variables , apply_node .outputs )
368
-
369
- if check :
370
- for node in new_nodes :
363
+ # new nodes, so we use all nodes we know of as inputs to interrupt the toposort
364
+ for node in node_toposort ([apply_node ], node_inputs = self .apply_nodes ):
365
+ if check :
371
366
for var in node .inputs :
372
367
if (
373
368
var .owner is None
@@ -387,8 +382,6 @@ def import_node(
387
382
)
388
383
raise MissingInputError (error_msg , variable = var )
389
384
390
- for node in new_nodes :
391
- assert node not in self .apply_nodes
392
385
self .apply_nodes .add (node )
393
386
if not hasattr (node .tag , "imported_by" ):
394
387
node .tag .imported_by = []
@@ -753,11 +746,12 @@ def toposort(self) -> list[Apply]:
753
746
:meth:`FunctionGraph.orderings`.
754
747
755
748
"""
756
- if len (self .apply_nodes ) < 2 :
757
- # No sorting is necessary
758
- return list (self .apply_nodes )
759
-
760
- return io_toposort (self .inputs , self .outputs , self .orderings ())
749
+ orderings = self .orderings ()
750
+ if orderings :
751
+ return io_toposort (self .inputs , self .outputs , orderings )
752
+ else :
753
+ # Faster implementation when no orderings are needed
754
+ return node_toposort (o .owner for o in self .outputs )
761
755
762
756
def orderings (self ) -> dict [Apply , list [Apply ]]:
763
757
"""Return a map of node to node evaluation dependencies.
@@ -776,28 +770,14 @@ def orderings(self) -> dict[Apply, list[Apply]]:
776
770
take care of computing the dependencies by itself.
777
771
778
772
"""
779
- assert isinstance (self ._features , list )
780
773
all_orderings : list [dict ] = []
781
774
782
775
for feature in self ._features :
783
776
if hasattr (feature , "orderings" ):
784
777
orderings = feature .orderings (self )
785
- if not isinstance (orderings , dict ):
786
- raise TypeError (
787
- "Non-deterministic return value from "
788
- + str (feature .orderings )
789
- + ". Nondeterministic object is "
790
- + str (orderings )
791
- )
792
- if len (orderings ) > 0 :
778
+ if orderings :
793
779
all_orderings .append (orderings )
794
- for node , prereqs in orderings .items ():
795
- if not isinstance (prereqs , list | OrderedSet ):
796
- raise TypeError (
797
- "prereqs must be a type with a "
798
- "deterministic iteration order, or toposort "
799
- " will be non-deterministic."
800
- )
780
+
801
781
if len (all_orderings ) == 1 :
802
782
# If there is only 1 ordering, we reuse it directly.
803
783
return all_orderings [0 ].copy ()
0 commit comments