diff --git a/doc/gallery/optimize/root.ipynb b/doc/gallery/optimize/root.ipynb index dc63107c9a..715a0788cc 100644 --- a/doc/gallery/optimize/root.ipynb +++ b/doc/gallery/optimize/root.ipynb @@ -1772,7 +1772,9 @@ } ], "source": [ - "from pytensor.graph.basic import explicit_graph_inputs\n", + "\n", + "from pytensor.graph.traversal import explicit_graph_inputs\n", + "\n", "list(explicit_graph_inputs(w_bar_2))" ] }, diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 8a53ee3192..eedefa430f 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -16,13 +16,12 @@ Constant, NominalVariable, Variable, - graph_inputs, - io_connection_pattern, ) from pytensor.graph.fg import FunctionGraph from pytensor.graph.null_type import NullType -from pytensor.graph.op import HasInnerGraph, Op +from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern from pytensor.graph.replace import clone_replace +from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index 384f9eb874..52a4e6f305 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -27,11 +27,12 @@ from pytensor.compile.mode import Mode, register_mode from pytensor.compile.ops import OutputGuard, _output_guard from pytensor.configdefaults import config -from pytensor.graph.basic import Variable, io_toposort +from pytensor.graph.basic import Variable from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.features import AlreadyThere, BadOptimization from pytensor.graph.fg import Output from pytensor.graph.op import HasInnerGraph, Op +from pytensor.graph.traversal import io_toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.link.basic import Container, LocalLinker from pytensor.link.c.op import COp diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 246354de0f..635af25e47 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -20,14 +20,13 @@ from pytensor.graph.basic import ( Constant, Variable, - ancestors, clone_get_equiv, - graph_inputs, ) from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.features import AlreadyThere, Feature, PreserveVariableAttributes from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph +from pytensor.graph.traversal import ancestors, graph_inputs from pytensor.graph.utils import InconsistencyError, get_variable_trace_string from pytensor.link.basic import Container from pytensor.link.utils import raise_with_op diff --git a/pytensor/d3viz/formatting.py b/pytensor/d3viz/formatting.py index df39335c19..98c66fadc0 100644 --- a/pytensor/d3viz/formatting.py +++ b/pytensor/d3viz/formatting.py @@ -10,8 +10,9 @@ import pytensor from pytensor.compile import Function, builders -from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph +from pytensor.graph.traversal import graph_inputs from pytensor.printing import _try_pydot_import diff --git a/pytensor/graph/__init__.py b/pytensor/graph/__init__.py index 189dfed237..5753479d25 100644 --- a/pytensor/graph/__init__.py +++ b/pytensor/graph/__init__.py @@ -5,10 +5,9 @@ Apply, Variable, Constant, - graph_inputs, clone, - ancestors, ) +from pytensor.graph.traversal import ancestors, graph_inputs from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph from pytensor.graph.op import Op from pytensor.graph.type import Type diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 512f0ef3ab..5d6667683b 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -2,15 +2,9 @@ import abc import warnings -from collections import deque from collections.abc import ( - Callable, - Collection, - Generator, Hashable, Iterable, - Iterator, - Reversible, Sequence, ) from copy import copy @@ -23,7 +17,6 @@ TypeVar, Union, cast, - overload, ) import numpy as np @@ -37,7 +30,6 @@ add_tag_trace, get_variable_trace_string, ) -from pytensor.misc.ordered_set import OrderedSet if TYPE_CHECKING: @@ -50,9 +42,39 @@ _TypeType = TypeVar("_TypeType", bound="Type") _IdType = TypeVar("_IdType", bound=Hashable) -T = TypeVar("T", bound="Node") -NoParams = object() -NodeAndChildren = tuple[T, Iterable[T] | None] +_MOVED_FUNCTIONS = { + "walk", + "ancestors", + "graph_inputs", + "explicit_graph_inputs", + "vars_between", + "orhpans_between", + "applys_between", + "apply_depends_on", + "truncated_graph_inputs", + "general_toposort", + "io_toposort", + "list_of_nodes", + "get_var_by_name", +} + + +def __getattr__(name): + """Provide backwards-compatibility for functions moved to graph/traversal.py.""" + if name in _MOVED_FUNCTIONS: + warnings.warn( + ( + f"`pytensor.graph.basic.{name}` was moved to `pytensor.graph.traversal.{name}`. " + "Calling it from the old location will fail in a future release." + ), + FutureWarning, + stacklevel=2, + ) + from pytensor.graph import traversal + + return getattr(traversal, name) + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") class Node(MetaObject): @@ -200,6 +222,7 @@ def default_output(self): return self.outputs[do] def __str__(self): + # FIXME: The called function is too complicated for this simple use case. return op_as_string(self.inputs, self) def __repr__(self): @@ -428,23 +451,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): # __slots__ = ['type', 'owner', 'index', 'name'] __count__ = count(0) - _owner: OptionalApplyType - - @property - def owner(self) -> OptionalApplyType: - return self._owner - - @owner.setter - def owner(self, value) -> None: - self._owner = value - - @property - def index(self): - return self._index - - @index.setter - def index(self, value): - self._index = value + owner: OptionalApplyType def __init__( self, @@ -459,7 +466,7 @@ def __init__( self.type = type - self._owner = owner + self.owner = owner if owner is not None and not isinstance(owner, Apply): raise TypeError("owner must be an Apply instance") @@ -615,6 +622,7 @@ def eval( function, so don't use it too much in real scripts. """ from pytensor.compile.function import function + from pytensor.graph.traversal import get_var_by_name ignore_unused_input = kwargs.get("on_unused_input", None) in ("ignore", "warn") @@ -832,390 +840,6 @@ def value(self): return self.data -def walk( - nodes: Iterable[T], - expand: Callable[[T], Iterable[T] | None], - bfs: bool = True, - return_children: bool = False, - hash_fn: Callable[[T], int] = id, -) -> Generator[T | NodeAndChildren, None, None]: - r"""Walk through a graph, either breadth- or depth-first. - - Parameters - ---------- - nodes - The nodes from which to start walking. - expand - A callable that is applied to each node in `nodes`, the results of - which are either new nodes to visit or ``None``. - bfs - If ``True``, breath first search is used; otherwise, depth first - search. - return_children - If ``True``, each output node will be accompanied by the output of - `expand` (i.e. the corresponding child nodes). - hash_fn - The function used to produce hashes of the elements in `nodes`. - The default is ``id``. - - Notes - ----- - A node will appear at most once in the return value, even if it - appears multiple times in the `nodes` parameter. - - """ - - nodes = deque(nodes) - - rval_set: set[int] = set() - - nodes_pop: Callable[[], T] - if bfs: - nodes_pop = nodes.popleft - else: - nodes_pop = nodes.pop - - while nodes: - node: T = nodes_pop() - - node_hash: int = hash_fn(node) - - if node_hash not in rval_set: - rval_set.add(node_hash) - - new_nodes: Iterable[T] | None = expand(node) - - if return_children: - yield node, new_nodes - else: - yield node - - if new_nodes: - nodes.extend(new_nodes) - - -def ancestors( - graphs: Iterable[Variable], blockers: Collection[Variable] | None = None -) -> Generator[Variable, None, None]: - r"""Return the variables that contribute to those in given graphs (inclusive). - - Parameters - ---------- - graphs : list of `Variable` instances - Output `Variable` instances from which to search backward through - owners. - blockers : list of `Variable` instances - A collection of `Variable`\s that, when found, prevent the graph search - from preceding from that point. - - Yields - ------ - `Variable`\s - All input nodes, in the order found by a left-recursive depth-first - search started at the nodes in `graphs`. - - """ - - def expand(r: Variable) -> Iterator[Variable] | None: - if r.owner and (not blockers or r not in blockers): - return reversed(r.owner.inputs) - return None - - yield from cast(Generator[Variable, None, None], walk(graphs, expand, False)) - - -def graph_inputs( - graphs: Iterable[Variable], blockers: Collection[Variable] | None = None -) -> Generator[Variable, None, None]: - r"""Return the inputs required to compute the given Variables. - - Parameters - ---------- - graphs : list of `Variable` instances - Output `Variable` instances from which to search backward through - owners. - blockers : list of `Variable` instances - A collection of `Variable`\s that, when found, prevent the graph search - from preceding from that point. - - Yields - ------ - Input nodes with no owner, in the order found by a left-recursive - depth-first search started at the nodes in `graphs`. - - """ - yield from (r for r in ancestors(graphs, blockers) if r.owner is None) - - -def explicit_graph_inputs( - graph: Variable | Iterable[Variable], -) -> Generator[Variable, None, None]: - """ - Get the root variables needed as inputs to a function that computes `graph` - - Parameters - ---------- - graph : TensorVariable - Output `Variable` instances for which to search backward through - owners. - - Returns - ------- - iterable - Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`. - - Examples - -------- - - .. code-block:: python - - import pytensor - import pytensor.tensor as pt - from pytensor.graph.basic import explicit_graph_inputs - - x = pt.vector("x") - y = pt.constant(2) - z = pt.mul(x * y) - - inputs = list(explicit_graph_inputs(z)) - f = pytensor.function(inputs, z) - eval = f([1, 2, 3]) - - print(eval) - # [2. 4. 6.] - """ - from pytensor.compile.sharedvalue import SharedVariable - - if isinstance(graph, Variable): - graph = [graph] - - return ( - v - for v in graph_inputs(graph) - if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable) - ) - - -def vars_between( - ins: Iterable[Variable], outs: Iterable[Variable] -) -> Generator[Variable, None, None]: - r"""Extract the `Variable`\s within the sub-graph between input and output nodes. - - Parameters - ---------- - ins - Input `Variable`\s. - outs - Output `Variable`\s. - - Yields - ------ - The `Variable`\s that are involved in the subgraph that lies - between `ins` and `outs`. This includes `ins`, `outs`, - ``orphans_between(ins, outs)`` and all values of all intermediary steps from - `ins` to `outs`. - - """ - - ins = set(ins) - - def expand(r: Variable) -> Iterable[Variable] | None: - if r.owner and r not in ins: - return reversed(r.owner.inputs + r.owner.outputs) - return None - - yield from cast(Generator[Variable, None, None], walk(outs, expand)) - - -def orphans_between( - ins: Collection[Variable], outs: Iterable[Variable] -) -> Generator[Variable, None, None]: - r"""Extract the `Variable`\s not within the sub-graph between input and output nodes. - - Parameters - ---------- - ins : list - Input `Variable`\s. - outs : list - Output `Variable`\s. - - Yields - ------- - Variable - The `Variable`\s upon which one or more `Variable`\s in `outs` - depend, but are neither in `ins` nor in the sub-graph that lies between - them. - - Examples - -------- - >>> from pytensor.graph.basic import orphans_between - >>> from pytensor.tensor import scalars - >>> x, y = scalars("xy") - >>> list(orphans_between([x], [(x + y)])) - [y] - - """ - yield from (r for r in vars_between(ins, outs) if r.owner is None and r not in ins) - - -def applys_between( - ins: Collection[Variable], outs: Iterable[Variable] -) -> Generator[Apply, None, None]: - r"""Extract the `Apply`\s contained within the sub-graph between given input and output variables. - - Parameters - ---------- - ins : list - Input `Variable`\s. - outs : list - Output `Variable`\s. - - Yields - ------ - The `Apply`\s that are contained within the sub-graph that lies - between `ins` and `outs`, including the owners of the `Variable`\s in - `outs` and intermediary `Apply`\s between `ins` and `outs`, but not the - owners of the `Variable`\s in `ins`. - - """ - yield from ( - r.owner for r in vars_between(ins, outs) if r not in ins and r.owner is not None - ) - - -def truncated_graph_inputs( - outputs: Sequence[Variable], - ancestors_to_include: Collection[Variable] | None = None, -) -> list[Variable]: - """Get the truncate graph inputs. - - Unlike :func:`graph_inputs` this function will return - the closest variables to outputs that do not depend on - ``ancestors_to_include``. So given all the returned - variables provided there is no missing variable to - compute the output and all variables are independent - from each other. - - Parameters - ---------- - outputs : Collection[Variable] - Variable to get conditions for - ancestors_to_include : Optional[Collection[Variable]] - Additional ancestors to assume, by default None - - Returns - ------- - List[Variable] - Variables required to compute ``outputs`` - - Examples - -------- - The returned variables marked in (parenthesis), ancestors variables are ``c``, output variables are ``o`` - - * No ancestors to include - - .. code-block:: - - n - n - (o) - - * One ancestors to include - - .. code-block:: - - n - (c) - o - - * Two ancestors to include where on depends on another, both returned - - .. code-block:: - - (c) - (c) - o - - * Additional variables are present - - .. code-block:: - - (c) - n - o - n - (n) -' - - * Disconnected ancestors to include not returned - - .. code-block:: - - (c) - n - o - c - - * Disconnected output is present and returned - - .. code-block:: - - (c) - (c) - o - (o) - - * ancestors to include that include itself adds itself - - .. code-block:: - - n - (c) - (o/c) - - """ - # simple case, no additional ancestors to include - truncated_inputs: list[Variable] = list() - # blockers have known independent variables and ancestors to include - candidates = list(outputs) - if not ancestors_to_include: # None or empty - # just filter out unique variables - for variable in candidates: - if variable not in truncated_inputs: - truncated_inputs.append(variable) - # no more actions are needed - return truncated_inputs - - blockers: set[Variable] = set(ancestors_to_include) - # variables that go here are under check already, do not repeat the loop for them - seen: set[Variable] = set() - # enforce O(1) check for variable in ancestors to include - ancestors_to_include = blockers.copy() - - while candidates: - # on any new candidate - variable = candidates.pop() - # we've looked into this variable already - if variable in seen: - continue - # check if the variable is independent, never go above blockers; - # blockers are independent variables and ancestors to include - elif variable in ancestors_to_include: - # The case where variable is in ancestors to include so we check if it depends on others - # it should be removed from the blockers to check against the rest - dependent = variable_depends_on(variable, ancestors_to_include - {variable}) - # ancestors to include that are present in the graph (not disconnected) - # should be added to truncated_inputs - truncated_inputs.append(variable) - if dependent: - # if the ancestors to include is still dependent we need to go above, the search is not yet finished - # owner can never be None for a dependent variable - candidates.extend(n for n in variable.owner.inputs if n not in seen) - else: - # A regular variable to check - dependent = variable_depends_on(variable, blockers) - # all regular variables fall to blockers - # 1. it is dependent - further search irrelevant - # 2. it is independent - the search variable is inside the closure - blockers.add(variable) - # if we've found an independent variable and it is not in blockers so far - # it is a new independent variable not present in ancestors to include - if dependent: - # populate search if it's not an independent variable - # owner can never be None for a dependent variable - candidates.extend(n for n in variable.owner.inputs if n not in seen) - else: - # otherwise, do not search beyond - truncated_inputs.append(variable) - # add variable to seen, no point in checking it once more - seen.add(variable) - return truncated_inputs - - def clone( inputs: Sequence[Variable], outputs: Sequence[Variable], @@ -1320,12 +944,11 @@ def clone_node_and_cache( def clone_get_equiv( inputs: Iterable[Variable], - outputs: Reversible[Variable], + outputs: Iterable[Variable], copy_inputs: bool = True, copy_orphans: bool = True, - memo: ( - dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] | None - ) = None, + memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] + | None = None, clone_inner_graphs: bool = False, **kwargs, ) -> dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]: @@ -1362,6 +985,8 @@ def clone_get_equiv( Keywords passed to `Apply.clone_with_new_inputs`. """ + from pytensor.graph.traversal import toposort + if memo is None: memo = {} @@ -1376,7 +1001,7 @@ def clone_get_equiv( memo.setdefault(input, input) # go through the inputs -> outputs graph cloning as we go - for apply in io_toposort(inputs, outputs): + for apply in toposort(outputs, blockers=inputs): for input in apply.inputs: if input not in memo: if not isinstance(input, Constant) and copy_orphans: @@ -1397,310 +1022,11 @@ def clone_get_equiv( return memo -@overload -def general_toposort( - outputs: Iterable[T], - deps: None, - compute_deps_cache: Callable[[T], OrderedSet | list[T] | None], - deps_cache: dict[T, list[T]] | None, - clients: dict[T, list[T]] | None, -) -> list[T]: ... - - -@overload -def general_toposort( - outputs: Iterable[T], - deps: Callable[[T], OrderedSet | list[T]], - compute_deps_cache: None, - deps_cache: None, - clients: dict[T, list[T]] | None, -) -> list[T]: ... - - -def general_toposort( - outputs: Iterable[T], - deps: Callable[[T], OrderedSet | list[T]] | None, - compute_deps_cache: Callable[[T], OrderedSet | list[T] | None] | None = None, - deps_cache: dict[T, list[T]] | None = None, - clients: dict[T, list[T]] | None = None, -) -> list[T]: - """Perform a topological sort of all nodes starting from a given node. - - Parameters - ---------- - deps : callable - A Python function that takes a node as input and returns its dependence. - compute_deps_cache : optional - If provided, `deps_cache` should also be provided. This is a function like - `deps`, but that also caches its results in a ``dict`` passed as `deps_cache`. - deps_cache : dict - A ``dict`` mapping nodes to their children. This is populated by - `compute_deps_cache`. - clients : dict - If a ``dict`` is passed, it will be filled with a mapping of - nodes-to-clients for each node in the subgraph. - - Notes - ----- - - ``deps(i)`` should behave like a pure function (no funny business with - internal state). - - ``deps(i)`` will be cached by this function (to be fast). - - The order of the return value list is determined by the order of nodes - returned by the `deps` function. - - The second option removes a Python function call, and allows for more - specialized code, so it can be faster. - - """ - if compute_deps_cache is None: - if deps_cache is None: - deps_cache = {} - - def _compute_deps_cache_(io): - if io not in deps_cache: - d = deps(io) - - if d: - if not isinstance(d, list | OrderedSet): - raise TypeError( - "Non-deterministic collections found; make" - " toposort non-deterministic." - ) - deps_cache[io] = list(d) - else: - deps_cache[io] = None - - return d - else: - return deps_cache[io] - - _compute_deps_cache = _compute_deps_cache_ - - else: - _compute_deps_cache = compute_deps_cache - - if deps_cache is None: - raise ValueError("deps_cache cannot be None") - - search_res: list[NodeAndChildren] = cast( - list[NodeAndChildren], - list(walk(outputs, _compute_deps_cache, bfs=False, return_children=True)), - ) - - _clients: dict[T, list[T]] = {} - sources: deque[T] = deque() - search_res_len = len(search_res) - for snode, children in search_res: - if children: - for child in children: - _clients.setdefault(child, []).append(snode) - if not deps_cache.get(snode): - sources.append(snode) - - if clients is not None: - clients.update(_clients) - - rset: set[T] = set() - rlist: list[T] = [] - while sources: - node: T = sources.popleft() - if node not in rset: - rlist.append(node) - rset.add(node) - for client in _clients.get(node, []): - d = [a for a in deps_cache[client] if a is not node] - deps_cache[client] = d - if not d: - sources.append(client) - - if len(rlist) != search_res_len: - raise ValueError("graph contains cycles") - - return rlist - - -def io_toposort( - inputs: Iterable[Variable], - outputs: Reversible[Variable], - orderings: dict[Apply, list[Apply]] | None = None, - clients: dict[Variable, list[Variable]] | None = None, -) -> list[Apply]: - """Perform topological sort from input and output nodes. - - Parameters - ---------- - inputs : list or tuple of Variable instances - Graph inputs. - outputs : list or tuple of Apply instances - Graph outputs. - orderings : dict - Keys are `Apply` instances, values are lists of `Apply` instances. - clients : dict - If provided, it will be filled with mappings of nodes-to-clients for - each node in the subgraph that is sorted. - - """ - if not orderings and clients is None: # ordering can be None or empty dict - # Specialized function that is faster when more then ~10 nodes - # when no ordering. - - # Do a new stack implementation with the vm algo. - # This will change the order returned. - computed = set(inputs) - todo = [o.owner for o in reversed(outputs) if o.owner] - order = [] - while todo: - cur = todo.pop() - if all(out in computed for out in cur.outputs): - continue - if all(i in computed or i.owner is None for i in cur.inputs): - computed.update(cur.outputs) - order.append(cur) - else: - todo.append(cur) - todo.extend( - i.owner for i in cur.inputs if (i.owner and i not in computed) - ) - return order - - iset = set(inputs) - - if not orderings: # ordering can be None or empty dict - # Specialized function that is faster when no ordering. - # Also include the cache in the function itself for speed up. - - deps_cache: dict = {} - - def compute_deps_cache(obj): - if obj in deps_cache: - return deps_cache[obj] - rval = [] - if obj not in iset: - if isinstance(obj, Variable): - if obj.owner: - rval = [obj.owner] - elif isinstance(obj, Apply): - rval = list(obj.inputs) - if rval: - deps_cache[obj] = list(rval) - else: - deps_cache[obj] = rval - else: - deps_cache[obj] = rval - return rval - - topo = general_toposort( - outputs, - deps=None, - compute_deps_cache=compute_deps_cache, - deps_cache=deps_cache, - clients=clients, - ) - - else: - # the inputs are used only here in the function that decides what - # 'predecessors' to explore - def compute_deps(obj): - rval = [] - if obj not in iset: - if isinstance(obj, Variable): - if obj.owner: - rval = [obj.owner] - elif isinstance(obj, Apply): - rval = list(obj.inputs) - rval.extend(orderings.get(obj, [])) - else: - assert not orderings.get(obj, None) - return rval - - topo = general_toposort( - outputs, - deps=compute_deps, - compute_deps_cache=None, - deps_cache=None, - clients=clients, - ) - return [o for o in topo if isinstance(o, Apply)] - - -default_leaf_formatter = str - - def default_node_formatter(op, argstrings): return f"{op.op}({', '.join(argstrings)})" -def io_connection_pattern(inputs, outputs): - """Return the connection pattern of a subgraph defined by given inputs and outputs.""" - inner_nodes = io_toposort(inputs, outputs) - - # Initialize 'connect_pattern_by_var' by establishing each input as - # connected only to itself - connect_pattern_by_var = {} - nb_inputs = len(inputs) - - for i in range(nb_inputs): - input = inputs[i] - inp_connection_pattern = [i == j for j in range(nb_inputs)] - connect_pattern_by_var[input] = inp_connection_pattern - - # Iterate through the nodes used to produce the outputs from the - # inputs and, for every node, infer their connection pattern to - # every input from the connection patterns of their parents. - for n in inner_nodes: - # Get the connection pattern of the inner node's op. If the op - # does not define a connection_pattern method, assume that - # every node output is connected to every node input - try: - op_connection_pattern = n.op.connection_pattern(n) - except AttributeError: - op_connection_pattern = [[True] * len(n.outputs)] * len(n.inputs) - - # For every output of the inner node, figure out which inputs it - # is connected to by combining the connection pattern of the inner - # node and the connection patterns of the inner node's inputs. - for out_idx in range(len(n.outputs)): - out = n.outputs[out_idx] - out_connection_pattern = [False] * nb_inputs - - for inp_idx in range(len(n.inputs)): - inp = n.inputs[inp_idx] - - if inp in connect_pattern_by_var: - inp_connection_pattern = connect_pattern_by_var[inp] - - # If the node output is connected to the node input, it - # means it is connected to every inner input that the - # node inputs is connected to - if op_connection_pattern[inp_idx][out_idx]: - out_connection_pattern = [ - out_connection_pattern[i] or inp_connection_pattern[i] - for i in range(nb_inputs) - ] - - # Store the connection pattern of the node output - connect_pattern_by_var[out] = out_connection_pattern - - # Obtain the global connection pattern by combining the - # connection patterns of the individual outputs - global_connection_pattern = [[] for o in range(len(inputs))] - for out in outputs: - out_connection_pattern = connect_pattern_by_var.get(out) - if out_connection_pattern is None: - # the output is completely isolated from inputs - out_connection_pattern = [False] * len(inputs) - for i in range(len(inputs)): - global_connection_pattern[i].append(out_connection_pattern[i]) - - return global_connection_pattern - - -def op_as_string( - i, op, leaf_formatter=default_leaf_formatter, node_formatter=default_node_formatter -): +def op_as_string(i, op, leaf_formatter=str, node_formatter=default_node_formatter): """Return a function that returns a string representation of the subgraph between `i` and :attr:`op.inputs`""" strs = as_string(i, op.inputs, leaf_formatter, node_formatter) return node_formatter(op, strs) @@ -1709,7 +1035,7 @@ def op_as_string( def as_string( inputs: list[Variable], outputs: list[Variable], - leaf_formatter=default_leaf_formatter, + leaf_formatter=str, node_formatter=default_node_formatter, ) -> list[str]: r"""Returns a string representation of the subgraph between `inputs` and `outputs`. @@ -1737,6 +1063,8 @@ def as_string( viewing convenience). """ + from pytensor.graph.traversal import applys_between, orphans_between + i = set(inputs) orph = list(orphans_between(i, outputs)) @@ -1787,82 +1115,6 @@ def describe(r): return [describe(output) for output in outputs] -def view_roots(node: Variable) -> list[Variable]: - """Return the leaves from a search through consecutive view-maps.""" - owner = node.owner - if owner is not None: - try: - vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()} - except AttributeError: - return [node] - if node in vars_to_views: - answer = [] - for i in vars_to_views[node]: - answer += view_roots(owner.inputs[i]) - return answer - else: - return [node] - else: - return [node] - - -def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool: - """Determine if any `depends_on` is in the graph given by ``apply``. - - Parameters - ---------- - apply : Apply - The Apply node to check. - depends_on : Union[Apply, Collection[Apply]] - Apply nodes to check dependency on - - Returns - ------- - bool - - """ - computed = set() - todo = [apply] - if not isinstance(depends_on, Collection): - depends_on = {depends_on} - else: - depends_on = set(depends_on) - while todo: - cur = todo.pop() - if cur.outputs[0] in computed: - continue - if all(i in computed or i.owner is None for i in cur.inputs): - computed.update(cur.outputs) - if cur in depends_on: - return True - else: - todo.append(cur) - todo.extend(i.owner for i in cur.inputs if i.owner) - return False - - -def variable_depends_on( - variable: Variable, depends_on: Variable | Collection[Variable] -) -> bool: - """Determine if any `depends_on` is in the graph given by ``variable``. - Parameters - ---------- - variable: Variable - Node to check - depends_on: Collection[Variable] - Nodes to check dependency on - - Returns - ------- - bool - """ - if not isinstance(depends_on, Collection): - depends_on = {depends_on} - else: - depends_on = set(depends_on) - return any(interim in depends_on for interim in ancestors([variable])) - - def equal_computations( xs: list[np.ndarray | Variable], ys: list[np.ndarray | Variable], @@ -2029,78 +1281,3 @@ def compare_nodes(nd_x, nd_y, common, different): return False return True - - -def get_var_by_name( - graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR" -) -> tuple[Variable, ...]: - r"""Get variables in a graph using their names. - - Parameters - ---------- - graphs: - The graph, or graphs, to search. - target_var_id: - The name to match against either ``Variable.name`` or - ``Variable.auto_name``. - - Returns - ------- - A ``tuple`` containing all the `Variable`\s that match `target_var_id`. - - """ - from pytensor.graph.op import HasInnerGraph - - def expand(r) -> list[Variable] | None: - if not r.owner: - return None - - res = list(r.owner.inputs) - - if isinstance(r.owner.op, HasInnerGraph): - res.extend(r.owner.op.inner_outputs) - - return res - - results: tuple[Variable, ...] = () - for var in walk(graphs, expand, False): - var = cast(Variable, var) - if target_var_id == var.name or target_var_id == var.auto_name: - results += (var,) - - return results - - -def replace_nominals_with_dummies(inputs, outputs): - """Replace nominal inputs with dummy variables. - - When constructing a new graph with nominal inputs from an existing graph, - pre-existing nominal inputs need to be replaced with dummy variables - beforehand; otherwise, sequential ID ordering (i.e. when nominals are IDed - based on the ordered inputs to which they correspond) of the nominals could - be broken, and/or circular replacements could manifest. - - FYI: This function assumes that all the nominal variables in the subgraphs - between `inputs` and `outputs` are present in `inputs`. - - """ - existing_nominal_replacements = { - i: i.type() for i in inputs if isinstance(i, NominalVariable) - } - - if existing_nominal_replacements: - # Replace existing nominal variables, because we need to produce an - # inner-graph for which the nominal variable IDs correspond exactly - # to their input order - _ = clone_get_equiv( - inputs, - outputs, - copy_inputs=False, - copy_orphans=False, - memo=existing_nominal_replacements, - ) - - outputs = [existing_nominal_replacements[o] for o in outputs] - inputs = [existing_nominal_replacements[i] for i in inputs] - - return inputs, outputs diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 74744d6732..a6e5b870c7 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -9,7 +9,8 @@ import pytensor from pytensor.configdefaults import config -from pytensor.graph.basic import Variable, io_toposort +from pytensor.graph.basic import Variable +from pytensor.graph.traversal import toposort from pytensor.graph.utils import InconsistencyError @@ -339,11 +340,11 @@ def clone(self): class Bookkeeper(Feature): def on_attach(self, fgraph): - for node in io_toposort(fgraph.inputs, fgraph.outputs): + for node in toposort(fgraph.outputs): self.on_import(fgraph, node, "on_attach") def on_detach(self, fgraph): - for node in io_toposort(fgraph.inputs, fgraph.outputs): + for node in toposort(fgraph.outputs): self.on_prune(fgraph, node, "Bookkeeper.detach") diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index bdaefc42f8..b97068da0f 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -11,17 +11,19 @@ Apply, AtomicVariable, Variable, - applys_between, clone_get_equiv, - graph_inputs, - io_toposort, - vars_between, ) from pytensor.graph.basic import as_string as graph_as_string from pytensor.graph.features import AlreadyThere, Feature, ReplaceValidate from pytensor.graph.op import Op +from pytensor.graph.traversal import ( + applys_between, + graph_inputs, + toposort, + toposort_with_orderings, + vars_between, +) from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError -from pytensor.misc.ordered_set import OrderedSet ClientType = tuple[Apply, int] @@ -130,7 +132,6 @@ def __init__( features = [] self._features: list[Feature] = [] - # All apply nodes in the subgraph defined by inputs and # outputs are cached in this field self.apply_nodes: set[Apply] = set() @@ -158,7 +159,8 @@ def __init__( "input's owner or use graph.clone." ) - self.add_input(in_var, check=False) + self.inputs.append(in_var) + self.clients.setdefault(in_var, []) for output in outputs: self.add_output(output, reason="init") @@ -186,16 +188,6 @@ def add_input(self, var: Variable, check: bool = True) -> None: return self.inputs.append(var) - self.setup_var(var) - - def setup_var(self, var: Variable) -> None: - """Set up a variable so it belongs to this `FunctionGraph`. - - Parameters - ---------- - var : pytensor.graph.basic.Variable - - """ self.clients.setdefault(var, []) def get_clients(self, var: Variable) -> list[ClientType]: @@ -319,10 +311,11 @@ def import_var( """ # Imports the owners of the variables - if var.owner and var.owner not in self.apply_nodes: - self.import_node(var.owner, reason=reason, import_missing=import_missing) + apply = var.owner + if apply is not None and apply not in self.apply_nodes: + self.import_node(apply, reason=reason, import_missing=import_missing) elif ( - var.owner is None + apply is None and not isinstance(var, AtomicVariable) and var not in self.inputs ): @@ -333,10 +326,11 @@ def import_var( f"Computation graph contains a NaN. {var.type.why_null}" ) if import_missing: - self.add_input(var) + self.inputs.append(var) + self.clients.setdefault(var, []) else: raise MissingInputError(f"Undeclared input: {var}", variable=var) - self.setup_var(var) + self.clients.setdefault(var, []) self.variables.add(var) def import_node( @@ -353,29 +347,29 @@ def import_node( apply_node : Apply The node to be imported. check : bool - Check that the inputs for the imported nodes are also present in - the `FunctionGraph`. + Check that the inputs for the imported nodes are also present in the `FunctionGraph`. reason : str The name of the optimization or operation in progress. import_missing : bool Add missing inputs instead of raising an exception. """ # We import the nodes in topological order. We only are interested in - # new nodes, so we use all variables we know of as if they were the - # input set. (The functions in the graph module only use the input set - # to know where to stop going down.) - new_nodes = io_toposort(self.variables, apply_node.outputs) - - if check: - for node in new_nodes: + # new nodes, so we use all nodes we know of as inputs to interrupt the toposort + self_variables = self.variables + self_clients = self.clients + self_apply_nodes = self.apply_nodes + self_inputs = self.inputs + for node in toposort(apply_node.outputs, blockers=self_variables): + if check: for var in node.inputs: if ( var.owner is None and not isinstance(var, AtomicVariable) - and var not in self.inputs + and var not in self_inputs ): if import_missing: - self.add_input(var) + self_inputs.append(var) + self_clients.setdefault(var, []) else: error_msg = ( f"Input {node.inputs.index(var)} ({var})" @@ -387,20 +381,20 @@ def import_node( ) raise MissingInputError(error_msg, variable=var) - for node in new_nodes: - assert node not in self.apply_nodes - self.apply_nodes.add(node) - if not hasattr(node.tag, "imported_by"): - node.tag.imported_by = [] - node.tag.imported_by.append(str(reason)) + self_apply_nodes.add(node) + tag = node.tag + if not hasattr(tag, "imported_by"): + tag.imported_by = [str(reason)] + else: + tag.imported_by.append(str(reason)) for output in node.outputs: - self.setup_var(output) - self.variables.add(output) - for i, input in enumerate(node.inputs): - if input not in self.variables: - self.setup_var(input) - self.variables.add(input) - self.add_client(input, (node, i)) + self_clients.setdefault(output, []) + self_variables.add(output) + for i, inp in enumerate(node.inputs): + if inp not in self_variables: + self_clients.setdefault(inp, []) + self_variables.add(inp) + self_clients[inp].append((node, i)) self.execute_callbacks("on_import", node, reason) def change_node_input( @@ -454,7 +448,7 @@ def change_node_input( self.outputs[node.op.idx] = new_var self.import_var(new_var, reason=reason, import_missing=import_missing) - self.add_client(new_var, (node, i)) + self.clients[new_var].append((node, i)) self.remove_client(r, (node, i), reason=reason) # Precondition: the substitution is semantically valid However it may # introduce cycles to the graph, in which case the transaction will be @@ -753,11 +747,7 @@ def toposort(self) -> list[Apply]: :meth:`FunctionGraph.orderings`. """ - if len(self.apply_nodes) < 2: - # No sorting is necessary - return list(self.apply_nodes) - - return io_toposort(self.inputs, self.outputs, self.orderings()) + return list(toposort_with_orderings(self.outputs, orderings=self.orderings())) def orderings(self) -> dict[Apply, list[Apply]]: """Return a map of node to node evaluation dependencies. @@ -776,29 +766,17 @@ def orderings(self) -> dict[Apply, list[Apply]]: take care of computing the dependencies by itself. """ - assert isinstance(self._features, list) - all_orderings: list[dict] = [] + all_orderings: list[dict] = [ + orderings + for feature in self._features + if ( + hasattr(feature, "orderings") and (orderings := feature.orderings(self)) + ) + ] - for feature in self._features: - if hasattr(feature, "orderings"): - orderings = feature.orderings(self) - if not isinstance(orderings, dict): - raise TypeError( - "Non-deterministic return value from " - + str(feature.orderings) - + ". Nondeterministic object is " - + str(orderings) - ) - if len(orderings) > 0: - all_orderings.append(orderings) - for node, prereqs in orderings.items(): - if not isinstance(prereqs, list | OrderedSet): - raise TypeError( - "prereqs must be a type with a " - "deterministic iteration order, or toposort " - " will be non-deterministic." - ) - if len(all_orderings) == 1: + if not all_orderings: + return {} + elif len(all_orderings) == 1: # If there is only 1 ordering, we reuse it directly. return all_orderings[0].copy() else: diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 3a00922c87..ab83230d6e 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -14,6 +14,7 @@ import pytensor from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable +from pytensor.graph.traversal import io_toposort from pytensor.graph.utils import ( MetaObject, TestValueError, @@ -753,3 +754,68 @@ def get_test_values(*args: Variable) -> Any | list[Any]: return rval return [tuple(rval)] + + +def io_connection_pattern(inputs, outputs): + """Return the connection pattern of a subgraph defined by given inputs and outputs.""" + inner_nodes = io_toposort(inputs, outputs) + + # Initialize 'connect_pattern_by_var' by establishing each input as + # connected only to itself + connect_pattern_by_var = {} + nb_inputs = len(inputs) + + for i in range(nb_inputs): + input = inputs[i] + inp_connection_pattern = [i == j for j in range(nb_inputs)] + connect_pattern_by_var[input] = inp_connection_pattern + + # Iterate through the nodes used to produce the outputs from the + # inputs and, for every node, infer their connection pattern to + # every input from the connection patterns of their parents. + for n in inner_nodes: + # Get the connection pattern of the inner node's op. If the op + # does not define a connection_pattern method, assume that + # every node output is connected to every node input + try: + op_connection_pattern = n.op.connection_pattern(n) + except AttributeError: + op_connection_pattern = [[True] * len(n.outputs)] * len(n.inputs) + + # For every output of the inner node, figure out which inputs it + # is connected to by combining the connection pattern of the inner + # node and the connection patterns of the inner node's inputs. + for out_idx in range(len(n.outputs)): + out = n.outputs[out_idx] + out_connection_pattern = [False] * nb_inputs + + for inp_idx in range(len(n.inputs)): + inp = n.inputs[inp_idx] + + if inp in connect_pattern_by_var: + inp_connection_pattern = connect_pattern_by_var[inp] + + # If the node output is connected to the node input, it + # means it is connected to every inner input that the + # node inputs is connected to + if op_connection_pattern[inp_idx][out_idx]: + out_connection_pattern = [ + out_connection_pattern[i] or inp_connection_pattern[i] + for i in range(nb_inputs) + ] + + # Store the connection pattern of the node output + connect_pattern_by_var[out] = out_connection_pattern + + # Obtain the global connection pattern by combining the + # connection patterns of the individual outputs + global_connection_pattern = [[] for o in range(len(inputs))] + for out in outputs: + out_connection_pattern = connect_pattern_by_var.get(out) + if out_connection_pattern is None: + # the output is completely isolated from inputs + out_connection_pattern = [False] * len(inputs) + for i in range(len(inputs)): + global_connection_pattern[i].append(out_connection_pattern[i]) + + return global_connection_pattern diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 6cb46b6301..bb49245ebe 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -7,11 +7,13 @@ Apply, Constant, Variable, - io_toposort, - truncated_graph_inputs, ) from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.graph.traversal import ( + toposort, + truncated_graph_inputs, +) ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable] @@ -296,7 +298,7 @@ def vectorize_graph( new_inputs = [replace.get(inp, inp) for inp in inputs] vect_vars = dict(zip(inputs, new_inputs, strict=True)) - for node in io_toposort(inputs, seq_outputs): + for node in toposort(seq_outputs, blockers=inputs): vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] vect_node = vectorize_node(node, *vect_inputs) for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 66d5f844b1..d03be5e707 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -22,14 +22,17 @@ AtomicVariable, Constant, Variable, - applys_between, - io_toposort, - vars_between, ) from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars +from pytensor.graph.traversal import ( + apply_ancestors, + applys_between, + toposort, + vars_between, +) from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet from pytensor.utils import flatten @@ -1821,12 +1824,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter): def __init__( self, node_rewriter: NodeRewriter, - order: Literal["out_to_in", "in_to_out"] = "in_to_out", + order: Literal["out_to_in", "in_to_out", "dfs"] = "in_to_out", ignore_newtrees: bool = False, failure_callback: FailureCallbackType | None = None, ): - if order not in ("out_to_in", "in_to_out"): - raise ValueError("order must be 'out_to_in' or 'in_to_out'") + valid_orders = ("out_to_in", "in_to_out", "dfs") + if order not in valid_orders: + raise ValueError(f"order must be one of {valid_orders}, got {order}") self.order = order super().__init__(node_rewriter, ignore_newtrees, failure_callback) @@ -1836,7 +1840,11 @@ def apply(self, fgraph, start_from=None): callback_before = fgraph.execute_callbacks_time nb_nodes_start = len(fgraph.apply_nodes) t0 = time.perf_counter() - q = deque(io_toposort(fgraph.inputs, start_from)) + q = deque( + apply_ancestors(start_from) + if (self.order == "dfs") + else toposort(start_from) + ) io_t = time.perf_counter() - t0 def importer(node): @@ -1959,6 +1967,7 @@ def walking_rewriter( in2out = partial(walking_rewriter, "in_to_out") out2in = partial(walking_rewriter, "out_to_in") +dfs_rewriter = partial(walking_rewriter, "dfs") class ChangeTracker(Feature): @@ -2166,7 +2175,7 @@ def apply_cleanup(profs_dict): changed |= apply_cleanup(iter_cleanup_sub_profs) topo_t0 = time.perf_counter() - q = deque(io_toposort(fgraph.inputs, start_from)) + q = deque(toposort(start_from)) io_toposort_timing.append(time.perf_counter() - topo_t0) nb_nodes.append(len(q)) diff --git a/pytensor/graph/rewriting/utils.py b/pytensor/graph/rewriting/utils.py index a8acc89e1c..d345068a99 100644 --- a/pytensor/graph/rewriting/utils.py +++ b/pytensor/graph/rewriting/utils.py @@ -7,11 +7,10 @@ Apply, Variable, equal_computations, - graph_inputs, - vars_between, ) from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.graph.traversal import graph_inputs, vars_between if TYPE_CHECKING: diff --git a/pytensor/graph/traversal.py b/pytensor/graph/traversal.py new file mode 100644 index 0000000000..08b0bc2fd8 --- /dev/null +++ b/pytensor/graph/traversal.py @@ -0,0 +1,773 @@ +from collections import deque +from collections.abc import ( + Callable, + Generator, + Iterable, + Reversible, + Sequence, +) +from typing import ( + Literal, + TypeVar, + overload, +) + +from pytensor.graph.basic import Apply, Constant, Node, Variable + + +T = TypeVar("T", bound=Node) +NodeAndChildren = tuple[T, Iterable[T] | None] + + +@overload +def walk( + nodes: Iterable[T], + expand: Callable[[T], Iterable[T] | None], + bfs: bool = True, + return_children: Literal[False] = False, +) -> Generator[T, None, None]: ... + + +@overload +def walk( + nodes: Iterable[T], + expand: Callable[[T], Iterable[T] | None], + bfs: bool, + return_children: Literal[True], +) -> Generator[NodeAndChildren, None, None]: ... + + +def walk( + nodes: Iterable[T], + expand: Callable[[T], Iterable[T] | None], + bfs: bool = True, + return_children: bool = False, +) -> Generator[T | NodeAndChildren, None, None]: + r"""Walk through a graph, either breadth- or depth-first. + + Parameters + ---------- + nodes + The nodes from which to start walking. + expand + A callable that is applied to each node in `nodes`, the results of + which are either new nodes to visit or ``None``. + bfs + If ``True``, breath first search is used; otherwise, depth first + search. + return_children + If ``True``, each output node will be accompanied by the output of + `expand` (i.e. the corresponding child nodes). + + Notes + ----- + A node will appear at most once in the return value, even if it + appears multiple times in the `nodes` parameter. + + """ + + rval_set: set[T] = set() + nodes = deque(nodes) + nodes_pop: Callable[[], T] = nodes.popleft if bfs else nodes.pop + node: T + new_nodes: Iterable[T] | None + try: + if return_children: + while True: + node = nodes_pop() + if node not in rval_set: + new_nodes = expand(node) + yield node, new_nodes + rval_set.add(node) + if new_nodes: + nodes.extend(new_nodes) + else: + while True: + node = nodes_pop() + if node not in rval_set: + yield node + rval_set.add(node) + new_nodes = expand(node) + if new_nodes: + nodes.extend(new_nodes) + except IndexError: + return None + + +def ancestors( + graphs: Iterable[Variable], + blockers: Iterable[Variable] | None = None, +) -> Generator[Variable, None, None]: + r"""Return the variables that contribute to those in given graphs (inclusive), stopping at blockers. + + Parameters + ---------- + graphs : list of `Variable` instances + Output `Variable` instances from which to search backward through + owners. + blockers : list of `Variable` instances + A collection of `Variable`\s that, when found, prevent the graph search + from preceding from that point. + + Yields + ------ + `Variable`\s + All ancestor variables, in the order found by a right-recursive depth-first search + started at the variables in `graphs`. + """ + + seen = set() + queue = list(graphs) + try: + if blockers: + blockers = frozenset(blockers) + while True: + if (var := queue.pop()) not in seen: + yield var + seen.add(var) + if var not in blockers and (apply := var.owner) is not None: + queue.extend(apply.inputs) + else: + while True: + if (var := queue.pop()) not in seen: + yield var + seen.add(var) + if (apply := var.owner) is not None: + queue.extend(apply.inputs) + except IndexError: + return + + +variable_ancestors = ancestors + + +def apply_ancestors( + graphs: Iterable[Variable], + blockers: Iterable[Variable] | None = None, +) -> Generator[Apply, None, None]: + """Return the Apply nodes that contribute to those in given graphs (inclusive).""" + seen = {None} # This filters out Variables without an owner + for var in ancestors(graphs, blockers): + # For multi-output nodes, we'll see multiple variables + # but we should only yield the Apply once + if (apply := var.owner) not in seen: + yield apply + seen.add(apply) + return + + +def graph_inputs( + graphs: Iterable[Variable], blockers: Iterable[Variable] | None = None +) -> Generator[Variable, None, None]: + r"""Return the inputs required to compute the given Variables. + + Parameters + ---------- + graphs : list of `Variable` instances + Output `Variable` instances from which to search backward through + owners. + blockers : list of `Variable` instances + A collection of `Variable`\s that, when found, prevent the graph search + from preceding from that point. + + Yields + ------ + Input nodes with no owner, in the order found by a breath first search started at the nodes in `graphs`. + + """ + yield from (var for var in ancestors(graphs, blockers) if var.owner is None) + + +def explicit_graph_inputs( + graph: Variable | Iterable[Variable], +) -> Generator[Variable, None, None]: + """ + Get the root variables needed as inputs to a function that computes `graph` + + Parameters + ---------- + graph : TensorVariable + Output `Variable` instances for which to search backward through + owners. + + Returns + ------- + iterable + Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`. + + Examples + -------- + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + from pytensor.graph.traversal import explicit_graph_inputs + + x = pt.vector("x") + y = pt.constant(2) + z = pt.mul(x * y) + + inputs = list(explicit_graph_inputs(z)) + f = pytensor.function(inputs, z) + eval = f([1, 2, 3]) + + print(eval) + # [2. 4. 6.] + """ + from pytensor.compile.sharedvalue import SharedVariable + + if isinstance(graph, Variable): + graph = (graph,) + + return ( + var + for var in ancestors(graph) + if var.owner is None and not isinstance(var, Constant | SharedVariable) + ) + + +def vars_between( + ins: Iterable[Variable], outs: Iterable[Variable] +) -> Generator[Variable, None, None]: + r"""Extract the `Variable`\s within the sub-graph between input and output nodes. + + Notes + ----- + This function is like ancestors(outs, blockers=ins), + except it can also yield disconnected output variables from multi-output apply nodes. + + Parameters + ---------- + ins + Input `Variable`\s. + outs + Output `Variable`\s. + + Yields + ------ + The `Variable`\s that are involved in the subgraph that lies + between `ins` and `outs`. This includes `ins`, `outs`, + ``orphans_between(ins, outs)`` and all values of all intermediary steps from + `ins` to `outs`. + + """ + + def expand(var: Variable, ins=frozenset(ins)) -> Iterable[Variable] | None: + if var.owner is not None and var not in ins: + return (*var.owner.inputs, *var.owner.outputs) + return None + + # With bfs = False, it iterates similarly to ancestors + yield from walk(outs, expand, bfs=False) + + +def orphans_between( + ins: Iterable[Variable], outs: Iterable[Variable] +) -> Generator[Variable, None, None]: + r"""Extract the root `Variable`\s not within the sub-graph between input and output nodes. + + Parameters + ---------- + ins : list + Input `Variable`\s. + outs : list + Output `Variable`\s. + + Yields + ------- + Variable + The `Variable`\s upon which one or more `Variable`\s in `outs` + depend, but are neither in `ins` nor in the sub-graph that lies between + them. + + Examples + -------- + >>> from pytensor.graph.traversal import orphans_between + >>> from pytensor.tensor import scalars + >>> x, y = scalars("xy") + >>> list(orphans_between([x], [(x + y)])) + [y] + + """ + ins = frozenset(ins) + yield from ( + var + for var in vars_between(ins, outs) + if ((var.owner is None) and (var not in ins)) + ) + + +def applys_between( + ins: Iterable[Variable], outs: Iterable[Variable] +) -> Generator[Apply, None, None]: + r"""Extract the `Apply`\s contained within the sub-graph between given input and output variables. + + Notes + ----- + This is identical to apply_ancestors(outs, blockers=ins) + + Parameters + ---------- + ins : list + Input `Variable`\s. + outs : list + Output `Variable`\s. + + Yields + ------ + The `Apply`\s that are contained within the sub-graph that lies + between `ins` and `outs`, including the owners of the `Variable`\s in + `outs` and intermediary `Apply`\s between `ins` and `outs`, but not the + owners of the `Variable`\s in `ins`. + + """ + return apply_ancestors(outs, blockers=ins) + + +def apply_depends_on(apply: Apply, depends_on: Apply | Iterable[Apply]) -> bool: + """Determine if any `depends_on` is in the graph given by ``apply``. + + Parameters + ---------- + apply : Apply + The Apply node to check. + depends_on : Union[Apply, Collection[Apply]] + Apply nodes to check dependency on + + Returns + ------- + bool + + """ + if isinstance(depends_on, Apply): + depends_on = frozenset((depends_on,)) + else: + depends_on = frozenset(depends_on) + return (apply in depends_on) or any( + apply in depends_on for apply in apply_ancestors(apply.inputs) + ) + + +def variable_depends_on( + variable: Variable, depends_on: Variable | Iterable[Variable] +) -> bool: + """Determine if any `depends_on` is in the graph given by ``variable``. + + Notes + ----- + The interpretation of dependency is done at a variable level. + A variable may depend on some output variables from a multi-output apply node but not others. + + + Parameters + ---------- + variable: Variable + T to check + depends_on: Iterable[Variable] + Nodes to check dependency on + + Returns + ------- + bool + """ + if isinstance(depends_on, Variable): + depends_on_set = frozenset((depends_on,)) + else: + depends_on_set = frozenset(depends_on) + return any(var in depends_on_set for var in variable_ancestors([variable])) + + +def truncated_graph_inputs( + outputs: Sequence[Variable], + ancestors_to_include: Iterable[Variable] | None = None, +) -> list[Variable]: + """Get the truncate graph inputs. + + Unlike :func:`graph_inputs` this function will return + the closest variables to outputs that do not depend on + ``ancestors_to_include``. So given all the returned + variables provided there is no missing variable to + compute the output and all variables are independent + from each other. + + Parameters + ---------- + outputs : Iterable[Variable] + Variable to get conditions for + ancestors_to_include : Optional[Iterable[Variable]] + Additional ancestors to assume, by default None + + Returns + ------- + List[Variable] + Variables required to compute ``outputs`` + + Examples + -------- + The returned variables marked in (parenthesis), ancestors variables are ``c``, output variables are ``o`` + + * No ancestors to include + + .. code-block:: + + n - n - (o) + + * One ancestors to include + + .. code-block:: + + n - (c) - o + + * Two ancestors to include where on depends on another, both returned + + .. code-block:: + + (c) - (c) - o + + * Additional variables are present + + .. code-block:: + + (c) - n - o + n - (n) -' + + * Disconnected ancestors to include not returned + + .. code-block:: + + (c) - n - o + c + + * Disconnected output is present and returned + + .. code-block:: + + (c) - (c) - o + (o) + + * ancestors to include that include itself adds itself + + .. code-block:: + + n - (c) - (o/c) + + """ + truncated_inputs: list[Variable] = list() + seen: set[Variable] = set() + + # simple case, no additional ancestors to include + if not ancestors_to_include: + # just filter out unique variables + for variable in outputs: + if variable not in seen: + seen.add(variable) + truncated_inputs.append(variable) + return truncated_inputs + + # blockers have known independent variables and ancestors to include + blockers: set[Variable] = set(ancestors_to_include) + # enforce O(1) check for variable in ancestors to include + ancestors_to_include = blockers.copy() + candidates = list(outputs) + try: + while True: + if (variable := candidates.pop()) not in seen: + seen.add(variable) + # check if the variable is independent, never go above blockers; + # blockers are independent variables and ancestors to include + if variable in ancestors_to_include: + # ancestors to include that are present in the graph (not disconnected) + # should be added to truncated_inputs + truncated_inputs.append(variable) + # if the ancestors to include is still dependent on other ancestors we need to go above, + # FIXME: This seems wrong? The other ancestors above are either redundant given this variable, + # or another path leads to them and the special casing isn't needed + # It seems the only reason we are expanding on these inputs is to find other ancestors_to_include + # (instead of treating them as disconnected), but this may yet cause other unrelated variables + # to become "independent" in the process + if variable_depends_on(variable, ancestors_to_include - {variable}): + # owner can never be None for a dependent variable + candidates.extend( + n for n in variable.owner.inputs if n not in seen + ) + else: + # A regular variable to check + # if we've found an independent variable and it is not in blockers so far + # it is a new independent variable not present in ancestors to include + if variable_depends_on(variable, blockers): + # If it's not an independent variable, inputs become candidates + candidates.extend(variable.owner.inputs) + else: + # otherwise it's a truncated input itself + truncated_inputs.append(variable) + # all regular variables fall to blockers + # 1. it is dependent - we already expanded on the inputs, nothing to do if we find it again + # 2. it is independent - this is a truncated input, search for other nodes can stop here + blockers.add(variable) + except IndexError: # pop from an empty list + pass + + return truncated_inputs + + +def walk_toposort( + graphs: Iterable[T], + deps: Callable[[T], Iterable[T] | None], +) -> Generator[T, None, None]: + """Perform a topological sort of all nodes starting from a given node. + + Parameters + ---------- + graphs: + An iterable of nodes from which to start the topological sort. + deps : callable + A Python function that takes a node as input and returns its dependence. + + Notes + ----- + + ``deps(i)`` should behave like a pure function (no funny business with internal state). + + The order of the return value list is determined by the order of nodes + returned by the `deps` function. + """ + + # Cache the dependencies (ancestors) as we iterate over the nodes with the deps function + deps_cache: dict[T, list[T]] = {} + + def compute_deps_cache(obj, deps_cache=deps_cache): + if obj in deps_cache: + return deps_cache[obj] + d = deps_cache[obj] = deps(obj) or [] + return d + + clients: dict[T, list[T]] = {} + sources: deque[T] = deque() + total_nodes = 0 + for node, children in walk( + graphs, compute_deps_cache, bfs=False, return_children=True + ): + total_nodes += 1 + # Mypy doesn't know that toposort will not return `None` because of our `or []` in the `compute_deps_cache` + for child in children: # type: ignore + clients.setdefault(child, []).append(node) + if not deps_cache[node]: + # Add nodes without dependencies to the stack + sources.append(node) + + rset: set[T] = set() + try: + while True: + if (node := sources.popleft()) not in rset: + yield node + total_nodes -= 1 + rset.add(node) + # Iterate over each client node (that is, it depends on the current node) + for client in clients.get(node, []): + # Remove itself from the dependent (ancestor) list of each client + d = deps_cache[client] = [ + a for a in deps_cache[client] if a is not node + ] + if not d: + # If there are no dependencies left to visit for this node, add it to the stack + sources.append(client) + except IndexError: + pass + + if total_nodes != 0: + raise ValueError("graph contains cycles") + + +def general_toposort( + outputs: Iterable[T], + deps: Callable[[T], Iterable[T] | None], + compute_deps_cache: Callable[[T], Iterable[T] | None] | None = None, + deps_cache: dict[T, list[T]] | None = None, + clients: dict[T, list[T]] | None = None, +) -> list[T]: + """Perform a topological sort of all nodes starting from a given node. + + Parameters + ---------- + deps : callable + A Python function that takes a node as input and returns its dependence. + compute_deps_cache : optional + If provided, `deps_cache` should also be provided. This is a function like + `deps`, but that also caches its results in a ``dict`` passed as `deps_cache`. + + Notes + ----- + This is a simple wrapper around `walk_toposort` for backwards compatibility + + ``deps(i)`` should behave like a pure function (no funny business with + internal state). + + The order of the return value list is determined by the order of nodes + returned by the `deps` function. + """ + # TODO: Deprecate me later + if compute_deps_cache is not None: + raise ValueError("compute_deps_cache is no longer supported") + if deps_cache is not None: + raise ValueError("deps_cache is no longer supported") + if clients is not None: + raise ValueError("clients is no longer supported") + return list(walk_toposort(outputs, deps)) + + +def toposort( + graphs: Iterable[Variable], + blockers: Iterable[Variable] | None = None, +) -> Generator[Apply, None, None]: + """Topologically sort of Apply nodes between graphs (outputs) and blockers (inputs). + + This is a streamlined version of `io_toposort_generator` when no additional ordering + constraints are needed. + """ + + # We can put blocker variables in computed, as we only return apply nodes + computed = set(blockers or ()) + todo = list(graphs) + try: + while True: + if (cur := todo.pop()) not in computed and (apply := cur.owner) is not None: + uncomputed_inputs = tuple( + i + for i in apply.inputs + if (i not in computed and i.owner is not None) + ) + if not uncomputed_inputs: + yield apply + computed.update(apply.outputs) + else: + todo.append(cur) + todo.extend(uncomputed_inputs) + except IndexError: # queue is empty + return + + +def toposort_with_orderings( + graphs: Iterable[Variable], + *, + blockers: Iterable[Variable] | None = None, + orderings: dict[Apply, list[Apply]] | None = None, +) -> Generator[Apply, None, None]: + """Perform topological of nodes between blocker (input) and graphs (output) variables with arbitrary extra orderings + + Extra orderings can be used to force sorting of variables that are not naturally related in the graph. + This can be used by inplace optimizations to ensure a variable is only destroyed after all other uses. + Those other uses show up as dependencies of the destroying node, in the orderings dictionary. + + + Parameters + ---------- + graphs : list or tuple of Variable instances + Graph inputs. + outputs : list or tuple of Apply instances + Graph outputs. + orderings : dict + Keys are `Apply` or `Variable` instances, values are lists of `Apply` or `Variable` instances. + + """ + if not orderings: + # Faster branch + yield from toposort(graphs, blockers=blockers) + + else: + # the inputs are used to decide where to stop expanding + if blockers: + + def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings): + if obj in blocker_set: + return None + if isinstance(obj, Apply): + return [*obj.inputs, *orderings.get(obj, [])] + else: + if (apply := obj.owner) is not None: + return [apply, *orderings.get(apply, [])] + else: + return orderings.get(obj, []) + else: + # mypy doesn't like conditional functions with different signatures, + # but passing the globals as optional is faster + def compute_deps(obj, orderings=orderings): # type: ignore[misc] + if isinstance(obj, Apply): + return [*obj.inputs, *orderings.get(obj, [])] + else: + if (apply := obj.owner) is not None: + return [apply, *orderings.get(apply, [])] + else: + return orderings.get(obj, []) + + yield from ( + apply + for apply in walk_toposort(graphs, deps=compute_deps) + # mypy doesn't understand that our generator will return both Apply and Variables + if isinstance(apply, Apply) # type: ignore + ) + + +def io_toposort( + inputs: Iterable[Variable], + outputs: Reversible[Variable], + orderings: dict[Apply, list[Apply]] | None = None, + clients: dict[Variable, list[Variable]] | None = None, +) -> list[Apply]: + """Perform topological of nodes between input and output variables. + + Notes + ----- + This is just a wrapper around `toposort_with_extra_orderings` for backwards compatibility + + Parameters + ---------- + inputs : list or tuple of Variable instances + Graph inputs. + outputs : list or tuple of Apply instances + Graph outputs. + orderings : dict + Keys are `Apply` instances, values are lists of `Apply` instances. + """ + # TODO: Deprecate me later + if clients is not None: + raise ValueError("clients is no longer supported") + + return list(toposort_with_orderings(outputs, blockers=inputs, orderings=orderings)) + + +def get_var_by_name( + graphs: Iterable[Variable], target_var_id: str +) -> tuple[Variable, ...]: + r"""Get variables in a graph using their names. + + Parameters + ---------- + graphs: + The graph, or graphs, to search. + target_var_id: + The name to match against either ``Variable.name`` or + ``Variable.auto_name``. + + Returns + ------- + A ``tuple`` containing all the `Variable`\s that match `target_var_id`. + + """ + from pytensor.graph.op import HasInnerGraph + + def expand(r: Variable) -> list[Variable] | None: + if (apply := r.owner) is not None: + if isinstance(apply.op, HasInnerGraph): + return [*apply.inputs, *apply.op.inner_outputs] + else: + # Mypy doesn't know these will never be None + return apply.inputs # type: ignore + else: + return None + + return tuple( + var + for var in walk(graphs, expand) + if (target_var_id == var.name or target_var_id == var.auto_name) + ) diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 970b1bec1c..f8e033a431 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -21,10 +21,11 @@ from pytensor import as_symbolic from pytensor.compile import optdb from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Variable, apply_depends_on +from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import _NoPythonOp from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter +from pytensor.graph.traversal import apply_depends_on from pytensor.graph.type import HasDataType, HasShape from pytensor.tensor.shape import Reshape, Shape, SpecifyShape diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index 8d2a35b9ac..a45179bbe1 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -15,10 +15,8 @@ from pytensor.graph.basic import ( AtomicVariable, Constant, - NoParams, - io_toposort, - vars_between, ) +from pytensor.graph.traversal import io_toposort, vars_between from pytensor.graph.utils import MethodNotDefined from pytensor.link.basic import Container, Linker, LocalLinker, PerformLinker from pytensor.link.c.cmodule import ( @@ -35,6 +33,9 @@ from pytensor.utils import difference, uniq +NoParams = object() + + if TYPE_CHECKING: from pytensor.graph.fg import FunctionGraph from pytensor.link.c.cmodule import ModuleCache diff --git a/pytensor/printing.py b/pytensor/printing.py index f0c98c911d..c54a0a9b05 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -18,9 +18,10 @@ from pytensor.compile.io import In, Out from pytensor.compile.profiling import ProfileStats from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph, Op, StorageMapType +from pytensor.graph.traversal import graph_inputs, toposort from pytensor.graph.utils import Scratchpad @@ -1102,7 +1103,7 @@ def process_graph(self, inputs, outputs, updates=None, display_inputs=False): ) inv_updates = {b: a for (a, b) in updates.items()} i = 1 - for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]): + for node in toposort([*outputs, *updates.values()], [*inputs, *updates]): for output in node.outputs: if output in inv_updates: name = str(inv_updates[output]) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index de92555401..ee4609120b 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -24,10 +24,11 @@ from pytensor import printing from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType, grad_undefined -from pytensor.graph.basic import Apply, Constant, Variable, applys_between, clone +from pytensor.graph.basic import Apply, Constant, Variable, clone from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph from pytensor.graph.rewriting.basic import MergeOptimizer +from pytensor.graph.traversal import applys_between from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.utils import MetaObject, MethodNotDefined from pytensor.link.c.op import COp diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index ae3785958c..6b03917a2b 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -6,9 +6,10 @@ from pytensor.compile.function.pfunc import construct_pfunc_ins_and_outs from pytensor.compile.sharedvalue import SharedVariable, collect_new_shareds from pytensor.configdefaults import config -from pytensor.graph.basic import Constant, Variable, graph_inputs +from pytensor.graph.basic import Constant, Variable from pytensor.graph.op import get_test_value from pytensor.graph.replace import clone_replace +from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import expand_empty, safe_new, until diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index c1ae4db04d..05a860584e 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -67,12 +67,11 @@ Apply, Variable, equal_computations, - graph_inputs, - io_connection_pattern, ) from pytensor.graph.features import NoOutputFromInplace -from pytensor.graph.op import HasInnerGraph, Op +from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern from pytensor.graph.replace import clone_replace +from pytensor.graph.traversal import graph_inputs from pytensor.graph.type import HasShape from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.c.basic import CLinker diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index adab47d37b..09793ab15a 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -18,11 +18,7 @@ Constant, NominalVariable, Variable, - ancestors, - apply_depends_on, equal_computations, - graph_inputs, - io_toposort, ) from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.features import ReplaceValidate @@ -33,11 +29,16 @@ EquilibriumGraphRewriter, GraphRewriter, copy_stack_trace, - in2out, + dfs_rewriter, node_rewriter, ) from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB from pytensor.graph.rewriting.utils import get_clients_at_depth +from pytensor.graph.traversal import ( + ancestors, + apply_depends_on, + graph_inputs, +) from pytensor.graph.type import HasShape from pytensor.graph.utils import InconsistencyError from pytensor.raise_op import Assert @@ -220,12 +221,9 @@ def scan_push_out_non_seq(fgraph, node): it to the outer function to be executed only once, before the `Scan` `Op`, reduces the amount of computation that needs to be performed. """ - if not isinstance(node.op, Scan): - return False - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs - local_fgraph_topo = io_toposort(node_inputs, node_outputs) + local_fgraph_topo = node.op.fgraph.toposort() local_fgraph_outs_set = set(node_outputs) local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} @@ -430,12 +428,9 @@ def scan_push_out_seq(fgraph, node): many times on many smaller tensors. In many cases, this optimization can increase memory usage but, in some specific cases, it can also decrease it. """ - if not isinstance(node.op, Scan): - return False - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs - local_fgraph_topo = io_toposort(node_inputs, node_outputs) + local_fgraph_topo = node.op.fgraph.toposort() local_fgraph_outs_set = set(node_outputs) local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} @@ -658,10 +653,9 @@ def inner_sitsot_only_last_step_used( fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs ) -> bool: """ - Given a inner nit-sot output of `Scan`, return ``True`` iff the outer - nit-sot output has only one client and that client is a `Subtensor` - instance that takes only the last step (last element along the first - axis). + Given a inner sit-sot output of `Scan`, return ``True`` iff the outer + sit-sot output has only one client and that client is a `Subtensor` + instance that takes only the last step (last element along the first axis). """ idx = scan_args.inner_out_sit_sot.index(var) outer_var = scan_args.outer_out_sit_sot[idx] @@ -697,7 +691,6 @@ def push_out_inner_vars( old_scan_args: ScanArgs, ) -> tuple[list[Variable], ScanArgs, dict[Variable, Variable]]: tmp_outer_vars: list[Variable | None] = [] - new_scan_node = old_scan_node new_scan_args = old_scan_args replacements: dict[Variable, Variable] = {} @@ -832,58 +825,78 @@ def scan_push_out_add(fgraph, node): Like `scan_push_out_seq`, this optimization aims to replace many operations on small tensors by few operations on large tensors. It can also lead to increased memory usage. + + FIXME: This rewrite doesn't cover user defined graphs, + since it doesn't account for the intermediate slice + returned by the scan constructor for sit-sot (i.e., something like output[1:]). + It only looks for `outputs[-1]` but the user will only ever write `outputs[1:][-1]` + The relevant helper function is `inner_sitsot_only_last_step_used` which is only used by this rewrite + Note this rewrite is registered before subtensor_merge, but even if it were after subtensor_merge is a mess + and doesn't simplify to x[1:][-1] to x[-1] unless x length is statically known """ # Don't perform the optimization on `as_while` `Scan`s. Because these # `Scan`s don't run for a predetermined number of steps, handling them is # more complicated and this optimization doesn't support it at the moment. - if not (isinstance(node.op, Scan) and not node.op.info.as_while): + op = node.op + if op.info.as_while: return False - op = node.op + # apply_ancestors(args.inner_outputs) + + add_of_dot_nodes = [ + n + for n in op.fgraph.apply_nodes + if + ( + # We have an Add + isinstance(n.op, Elemwise) + and isinstance(n.op.scalar_op, ps.Add) + and any( + ( + # With a Dot input that's only used in the Add + n_inp.owner is not None + and isinstance(n_inp.owner.op, Dot) + and len(op.fgraph.clients[n_inp]) == 1 + ) + for n_inp in n.inputs + ) + ) + ] - # Use `ScanArgs` to parse the inputs and outputs of scan for ease of - # use - args = ScanArgs( - node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info - ) + if not add_of_dot_nodes: + return False - clients = {} - local_fgraph_topo = io_toposort( - args.inner_inputs, args.inner_outputs, clients=clients + # Use `ScanArgs` to parse the inputs and outputs of scan for ease of access + args = ScanArgs( + node.inputs, + node.outputs, + op.inner_inputs, + op.inner_outputs, + op.info, + clone=False, ) - for nd in local_fgraph_topo: + for nd in add_of_dot_nodes: if ( - isinstance(nd.op, Elemwise) - and isinstance(nd.op.scalar_op, ps.Add) - and nd.out in args.inner_out_sit_sot + nd.out in args.inner_out_sit_sot + # FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern and inner_sitsot_only_last_step_used(fgraph, nd.out, args) ): # Ensure that one of the input to the add is the output of # the add from a previous iteration of the inner function sitsot_idx = args.inner_out_sit_sot.index(nd.out) if args.inner_in_sit_sot[sitsot_idx] in nd.inputs: - # Ensure that the other input to the add is a dot product - # between 2 matrices which will become a tensor3 and a - # matrix if pushed outside of the scan. Also make sure - # that the output of the Dot is ONLY used by the 'add' - # otherwise doing a Dot in the outer graph will only - # duplicate computation. - sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx]) # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0 dot_in_idx = 1 - sitsot_in_idx - dot_input = nd.inputs[dot_in_idx] + assert dot_input.owner is not None and isinstance( + dot_input.owner.op, Dot + ) if ( - dot_input.owner is not None - and isinstance(dot_input.owner.op, Dot) - and len(clients[dot_input]) == 1 - and dot_input.owner.inputs[0].ndim == 2 - and dot_input.owner.inputs[1].ndim == 2 - and get_outer_ndim(dot_input.owner.inputs[0], args) == 3 + get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and get_outer_ndim(dot_input.owner.inputs[1], args) == 3 ): # The optimization can be be applied in this case. @@ -920,6 +933,7 @@ def scan_push_out_add(fgraph, node): # external Dot instead of the output of scan # Modify the outer graph to add the outer Dot outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx] + # TODO: If we fix the FIXME above, we have to make sure we replace the last subtensor, not the immediate one subtensor_node = fgraph.clients[outer_sitsot][0][0] outer_sitsot_last_step = subtensor_node.outputs[0] @@ -2544,7 +2558,7 @@ def apply(self, fgraph, start_from=None): # ScanSaveMem should execute only once per node. optdb.register( "scan_save_mem_prealloc", - in2out(scan_save_mem_prealloc, ignore_newtrees=True), + dfs_rewriter(scan_save_mem_prealloc, ignore_newtrees=True), "fast_run", "scan", "scan_save_mem", @@ -2552,7 +2566,7 @@ def apply(self, fgraph, start_from=None): ) optdb.register( "scan_save_mem_no_prealloc", - in2out(scan_save_mem_no_prealloc, ignore_newtrees=True), + dfs_rewriter(scan_save_mem_no_prealloc, ignore_newtrees=True), "numba", "jax", "pytorch", @@ -2573,7 +2587,7 @@ def apply(self, fgraph, start_from=None): scan_seqopt1.register( "scan_remove_constants_and_unused_inputs0", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), "remove_constants_and_unused_inputs_scan", "fast_run", "scan", @@ -2582,7 +2596,7 @@ def apply(self, fgraph, start_from=None): scan_seqopt1.register( "scan_push_out_non_seq", - in2out(scan_push_out_non_seq, ignore_newtrees=True), + dfs_rewriter(scan_push_out_non_seq, ignore_newtrees=True), "scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name "fast_run", "scan", @@ -2592,7 +2606,7 @@ def apply(self, fgraph, start_from=None): scan_seqopt1.register( "scan_push_out_seq", - in2out(scan_push_out_seq, ignore_newtrees=True), + dfs_rewriter(scan_push_out_seq, ignore_newtrees=True), "scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name "fast_run", "scan", @@ -2603,7 +2617,7 @@ def apply(self, fgraph, start_from=None): scan_seqopt1.register( "scan_push_out_dot1", - in2out(scan_push_out_dot1, ignore_newtrees=True), + dfs_rewriter(scan_push_out_dot1, ignore_newtrees=True), "scan_pushout_dot1", # For backcompat: so it can be tagged with old name "fast_run", "more_mem", @@ -2616,7 +2630,7 @@ def apply(self, fgraph, start_from=None): scan_seqopt1.register( "scan_push_out_add", # TODO: Perhaps this should be an `EquilibriumGraphRewriter`? - in2out(scan_push_out_add, ignore_newtrees=False), + dfs_rewriter(scan_push_out_add, ignore_newtrees=False), "scan_pushout_add", # For backcompat: so it can be tagged with old name "fast_run", "more_mem", @@ -2627,14 +2641,14 @@ def apply(self, fgraph, start_from=None): scan_eqopt2.register( "while_scan_merge_subtensor_last_element", - in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True), + dfs_rewriter(while_scan_merge_subtensor_last_element, ignore_newtrees=True), "fast_run", "scan", ) scan_eqopt2.register( "constant_folding_for_scan2", - in2out(constant_folding, ignore_newtrees=True), + dfs_rewriter(constant_folding, ignore_newtrees=True), "fast_run", "scan", ) @@ -2642,7 +2656,7 @@ def apply(self, fgraph, start_from=None): scan_eqopt2.register( "scan_remove_constants_and_unused_inputs1", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), "remove_constants_and_unused_inputs_scan", "fast_run", "scan", @@ -2657,7 +2671,7 @@ def apply(self, fgraph, start_from=None): # After Merge optimization scan_eqopt2.register( "scan_remove_constants_and_unused_inputs2", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), "remove_constants_and_unused_inputs_scan", "fast_run", "scan", @@ -2665,7 +2679,7 @@ def apply(self, fgraph, start_from=None): scan_eqopt2.register( "scan_merge_inouts", - in2out(scan_merge_inouts, ignore_newtrees=True), + dfs_rewriter(scan_merge_inouts, ignore_newtrees=True), "fast_run", "scan", ) @@ -2673,7 +2687,7 @@ def apply(self, fgraph, start_from=None): # After everything else scan_eqopt2.register( "scan_remove_constants_and_unused_inputs3", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), "remove_constants_and_unused_inputs_scan", "fast_run", "scan", diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index 3b924225ac..4f39cfe0ad 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -15,9 +15,10 @@ from pytensor import tensor as pt from pytensor.compile.profiling import ProfileStats from pytensor.configdefaults import config -from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs +from pytensor.graph.basic import Constant, Variable, equal_computations from pytensor.graph.op import get_test_value from pytensor.graph.replace import clone_replace +from pytensor.graph.traversal import graph_inputs from pytensor.graph.type import HasDataType from pytensor.graph.utils import TestValueError from pytensor.tensor.basic import AllocEmpty, cast diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index c0a1c5cce8..95f45dc373 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -3,7 +3,7 @@ from pytensor.compile import optdb from pytensor.graph import Constant, graph_inputs -from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter from pytensor.scan.op import Scan from pytensor.scan.rewriting import scan_seqopt1 from pytensor.tensor._linalg.solve.tridiagonal import ( @@ -243,7 +243,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node): scan_seqopt1.register( scan_split_non_sequence_decomposition_and_solve.__name__, - in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True), + dfs_rewriter(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True), "fast_run", "scan", "scan_pushout", @@ -260,7 +260,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node): optdb["specialize"].register( reuse_decomposition_multiple_solves_jax.__name__, - in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True), + dfs_rewriter(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True), "jax", use_db_name_as_tag=False, ) @@ -275,7 +275,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node): scan_seqopt1.register( scan_split_non_sequence_decomposition_and_solve_jax.__name__, - in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True), + dfs_rewriter( + scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True + ), "jax", use_db_name_as_tag=False, position=2, diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index a183431a0e..d175f50219 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -85,7 +85,7 @@ import numpy as np from scipy.linalg import get_blas_funcs -from pytensor.graph import vectorize_graph +from pytensor.graph import Variable, vectorize_graph from pytensor.npy_2_compat import normalize_axis_tuple @@ -97,7 +97,7 @@ import pytensor.scalar from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, view_roots +from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from pytensor.link.c.op import COp @@ -114,6 +114,25 @@ _logger = logging.getLogger("pytensor.tensor.blas") +def view_roots(node: Variable) -> list[Variable]: + """Return the leaves from a search through consecutive view-maps.""" + owner = node.owner + if owner is not None: + try: + vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()} + except AttributeError: + return [node] + if node in vars_to_views: + answer = [] + for i in vars_to_views[node]: + answer += view_roots(owner.inputs[i]) + return answer + else: + return [node] + else: + return [node] + + def must_initialize_y_gemv(): # Check whether Scipy GEMV could output nan if y in not initialized from scipy.linalg.blas import get_blas_funcs diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 14d9a53251..0181699851 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -8,7 +8,7 @@ from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph import FunctionGraph -from pytensor.graph.basic import Apply, Constant, explicit_graph_inputs +from pytensor.graph.basic import Apply, Constant from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.graph.replace import ( @@ -16,6 +16,7 @@ _vectorize_not_needed, vectorize_graph, ) +from pytensor.graph.traversal import explicit_graph_inputs from pytensor.link.c.op import COp from pytensor.scalar import ScalarType from pytensor.tensor import as_tensor_variable diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 99a3d8b444..2088dd99cd 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -8,10 +8,11 @@ import pytensor.scalar as ps from pytensor.compile.function import function from pytensor.gradient import grad, hessian, jacobian -from pytensor.graph import Apply, Constant, FunctionGraph -from pytensor.graph.basic import ancestors, truncated_graph_inputs +from pytensor.graph.basic import Apply, Constant +from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType from pytensor.graph.replace import graph_replace +from pytensor.graph.traversal import ancestors, truncated_graph_inputs from pytensor.tensor.basic import ( atleast_2d, concatenate, diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index 6de1a6b527..d67a6653f4 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -4,7 +4,11 @@ from pytensor.configdefaults import config from pytensor.graph import ancestors from pytensor.graph.op import compute_test_value -from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter +from pytensor.graph.rewriting.basic import ( + copy_stack_trace, + dfs_rewriter, + node_rewriter, +) from pytensor.tensor import NoneConst, TensorVariable from pytensor.tensor.basic import constant from pytensor.tensor.elemwise import DimShuffle @@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node): optdb.register( "random_make_inplace", - in2out(random_make_inplace, ignore_newtrees=True), + dfs_rewriter(random_make_inplace, ignore_newtrees=True), "fast_run", "inplace", position=50.9, diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index fa30e10c18..bc40dc57d8 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -2,8 +2,7 @@ from pytensor.compile import optdb from pytensor.graph import Constant -from pytensor.graph.rewriting.basic import in2out, node_rewriter -from pytensor.graph.rewriting.db import SequenceDB +from pytensor.graph.rewriting.basic import dfs_rewriter, in2out, node_rewriter from pytensor.tensor import abs as abs_t from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt from pytensor.tensor.basic import ( @@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): return new_op.make_node(rng, size, a_vector_param, *other_params).outputs -random_vars_opt = SequenceDB() -random_vars_opt.register( - "lognormal_from_normal", - in2out(lognormal_from_normal), - "jax", -) -random_vars_opt.register( - "halfnormal_from_normal", - in2out(halfnormal_from_normal), - "jax", -) -random_vars_opt.register( - "geometric_from_uniform", - in2out(geometric_from_uniform), - "jax", -) -random_vars_opt.register( - "negative_binomial_from_gamma_poisson", - in2out(negative_binomial_from_gamma_poisson), - "jax", -) -random_vars_opt.register( - "inverse_gamma_from_gamma", - in2out(inverse_gamma_from_gamma), - "jax", -) -random_vars_opt.register( - "generalized_gamma_from_gamma", - in2out(generalized_gamma_from_gamma), - "jax", -) -random_vars_opt.register( - "wald_from_normal_uniform", - in2out(wald_from_normal_uniform), - "jax", -) -random_vars_opt.register( - "beta_binomial_from_beta_binomial", - in2out(beta_binomial_from_beta_binomial), - "jax", -) -random_vars_opt.register( - "materialize_implicit_arange_choice_without_replacement", - in2out(materialize_implicit_arange_choice_without_replacement), - "jax", +random_vars_opt = dfs_rewriter( + lognormal_from_normal, + halfnormal_from_normal, + geometric_from_uniform, + negative_binomial_from_gamma_poisson, + inverse_gamma_from_gamma, + generalized_gamma_from_gamma, + wald_from_normal_uniform, + beta_binomial_from_beta_binomial, + materialize_implicit_arange_choice_without_replacement, ) optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110) diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index 75d213eb26..3ed6c12781 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -1,6 +1,6 @@ from pytensor.compile import optdb from pytensor.graph import node_rewriter -from pytensor.graph.rewriting.basic import out2in +from pytensor.graph.rewriting.basic import dfs_rewriter from pytensor.tensor import as_tensor, constant from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.rewriting.shape import ShapeFeature @@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node): optdb.register( introduce_explicit_core_shape_rv.__name__, - out2in(introduce_explicit_core_shape_rv), + dfs_rewriter(introduce_explicit_core_shape_rv), "numba", position=100, ) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index e9c2c8e47e..86bdae9e64 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -36,6 +36,7 @@ NodeRewriter, Rewriter, copy_stack_trace, + dfs_rewriter, in2out, node_rewriter, ) @@ -518,7 +519,7 @@ def local_alloc_empty_to_zeros(fgraph, node): compile.optdb.register( "local_alloc_empty_to_zeros", - in2out(local_alloc_empty_to_zeros), + dfs_rewriter(local_alloc_empty_to_zeros), # After move to gpu and merge2, before inplace. "alloc_empty_to_zeros", position=49.3, diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 685cec5785..03a1e8b0ab 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -59,6 +59,7 @@ import numpy as np +from pytensor.graph.traversal import toposort from pytensor.tensor.rewriting.basic import register_specialize @@ -76,7 +77,7 @@ EquilibriumGraphRewriter, GraphRewriter, copy_stack_trace, - in2out, + dfs_rewriter, node_rewriter, ) from pytensor.graph.rewriting.db import SequenceDB @@ -459,6 +460,9 @@ def apply(self, fgraph): callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time + nodelist = list(toposort(fgraph.outputs)) + nodelist.reverse() + def on_import(new_node): if new_node is not node: nodelist.append(new_node) @@ -470,10 +474,8 @@ def on_import(new_node): while did_something: nb_iter += 1 t0 = time.perf_counter() - nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs) time_toposort += time.perf_counter() - t0 did_something = False - nodelist.reverse() for node in nodelist: if not ( isinstance(node.op, Elemwise) @@ -719,7 +721,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): # fast_compile is needed to have GpuDot22 created. blas_optdb.register( "local_dot_to_dot22", - in2out(local_dot_to_dot22), + dfs_rewriter(local_dot_to_dot22), "fast_run", "fast_compile", position=0, @@ -742,7 +744,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): ) -blas_opt_inplace = in2out( +blas_opt_inplace = dfs_rewriter( local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace" ) optdb.register( @@ -881,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node): # dot22scalar and gemm give more speed up then dot22scalar blas_optdb.register( "local_dot22_to_dot22scalar", - in2out(local_dot22_to_dot22scalar), + dfs_rewriter(local_dot22_to_dot22scalar), "fast_run", position=12, ) diff --git a/pytensor/tensor/rewriting/blas_c.py b/pytensor/tensor/rewriting/blas_c.py index 1723cf36f8..827aa64077 100644 --- a/pytensor/tensor/rewriting/blas_c.py +++ b/pytensor/tensor/rewriting/blas_c.py @@ -1,5 +1,5 @@ from pytensor.configdefaults import config -from pytensor.graph.rewriting.basic import in2out +from pytensor.graph.rewriting.basic import dfs_rewriter from pytensor.tensor import basic as ptb from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive from pytensor.tensor.blas_c import ( @@ -56,13 +56,15 @@ def make_c_gemv_destructive(fgraph, node): blas_optdb.register( - "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 + "use_c_blas", dfs_rewriter(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 ) # this matches the InplaceBlasOpt defined in blas.py optdb.register( "c_blas_destructive", - in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), + dfs_rewriter( + make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive" + ), "fast_run", "inplace", "c_blas", diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 023c8aae51..3c1d808c6d 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -2,7 +2,7 @@ from pytensor.graph import Constant, node_rewriter from pytensor.graph.destroyhandler import inplace_candidates from pytensor.graph.replace import vectorize_node -from pytensor.graph.rewriting.basic import copy_stack_trace, out2in +from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.math import Dot @@ -59,7 +59,7 @@ def local_useless_unbatched_blockwise(fgraph, node): # We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops optdb.register( "local_useless_unbatched_blockwise", - out2in(local_useless_unbatched_blockwise, ignore_newtrees=True), + dfs_rewriter(local_useless_unbatched_blockwise, ignore_newtrees=True), "fast_run", "fast_compile", "blockwise", diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index f08f19f06c..1923aa4a9e 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -13,19 +13,21 @@ from pytensor.compile.function.types import Supervisor from pytensor.compile.mode import get_target_language from pytensor.configdefaults import config -from pytensor.graph import FunctionGraph, Op -from pytensor.graph.basic import Apply, Variable, ancestors +from pytensor.graph.basic import Apply, Variable from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.features import ReplaceValidate -from pytensor.graph.fg import Output +from pytensor.graph.fg import FunctionGraph, Output +from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import ( GraphRewriter, copy_stack_trace, + dfs_rewriter, in2out, node_rewriter, out2in, ) from pytensor.graph.rewriting.db import SequenceDB +from pytensor.graph.traversal import ancestors from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -1241,21 +1243,21 @@ def constant_fold_branches_of_add_mul(fgraph, node): ) fuse_seqopt.register( "local_useless_composite_outputs", - in2out(local_useless_composite_outputs), + dfs_rewriter(local_useless_composite_outputs), "fast_run", "fusion", position=2, ) fuse_seqopt.register( "local_careduce_fusion", - in2out(local_careduce_fusion), + dfs_rewriter(local_careduce_fusion), "fast_run", "fusion", position=10, ) fuse_seqopt.register( "local_inline_composite_constants", - in2out(local_inline_composite_constants, ignore_newtrees=True), + dfs_rewriter(local_inline_composite_constants, ignore_newtrees=True), "fast_run", "fusion", position=20, diff --git a/pytensor/tensor/rewriting/jax.py b/pytensor/tensor/rewriting/jax.py index 00ed3f2b14..6e0ee1c1d6 100644 --- a/pytensor/tensor/rewriting/jax.py +++ b/pytensor/tensor/rewriting/jax.py @@ -1,6 +1,6 @@ import pytensor.tensor as pt from pytensor.compile import optdb -from pytensor.graph.rewriting.basic import in2out, node_rewriter +from pytensor.graph.rewriting.basic import dfs_rewriter, node_rewriter from pytensor.tensor.basic import MakeVector from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Sum @@ -46,7 +46,7 @@ def boolean_indexing_set_or_inc(fgraph, node): optdb.register( "jax_boolean_indexing_set_or_inc", - in2out(boolean_indexing_set_or_inc), + dfs_rewriter(boolean_indexing_set_or_inc), "jax", position=100, ) @@ -96,7 +96,7 @@ def boolean_indexing_sum(fgraph, node): optdb.register( - "jax_boolean_indexing_sum", in2out(boolean_indexing_sum), "jax", position=100 + "jax_boolean_indexing_sum", dfs_rewriter(boolean_indexing_sum), "jax", position=100 ) @@ -144,7 +144,7 @@ def shape_parameter_as_tuple(fgraph, node): optdb.register( "jax_shape_parameter_as_tuple", - in2out(shape_parameter_as_tuple), + dfs_rewriter(shape_parameter_as_tuple), "jax", position=100, ) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 8367642c4c..9535526da7 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -10,7 +10,7 @@ from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, - in2out, + dfs_rewriter, node_rewriter, ) from pytensor.scalar.basic import Abs, Log, Mul, Sign @@ -952,7 +952,7 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): optdb.register( "jax_bilinaer_lyapunov_to_direct", - in2out(jax_bilinaer_lyapunov_to_direct), + dfs_rewriter(jax_bilinaer_lyapunov_to_direct), "jax", position=0.9, # Run before canonicalization ) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index 60c4e41c2d..6bb9ed5bd9 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -1,7 +1,7 @@ from pytensor.compile import optdb from pytensor.graph import node_rewriter -from pytensor.graph.basic import applys_between -from pytensor.graph.rewriting.basic import out2in +from pytensor.graph.rewriting.basic import dfs_rewriter +from pytensor.graph.traversal import applys_between from pytensor.tensor.basic import as_tensor, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.rewriting.shape import ShapeFeature @@ -102,7 +102,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): optdb.register( introduce_explicit_core_shape_blockwise.__name__, - out2in(introduce_explicit_core_shape_blockwise), + dfs_rewriter(introduce_explicit_core_shape_blockwise), "numba", position=100, ) diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 52472de47b..098d380fad 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -4,7 +4,7 @@ from pytensor.compile import optdb from pytensor.compile.builders import OpFromGraph from pytensor.graph import Apply, node_rewriter -from pytensor.graph.rewriting.basic import copy_stack_trace, in2out +from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter from pytensor.tensor.basic import AllocDiag from pytensor.tensor.rewriting.basic import register_specialize @@ -37,7 +37,7 @@ def inline_ofg_expansion(fgraph, node): # and before the first scan optimizer. optdb.register( "inline_ofg_expansion", - in2out(inline_ofg_expansion), + dfs_rewriter(inline_ofg_expansion), "fast_compile", "fast_run", position=-0.01, diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 1eb10d247b..f784954dc9 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -7,7 +7,7 @@ import pytensor from pytensor.configdefaults import config -from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations +from pytensor.graph.basic import Constant, Variable, equal_computations from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( @@ -15,6 +15,7 @@ copy_stack_trace, node_rewriter, ) +from pytensor.graph.traversal import ancestors from pytensor.graph.utils import InconsistencyError, get_variable_trace_string from pytensor.scalar import ScalarType from pytensor.tensor.basic import ( diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 93a94fac09..083a927c81 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -10,9 +10,9 @@ from pytensor.graph.rewriting.basic import ( WalkingGraphRewriter, copy_stack_trace, + dfs_rewriter, in2out, node_rewriter, - out2in, ) from pytensor.raise_op import Assert from pytensor.scalar import Add, ScalarConstant, ScalarType @@ -1562,7 +1562,7 @@ def local_uint_constant_indices(fgraph, node): compile.optdb.register( local_uint_constant_indices.__name__, - out2in(local_uint_constant_indices), + dfs_rewriter(local_uint_constant_indices), # We don't include in the Python / C because those always cast indices to int64 internally. "numba", "jax", diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py index 43c60df370..a9fcda6122 100644 --- a/pytensor/xtensor/rewriting/utils.py +++ b/pytensor/xtensor/rewriting/utils.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from pytensor.compile import optdb -from pytensor.graph.rewriting.basic import NodeRewriter, in2out +from pytensor.graph.rewriting.basic import NodeRewriter, dfs_rewriter from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase from pytensor.tensor.rewriting.ofg import inline_ofg_expansion from pytensor.tensor.variable import TensorVariable @@ -23,7 +23,7 @@ # Register OFG inline again after lowering xtensor optdb.register( "inline_ofg_expansion_xtensor", - in2out(inline_ofg_expansion), + dfs_rewriter(inline_ofg_expansion), "fast_run", "fast_compile", position=0.11, diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 34cc810647..7b9ac9af59 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -157,7 +157,9 @@ def check_no_unexpected_results(mypy_lines: Iterable[str]): for section, sdf in df.reset_index().groupby(args.groupby): print(f"\n\n[{section}]") for row in sdf.itertuples(): - print(f"{row.file}:{row.line}: {row.type}: {row.message}") + print( + f"{row.file}:{row.line}: {row.type} [{row.errorcode}]: {row.message}" + ) print() else: print( diff --git a/tests/compile/test_profiling.py b/tests/compile/test_profiling.py index e0610b9783..8839ee5cf2 100644 --- a/tests/compile/test_profiling.py +++ b/tests/compile/test_profiling.py @@ -50,23 +50,14 @@ def test_profiling(self): the_string = buf.getvalue() lines1 = [l for l in the_string.split("\n") if "Max if linker" in l] lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l] - if config.device == "cpu": - assert "CPU: 4112KB (4104KB)" in the_string, (lines1, lines2) - assert "CPU: 8204KB (8196KB)" in the_string, (lines1, lines2) - assert "CPU: 8208KB" in the_string, (lines1, lines2) - assert ( - "Minimum peak from all valid apply node order is 4104KB" - in the_string - ), (lines1, lines2) - else: - assert "CPU: 16KB (16KB)" in the_string, (lines1, lines2) - assert "GPU: 8204KB (8204KB)" in the_string, (lines1, lines2) - assert "GPU: 12300KB (12300KB)" in the_string, (lines1, lines2) - assert "GPU: 8212KB" in the_string, (lines1, lines2) - assert ( - "Minimum peak from all valid apply node order is 4116KB" - in the_string - ), (lines1, lines2) + # NODE: The specific numbers can change for distinct (but correct) toposort orderings + # Update the test values if a different algorithm is used + assert "CPU: 4112KB (4112KB)" in the_string, (lines1, lines2) + assert "CPU: 8204KB (8204KB)" in the_string, (lines1, lines2) + assert "CPU: 8208KB" in the_string, (lines1, lines2) + assert ( + "Minimum peak from all valid apply node order is 4104KB" in the_string + ), (lines1, lines2) finally: config.profile = config1 diff --git a/tests/graph/rewriting/test_kanren.py b/tests/graph/rewriting/test_kanren.py index a1dc310ce5..7cb66a4ba0 100644 --- a/tests/graph/rewriting/test_kanren.py +++ b/tests/graph/rewriting/test_kanren.py @@ -160,7 +160,7 @@ def distributes(in_lv, out_lv): assert expr_opt.owner.op == pt.add assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot) - assert fgraph_opt.inputs[0] is A_pt + assert fgraph_opt.inputs[-1] is A_pt assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A" assert expr_opt.owner.inputs[1].owner.op == pt.add assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 84ffb365b5..9a929155bd 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -11,25 +11,13 @@ Apply, NominalVariable, Variable, - ancestors, - apply_depends_on, - applys_between, as_string, clone, clone_get_equiv, equal_computations, - explicit_graph_inputs, - general_toposort, - get_var_by_name, - graph_inputs, - io_toposort, - orphans_between, - truncated_graph_inputs, - variable_depends_on, - vars_between, - walk, ) from pytensor.graph.op import Op +from pytensor.graph.traversal import applys_between, graph_inputs from pytensor.graph.type import Type from pytensor.printing import debugprint from pytensor.tensor import constant @@ -37,7 +25,7 @@ from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector from pytensor.tensor.type_other import NoneConst from pytensor.tensor.variable import TensorVariable -from tests.graph.utils import MyInnerGraphOp, op_multiple_outputs +from tests.graph.utils import MyInnerGraphOp class MyType(Type): @@ -207,129 +195,6 @@ def test_clone_inner_graph(self): ) -def prenode(obj): - if isinstance(obj, Variable): - if obj.owner: - return [obj.owner] - if isinstance(obj, Apply): - return obj.inputs - - -class TestToposort: - def test_simple(self): - # Test a simple graph - r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - o = MyOp(r1, r2) - o.name = "o1" - o2 = MyOp(o, r5) - o2.name = "o2" - - clients = {} - res = general_toposort([o2], prenode, clients=clients) - - assert clients == { - o2.owner: [o2], - o: [o2.owner], - r5: [o2.owner], - o.owner: [o], - r1: [o.owner], - r2: [o.owner], - } - assert res == [r5, r2, r1, o.owner, o, o2.owner, o2] - - with pytest.raises(ValueError): - general_toposort( - [o2], prenode, compute_deps_cache=lambda x: None, deps_cache=None - ) - - res = io_toposort([r5], [o2]) - assert res == [o.owner, o2.owner] - - def test_double_dependencies(self): - # Test a graph with double dependencies - r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) - o2 = MyOp.make_node(o.outputs[0], r5) - all = general_toposort(o2.outputs, prenode) - assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]] - - def test_inputs_owners(self): - # Test a graph where the inputs have owners - r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) - r2b = o.outputs[0] - o2 = MyOp.make_node(r2b, r2b) - all = io_toposort([r2b], o2.outputs) - assert all == [o2] - - o2 = MyOp.make_node(r2b, r5) - all = io_toposort([r2b], o2.outputs) - assert all == [o2] - - def test_not_connected(self): - # Test a graph which is not connected - r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(r3, r4) - all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) - assert all == [o1, o0] or all == [o0, o1] - - def test_io_chain(self): - # Test inputs and outputs mixed together in a chain graph - r1, r2 = MyVariable(1), MyVariable(2) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(o0.outputs[0], r1) - all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) - assert all == [o1] - - def test_outputs_clients(self): - # Test when outputs have clients - r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - MyOp.make_node(o0.outputs[0], r4) - all = io_toposort([], o0.outputs) - assert all == [o0] - - def test_multi_output_nodes(self): - l0, r0 = op_multiple_outputs(shared(0.0)) - l1, r1 = op_multiple_outputs(shared(0.0)) - - v0 = r0 + 1 - v1 = pt.exp(v0) - out = r1 * v1 - - # When either r0 or r1 is provided as an input, the respective node shouldn't be part of the toposort - assert set(io_toposort([], [out])) == { - r0.owner, - r1.owner, - v0.owner, - v1.owner, - out.owner, - } - assert set(io_toposort([r0], [out])) == { - r1.owner, - v0.owner, - v1.owner, - out.owner, - } - assert set(io_toposort([r1], [out])) == { - r0.owner, - v0.owner, - v1.owner, - out.owner, - } - assert set(io_toposort([r0, r1], [out])) == {v0.owner, v1.owner, out.owner} - - # When l0 and/or l1 are provided, we still need to compute the respective nodes - assert set(io_toposort([l0, l1], [out])) == { - r0.owner, - r1.owner, - v0.owner, - v1.owner, - out.owner, - } - - class TestEval: def setup_method(self): self.x, self.y = scalars("x", "y") @@ -463,99 +328,6 @@ def test_equal_computations(): assert equal_computations(max_argmax1, max_argmax2) -def test_walk(): - r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) - o1.name = "o1" - o2 = MyOp(r3, o1) - o2.name = "o2" - - def expand(r): - if r.owner: - return r.owner.inputs - - res = walk([o2], expand, bfs=True, return_children=False) - res_list = list(res) - assert res_list == [o2, r3, o1, r1, r2] - - res = walk([o2], expand, bfs=False, return_children=False) - res_list = list(res) - assert res_list == [o2, o1, r2, r1, r3] - - res = walk([o2], expand, bfs=True, return_children=True) - res_list = list(res) - assert res_list == [ - (o2, [r3, o1]), - (r3, None), - (o1, [r1, r2]), - (r1, None), - (r2, None), - ] - - -def test_ancestors(): - r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) - o1.name = "o1" - o2 = MyOp(r3, o1) - o2.name = "o2" - - res = ancestors([o2], blockers=None) - res_list = list(res) - assert res_list == [o2, r3, o1, r1, r2] - - res = ancestors([o2], blockers=None) - assert r3 in res - res_list = list(res) - assert res_list == [o1, r1, r2] - - res = ancestors([o2], blockers=[o1]) - res_list = list(res) - assert res_list == [o2, r3, o1] - - -def test_graph_inputs(): - r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) - o1.name = "o1" - o2 = MyOp(r3, o1) - o2.name = "o2" - - res = graph_inputs([o2], blockers=None) - res_list = list(res) - assert res_list == [r3, r1, r2] - - -def test_explicit_graph_inputs(): - x = pt.fscalar() - y = pt.constant(2) - z = shared(1) - a = pt.sum(x + y + z) - b = pt.true_div(x, y) - - res = list(explicit_graph_inputs([a])) - res1 = list(explicit_graph_inputs(b)) - - assert res == [x] - assert res1 == [x] - - -def test_variables_and_orphans(): - r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) - o1.name = "o1" - o2 = MyOp(r3, o1) - o2.name = "o2" - - vars_res = vars_between([r1, r2], [o2]) - orphans_res = orphans_between([r1, r2], [o2]) - - vars_res_list = list(vars_res) - orphans_res_list = list(orphans_res) - assert vars_res_list == [o2, o1, r3, r2, r1] - assert orphans_res_list == [r3] - - def test_ops(): r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) o1 = MyOp(r1, r2) @@ -570,64 +342,6 @@ def test_ops(): assert res_list == [o3.owner, o2.owner, o1.owner] -def test_apply_depends_on(): - r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) - o1.name = "o1" - o2 = MyOp(r1, o1) - o2.name = "o2" - o3 = MyOp(r3, o1, o2) - o3.name = "o3" - - assert apply_depends_on(o2.owner, o1.owner) - assert apply_depends_on(o2.owner, o2.owner) - assert apply_depends_on(o3.owner, [o1.owner, o2.owner]) - - -@pytest.mark.xfail(reason="Not implemented") -def test_io_connection_pattern(): - raise AssertionError() - - -@pytest.mark.xfail(reason="Not implemented") -def test_view_roots(): - raise AssertionError() - - -def test_get_var_by_name(): - r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) - o1.name = "o1" - - # Inner graph - igo_in_1 = MyVariable(4) - igo_in_2 = MyVariable(5) - igo_out_1 = MyOp(igo_in_1, igo_in_2) - igo_out_1.name = "igo1" - - igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) - - o2 = igo(r3, o1) - o2.name = "o1" - - res = get_var_by_name([o1, o2], "blah") - - assert res == () - - res = get_var_by_name([o1, o2], "o1") - - assert set(res) == {o1, o2} - - (res,) = get_var_by_name([o1, o2], o1.auto_name) - - assert res == o1 - - (res,) = get_var_by_name([o1, o2], "igo1") - - exp_res = igo.fgraph.outputs[0] - assert res == exp_res - - def test_clone_new_inputs(): """Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes.""" @@ -752,133 +466,6 @@ def test_NominalVariable_create_variable_type(): assert ntv_unpkld is ntv -def test_variable_depends_on(): - x = MyVariable(1) - x.name = "x" - y = MyVariable(1) - y.name = "y" - x2 = MyOp(x) - x2.name = "x2" - y2 = MyOp(y) - y2.name = "y2" - o = MyOp(x2, y) - assert variable_depends_on(o, x) - assert variable_depends_on(o, [x]) - assert not variable_depends_on(o, [y2]) - assert variable_depends_on(o, [y2, x]) - assert not variable_depends_on(y, [y2]) - assert variable_depends_on(y, [y]) - - -class TestTruncatedGraphInputs: - def test_basic(self): - """ - * No conditions - n - n - (o) - - * One condition - n - (c) - o - - * Two conditions where on depends on another, both returned - (c) - (c) - o - - * Additional nodes are present - (c) - n - o - n - (n) -' - - * Disconnected condition not returned - (c) - n - o - c - - * Disconnected output is present and returned - (c) - (c) - o - (o) - - * Condition on itself adds itself - n - (c) - (o/c) - """ - x = MyVariable(1) - x.name = "x" - y = MyVariable(1) - y.name = "y" - z = MyVariable(1) - z.name = "z" - x2 = MyOp(x) - x2.name = "x2" - y2 = MyOp(y, x2) - y2.name = "y2" - o = MyOp(y2) - o2 = MyOp(o) - # No conditions - assert truncated_graph_inputs([o]) == [o] - # One condition - assert truncated_graph_inputs([o2], [y2]) == [y2] - # Condition on itself adds itself - assert truncated_graph_inputs([o], [y2, o]) == [o, y2] - # Two conditions where on depends on another, both returned - assert truncated_graph_inputs([o2], [y2, o]) == [o, y2] - # Additional nodes are present - assert truncated_graph_inputs([o], [y]) == [x2, y] - # Disconnected condition - assert truncated_graph_inputs([o2], [y2, z]) == [y2] - # Disconnected output is present - assert truncated_graph_inputs([o2, z], [y2]) == [z, y2] - - def test_repeated_input(self): - """Test that truncated_graph_inputs does not return repeated inputs.""" - x = MyVariable(1) - x.name = "x" - y = MyVariable(1) - y.name = "y" - - trunc_inp1 = MyOp(x, y) - trunc_inp1.name = "trunc_inp1" - - trunc_inp2 = MyOp(x, y) - trunc_inp2.name = "trunc_inp2" - - o = MyOp(trunc_inp1, trunc_inp1, trunc_inp2, trunc_inp2) - o.name = "o" - - assert truncated_graph_inputs([o], [trunc_inp1]) == [trunc_inp2, trunc_inp1] - - def test_repeated_nested_input(self): - """Test that truncated_graph_inputs does not return repeated inputs.""" - x = MyVariable(1) - x.name = "x" - y = MyVariable(1) - y.name = "y" - - trunc_inp = MyOp(x, y) - trunc_inp.name = "trunc_inp" - - o1 = MyOp(trunc_inp, trunc_inp, x, x) - o1.name = "o1" - - assert truncated_graph_inputs([o1], [trunc_inp]) == [x, trunc_inp] - - # Reverse order of inputs - o2 = MyOp(x, x, trunc_inp, trunc_inp) - o2.name = "o2" - - assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x] - - def test_single_pass_per_node(self, mocker): - import pytensor.graph.basic - - inspect = mocker.spy(pytensor.graph.basic, "variable_depends_on") - x = pt.dmatrix("x") - m = x.shape[0][None, None] - - f = x / m - w = x / m - f - truncated_graph_inputs([w], [x]) - # make sure there were exactly the same calls as unique variables seen by the function - assert len(inspect.call_args_list) == len( - {a for ((a, b), kw) in inspect.call_args_list} - ) - - def test_dprint(): r1, r2 = MyVariable(1), MyVariable(2) o1 = MyOp(r1, r2) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 54ec654095..cd69e9022c 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -56,7 +56,7 @@ def test_validate_inputs(self): with pytest.raises(TypeError, match="'Variable' object is not iterable"): FunctionGraph(var1, [var2]) - with pytest.raises(TypeError, match="'Variable' object is not reversible"): + with pytest.raises(TypeError, match="'Variable' object is not iterable"): FunctionGraph([var1], var2) with pytest.raises( diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 5ec545015b..d0d8b6c5fb 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -275,3 +275,8 @@ def perform(self, node, inputs, outputs): res_nameless = single_op(x) assert res_nameless.name is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_io_connection_pattern(): + raise AssertionError() diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index 2c822587a3..f0d64ee76b 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -4,13 +4,14 @@ import pytensor.tensor as pt from pytensor import config, function, shared -from pytensor.graph.basic import equal_computations, graph_inputs +from pytensor.graph.basic import equal_computations from pytensor.graph.replace import ( clone_replace, graph_replace, vectorize_graph, vectorize_node, ) +from pytensor.graph.traversal import graph_inputs from pytensor.tensor import dvector, fvector, vector from tests import unittest_tools as utt from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs @@ -27,7 +28,7 @@ def test_cloning_no_replace_strict_copy_inputs(self): f1 = z * (x + y) ** 2 + 5 f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True) - f2_inp = graph_inputs([f2]) + f2_inp = tuple(graph_inputs([f2])) assert z in f2_inp assert x in f2_inp @@ -64,7 +65,7 @@ def test_cloning_replace_strict_copy_inputs(self): f2 = clone_replace( f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True ) - f2_inp = graph_inputs([f2]) + f2_inp = tuple(graph_inputs([f2])) assert z in f2_inp assert x in f2_inp assert y2 in f2_inp @@ -82,7 +83,7 @@ def test_cloning_replace_not_strict_copy_inputs(self): f2 = clone_replace( f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True ) - f2_inp = graph_inputs([f2]) + f2_inp = tuple(graph_inputs([f2])) assert z in f2_inp assert x in f2_inp assert y2 in f2_inp diff --git a/tests/graph/test_traversal.py b/tests/graph/test_traversal.py new file mode 100644 index 0000000000..30ff171348 --- /dev/null +++ b/tests/graph/test_traversal.py @@ -0,0 +1,442 @@ +import pytest + +from pytensor import Variable, shared +from pytensor import tensor as pt +from pytensor.graph import Apply, ancestors, graph_inputs +from pytensor.graph.traversal import ( + apply_ancestors, + apply_depends_on, + explicit_graph_inputs, + general_toposort, + get_var_by_name, + io_toposort, + orphans_between, + toposort, + toposort_with_orderings, + truncated_graph_inputs, + variable_ancestors, + variable_depends_on, + vars_between, + walk, +) +from tests.graph.test_basic import MyOp, MyVariable +from tests.graph.utils import MyInnerGraphOp, op_multiple_outputs + + +class TestToposort: + @staticmethod + def prenode(obj): + if isinstance(obj, Variable): + if obj.owner: + return [obj.owner] + if isinstance(obj, Apply): + return obj.inputs + + def test_simple(self): + # Test a simple graph + r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) + o = MyOp(r1, r2) + o.name = "o1" + o2 = MyOp(o, r5) + o2.name = "o2" + + res = general_toposort([o2], self.prenode) + assert res == [r5, r2, r1, o.owner, o, o2.owner, o2] + + def circular_dependency(obj): + if obj is o: + # o2 depends on o, so o cannot depend on o2 + return [o2, *self.prenode(obj)] + return self.prenode(obj) + + with pytest.raises(ValueError, match="graph contains cycles"): + general_toposort([o2], circular_dependency) + + res = io_toposort([r5], [o2]) + assert res == [o.owner, o2.owner] + + def test_double_dependencies(self): + # Test a graph with double dependencies + r1, r5 = MyVariable(1), MyVariable(5) + o = MyOp.make_node(r1, r1) + o2 = MyOp.make_node(o.outputs[0], r5) + all = general_toposort(o2.outputs, self.prenode) + assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]] + + def test_inputs_owners(self): + # Test a graph where the inputs have owners + r1, r5 = MyVariable(1), MyVariable(5) + o = MyOp.make_node(r1, r1) + r2b = o.outputs[0] + o2 = MyOp.make_node(r2b, r2b) + all = io_toposort([r2b], o2.outputs) + assert all == [o2] + + o2 = MyOp.make_node(r2b, r5) + all = io_toposort([r2b], o2.outputs) + assert all == [o2] + + def test_not_connected(self): + # Test a graph which is not connected + r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) + o0 = MyOp.make_node(r1, r2) + o1 = MyOp.make_node(r3, r4) + all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) + assert all == [o1, o0] or all == [o0, o1] + + def test_io_chain(self): + # Test inputs and outputs mixed together in a chain graph + r1, r2 = MyVariable(1), MyVariable(2) + o0 = MyOp.make_node(r1, r2) + o1 = MyOp.make_node(o0.outputs[0], r1) + all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) + assert all == [o1] + + def test_outputs_clients(self): + # Test when outputs have clients + r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4) + o0 = MyOp.make_node(r1, r2) + MyOp.make_node(o0.outputs[0], r4) + all = io_toposort([], o0.outputs) + assert all == [o0] + + def test_multi_output_nodes(self): + l0, r0 = op_multiple_outputs(shared(0.0)) + l1, r1 = op_multiple_outputs(shared(0.0)) + + v0 = r0 + 1 + v1 = pt.exp(v0) + out = r1 * v1 + + # When either r0 or r1 is provided as an input, the respective node shouldn't be part of the toposort + assert set(io_toposort([], [out])) == { + r0.owner, + r1.owner, + v0.owner, + v1.owner, + out.owner, + } + assert set(io_toposort([r0], [out])) == { + r1.owner, + v0.owner, + v1.owner, + out.owner, + } + assert set(io_toposort([r1], [out])) == { + r0.owner, + v0.owner, + v1.owner, + out.owner, + } + assert set(io_toposort([r0, r1], [out])) == {v0.owner, v1.owner, out.owner} + + # When l0 and/or l1 are provided, we still need to compute the respective nodes + assert set(io_toposort([l0, l1], [out])) == { + r0.owner, + r1.owner, + v0.owner, + v1.owner, + out.owner, + } + + +def test_walk(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + o2 = MyOp(r3, o1) + o2.name = "o2" + + def expand(r): + if r.owner: + return r.owner.inputs + + res = walk([o2], expand, bfs=True, return_children=False) + res_list = list(res) + assert res_list == [o2, r3, o1, r1, r2] + + res = walk([o2], expand, bfs=False, return_children=False) + res_list = list(res) + assert res_list == [o2, o1, r2, r1, r3] + + res = walk([o2], expand, bfs=True, return_children=True) + res_list = list(res) + assert res_list == [ + (o2, [r3, o1]), + (r3, None), + (o1, [r1, r2]), + (r1, None), + (r2, None), + ] + + +def test_ancestors(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + o2 = MyOp(r3, o1) + o2.name = "o2" + + res = ancestors([o2], blockers=None) + res_list = list(res) + assert res_list == [o2, o1, r2, r1, r3] + + res = ancestors([o2], blockers=None) + assert o1 in res + res_list = list(res) + assert res_list == [r2, r1, r3] + + res = ancestors([o2], blockers=[o1]) + res_list = list(res) + assert res_list == [o2, o1, r3] + + +def test_graph_inputs(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + o2 = MyOp(r3, o1) + o2.name = "o2" + + res = graph_inputs([o2], blockers=None) + res_list = list(res) + assert res_list == [r2, r1, r3] + + +def test_explicit_graph_inputs(): + x = pt.fscalar() + y = pt.constant(2) + z = shared(1) + a = pt.sum(x + y + z) + b = pt.true_div(x, y) + + res = list(explicit_graph_inputs([a])) + res1 = list(explicit_graph_inputs(b)) + + assert res == [x] + assert res1 == [x] + + +def test_variables_and_orphans(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + o2 = MyOp(r3, o1) + o2.name = "o2" + + vars_res = vars_between([r1, r2], [o2]) + orphans_res = orphans_between([r1, r2], [o2]) + + vars_res_list = list(vars_res) + orphans_res_list = list(orphans_res) + assert vars_res_list == [o2, o1, r2, r1, r3] + assert orphans_res_list == [r3] + + +def test_apply_depends_on(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + o2 = MyOp(r1, o1) + o2.name = "o2" + o3 = MyOp(r3, o1, o2) + o3.name = "o3" + + assert apply_depends_on(o2.owner, o1.owner) + assert apply_depends_on(o2.owner, o2.owner) + assert apply_depends_on(o3.owner, [o1.owner, o2.owner]) + + +def test_variable_depends_on(): + x = MyVariable(1) + x.name = "x" + y = MyVariable(1) + y.name = "y" + x2 = MyOp(x) + x2.name = "x2" + y2 = MyOp(y) + y2.name = "y2" + o = MyOp(x2, y) + assert variable_depends_on(o, x) + assert variable_depends_on(o, [x]) + assert not variable_depends_on(o, [y2]) + assert variable_depends_on(o, [y2, x]) + assert not variable_depends_on(y, [y2]) + assert variable_depends_on(y, [y]) + + +class TestTruncatedGraphInputs: + def test_basic(self): + """ + * No conditions + n - n - (o) + + * One condition + n - (c) - o + + * Two conditions where on depends on another, both returned + (c) - (c) - o + + * Additional nodes are present + (c) - n - o + n - (n) -' + + * Disconnected condition not returned + (c) - n - o + c + + * Disconnected output is present and returned + (c) - (c) - o + (o) + + * Condition on itself adds itself + n - (c) - (o/c) + """ + x = MyVariable(1) + x.name = "x" + y = MyVariable(1) + y.name = "y" + z = MyVariable(1) + z.name = "z" + x2 = MyOp(x) + x2.name = "x2" + y2 = MyOp(y, x2) + y2.name = "y2" + o = MyOp(y2) + o2 = MyOp(o) + # No conditions + assert truncated_graph_inputs([o]) == [o] + # One condition + assert truncated_graph_inputs([o2], [y2]) == [y2] + # Condition on itself adds itself + assert truncated_graph_inputs([o], [y2, o]) == [o, y2] + # Two conditions where on depends on another, both returned + assert truncated_graph_inputs([o2], [y2, o]) == [o, y2] + # Additional nodes are present + assert truncated_graph_inputs([o], [y]) == [x2, y] + # Disconnected condition + assert truncated_graph_inputs([o2], [y2, z]) == [y2] + # Disconnected output is present + assert truncated_graph_inputs([o2, z], [y2]) == [z, y2] + + def test_repeated_input(self): + """Test that truncated_graph_inputs does not return repeated inputs.""" + x = MyVariable(1) + x.name = "x" + y = MyVariable(1) + y.name = "y" + + trunc_inp1 = MyOp(x, y) + trunc_inp1.name = "trunc_inp1" + + trunc_inp2 = MyOp(x, y) + trunc_inp2.name = "trunc_inp2" + + o = MyOp(trunc_inp1, trunc_inp1, trunc_inp2, trunc_inp2) + o.name = "o" + + assert truncated_graph_inputs([o], [trunc_inp1]) == [trunc_inp2, trunc_inp1] + + def test_repeated_nested_input(self): + """Test that truncated_graph_inputs does not return repeated inputs.""" + x = MyVariable(1) + x.name = "x" + y = MyVariable(1) + y.name = "y" + + trunc_inp = MyOp(x, y) + trunc_inp.name = "trunc_inp" + + o1 = MyOp(trunc_inp, trunc_inp, x, x) + o1.name = "o1" + + assert truncated_graph_inputs([o1], [trunc_inp]) == [x, trunc_inp] + + # Reverse order of inputs + o2 = MyOp(x, x, trunc_inp, trunc_inp) + o2.name = "o2" + + assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x] + + def test_single_pass_per_node(self, mocker): + import pytensor.graph.traversal + + inspect = mocker.spy(pytensor.graph.traversal, "variable_depends_on") + x = pt.dmatrix("x") + m = x.shape[0][None, None] + + f = x / m + w = x / m - f + truncated_graph_inputs([w], [x]) + # make sure there were exactly the same calls as unique variables seen by the function + assert len(inspect.call_args_list) == len( + {a for ((a, b), kw) in inspect.call_args_list} + ) + + +def test_get_var_by_name(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + + # Inner graph + igo_in_1 = MyVariable(4) + igo_in_2 = MyVariable(5) + igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1.name = "igo1" + + igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) + + o2 = igo(r3, o1) + o2.name = "o1" + + res = get_var_by_name([o1, o2], "blah") + + assert res == () + + res = get_var_by_name([o1, o2], "o1") + + assert set(res) == {o1, o2} + + (res,) = get_var_by_name([o1, o2], o1.auto_name) + + assert res == o1 + + (res,) = get_var_by_name([o1, o2], "igo1") + + exp_res = igo.fgraph.outputs[0] + assert res == exp_res + + +@pytest.mark.parametrize( + "func", + [ + lambda x: all(variable_ancestors([x])), + lambda x: all(variable_ancestors([x], blockers=[x.clone()])), + lambda x: all(apply_ancestors([x])), + lambda x: all(apply_ancestors([x], blockers=[x.clone()])), + lambda x: all(toposort([x])), + lambda x: all(toposort([x], blockers=[x.clone()])), + lambda x: all(toposort_with_orderings([x], orderings={x: []})), + lambda x: all( + toposort_with_orderings([x], blockers=[x.clone()], orderings={x: []}) + ), + ], + ids=[ + "variable_ancestors", + "variable_ancestors_with_blockers", + "apply_ancestors", + "apply_ancestors_with_blockers)", + "toposort", + "toposort_with_blockers", + "toposort_with_orderings", + "toposort_with_orderings_and_blockers", + ], +) +def test_traversal_benchmark(func, benchmark): + r1 = MyVariable(1) + out = r1 + for i in range(50): + out = MyOp(out, out) + + benchmark(func, out) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 95cf6ec557..fd9a48111f 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -6,9 +6,9 @@ import numpy as np import pytest +import scipy from pytensor.compile import SymbolicInput -from tests.tensor.test_math_scipy import scipy numba = pytest.importorskip("numba") diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 8c0d9d4f52..a43ed72770 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -1,3 +1,5 @@ +from itertools import chain + import numpy as np import pytest @@ -490,6 +492,7 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): if isinstance(node.op, Scan) ] + # Collect inner inputs we expect to be destroyed by the step function # Scan reorders inputs internally, so we need to check its ordering inner_inps = scan_op.fgraph.inputs mit_sot_inps = scan_op.inner_mitsot(inner_inps) @@ -501,28 +504,22 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): ] [sit_sot_inp] = scan_op.inner_sitsot(inner_inps) - inner_outs = scan_op.fgraph.outputs - mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs) - [sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs) - [nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs) + destroyed_inputs = [] + for inner_out in scan_op.fgraph.outputs: + node = inner_out.owner + dm = node.op.destroy_map + if dm: + destroyed_inputs.extend( + node.inputs[idx] for idx in chain.from_iterable(dm.values()) + ) if n_steps_constant: - assert mit_sot_outs[0].owner.op.destroy_map == { - 0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])] - } - assert mit_sot_outs[1].owner.op.destroy_map == { - 0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])] - } - assert sit_sot_out.owner.op.destroy_map == { - 0: [sit_sot_out.owner.inputs.index(sit_sot_inp)] - } + assert len(destroyed_inputs) == 3 + assert set(destroyed_inputs) == {*oldest_mit_sot_inps, sit_sot_inp} else: # This is not a feature, but a current limitation # https://github.com/pymc-devs/pytensor/issues/1283 - assert mit_sot_outs[0].owner.op.destroy_map == {} - assert mit_sot_outs[1].owner.op.destroy_map == {} - assert sit_sot_out.owner.op.destroy_map == {} - assert nit_sot_out.owner.op.destroy_map == {} + assert not destroyed_inputs @pytest.mark.parametrize( diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 896d131f57..2d3daeec0e 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -27,11 +27,12 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian -from pytensor.graph import vectorize_graph -from pytensor.graph.basic import Apply, ancestors, equal_computations +from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.graph.replace import vectorize_graph from pytensor.graph.rewriting.basic import MergeOptimizer +from pytensor.graph.traversal import ancestors from pytensor.graph.utils import MissingInputError from pytensor.raise_op import assert_op from pytensor.scan.basic import scan diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 1b7fac98a4..9871f2d8bf 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -9,9 +9,10 @@ from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config from pytensor.gradient import grad, jacobian -from pytensor.graph.basic import Constant, ancestors, equal_computations +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace +from pytensor.graph.traversal import ancestors from pytensor.scan.op import Scan from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge from pytensor.scan.utils import until @@ -600,10 +601,12 @@ class TestPushOutAddScan: is used to compute the sum over the dot products between the corresponding elements of two list of matrices. - TODO FIXME XXX: These aren't real tests; they simply confirm that a few + FIXME: These aren't real tests; they simply confirm that a few graph that could be relevant to the push-out optimizations can be compiled and evaluated. None of them confirm that a push-out optimization has been performed. + + FIXME: The rewrite is indeed broken, probably fro a long while, see FIXME details in the respective rewrite """ def test_sum_dot(self): @@ -614,7 +617,15 @@ def test_sum_dot(self): sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)], outputs_info=[pt.zeros_like(A)], ) + # FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that. + # They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]` + # instead of `scan_out[1:][-1]` that the user would define by writing `s[-1]` + # It however, tests the only case the rewrite supports now f = function([A, B], S.owner.inputs[0][-1]) + has_scan = any(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) + # Rewrite is only triggered in fast_run mode + assert has_scan if (config.mode == "FAST_COMPILE") else (not has_scan) + rng = np.random.default_rng(utt.fetch_seed()) vA = rng.uniform(size=(5, 5)).astype(config.floatX) vB = rng.uniform(size=(5, 5)).astype(config.floatX) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 7da993b3dc..4286660ccf 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -12,8 +12,9 @@ from pytensor.compile.io import In, Out from pytensor.configdefaults import config from pytensor.gradient import GradientError -from pytensor.graph.basic import Apply, Constant, applys_between +from pytensor.graph.basic import Apply, Constant from pytensor.graph.op import Op +from pytensor.graph.traversal import applys_between from pytensor.sparse import ( CSC, CSM, diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index b968131525..837b2c3b0c 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -7,10 +7,11 @@ from pytensor import config, shared from pytensor.compile.function import function from pytensor.compile.mode import Mode -from pytensor.graph.basic import Constant, Variable, ancestors +from pytensor.graph.basic import Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.graph.traversal import ancestors from pytensor.tensor import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.random.basic import ( diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 06af82ddf7..065cdeb1f6 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -11,11 +11,12 @@ from pytensor.compile.mode import Mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.configdefaults import config -from pytensor.graph.basic import Constant, Variable, graph_inputs +from pytensor.graph.basic import Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.graph.traversal import graph_inputs from pytensor.tensor import ones, stack from pytensor.tensor.random.basic import ( ChoiceWithoutReplacement, diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 6ce9d10ef9..eda10cbd25 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -16,9 +16,9 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph import vectorize_graph -from pytensor.graph.basic import Apply, ancestors, equal_computations +from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import vectorize_graph from pytensor.graph.rewriting.basic import ( SequentialNodeRewriter, WalkingGraphRewriter, @@ -28,6 +28,7 @@ ) from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph +from pytensor.graph.traversal import ancestors from pytensor.printing import debugprint from pytensor.scalar import PolyGamma, Psi, TriGamma from pytensor.tensor import inplace diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 95f84790d9..91a1f96e81 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -12,8 +12,9 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config from pytensor.graph import rewrite_graph, vectorize_graph -from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations +from pytensor.graph.basic import Constant, Variable, equal_computations from pytensor.graph.rewriting.basic import check_stack_trace +from pytensor.graph.traversal import ancestors from pytensor.raise_op import Assert from pytensor.tensor.basic import Alloc, _convert_to_int8 from pytensor.tensor.blockwise import Blockwise diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index a22a07d101..4df25cc1ca 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -5,8 +5,8 @@ from scipy.signal import convolve as scipy_convolve from pytensor import config, function, grad -from pytensor.graph.basic import ancestors, io_toposort from pytensor.graph.rewriting import rewrite_graph +from pytensor.graph.traversal import ancestors, io_toposort from pytensor.tensor import matrix, tensor, vector from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.signal.conv import Convolve1d, convolve1d diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 8274ddbcea..352238adec 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,7 +9,8 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph import rewrite_graph -from pytensor.graph.basic import Constant, applys_between, equal_computations +from pytensor.graph.basic import Constant, equal_computations +from pytensor.graph.traversal import applys_between from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert from pytensor.tensor import alloc diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 935b9ada52..c68b8a3159 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -19,9 +19,10 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, grad, numeric_grad -from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations +from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import vectorize_node +from pytensor.graph.traversal import ancestors, applys_between from pytensor.link.c.basic import DualLinker from pytensor.npy_2_compat import using_numpy_2 from pytensor.printing import pprint diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index e7579b10ac..5d69ff1f09 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -1,24 +1,19 @@ import warnings +from functools import partial import numpy as np import pytest - -from pytensor.gradient import NullTypeGradError, verify_grad -from pytensor.scalar import ScalarLoop -from pytensor.tensor.elemwise import Elemwise - - -scipy = pytest.importorskip("scipy") - -from functools import partial - +import scipy from scipy import special, stats from pytensor import function, grad from pytensor import tensor as pt from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config +from pytensor.gradient import NullTypeGradError, verify_grad +from pytensor.scalar import ScalarLoop from pytensor.tensor import gammaincc, inplace, kn, kv, kve, vector +from pytensor.tensor.elemwise import Elemwise from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, @@ -1175,8 +1170,8 @@ def test_unused_grad_loop_opt(self, wrt): if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ScalarLoop) ] - assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1] - assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2] + assert scalar_loop_op1.nin == 10 + 3 * 1 # wrt=[2] + assert scalar_loop_op2.nin == 10 + 3 * 2 # wrt=[0, 1] else: [scalar_loop_op] = [ node.op.scalar_op diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index 2c6f818c30..e4a0841910 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -10,6 +10,7 @@ from pytensor.compile import DeepCopyOp from pytensor.compile.mode import get_default_mode from pytensor.graph.basic import Constant, equal_computations +from pytensor.graph.traversal import io_toposort from pytensor.tensor import get_vector_length from pytensor.tensor.basic import constant from pytensor.tensor.elemwise import DimShuffle @@ -144,7 +145,7 @@ def test__getitem__Subtensor(): i = iscalar("i") z = x[i] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == Subtensor # This should ultimately do nothing (i.e. just return `x`) @@ -156,29 +157,29 @@ def test__getitem__Subtensor(): # It lands in the `full_slices` condition in # `_tensor_py_operators.__getitem__` z = x[..., None] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert all(op_type == DimShuffle for op_type in op_types) z = x[None, :, None, :] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert all(op_type == DimShuffle for op_type in op_types) # This one lands in the non-`full_slices` condition in # `_tensor_py_operators.__getitem__` z = x[:i, :, None] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[1:] == [DimShuffle, Subtensor] z = x[:] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == Subtensor z = x[..., :] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == Subtensor z = x[..., i, :] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == Subtensor @@ -187,24 +188,24 @@ def test__getitem__AdvancedSubtensor_bool(): i = TensorType("bool", shape=(None, None))("i") z = x[i] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor i = TensorType("bool", shape=(None,))("i") z = x[:, i] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor i = TensorType("bool", shape=(None,))("i") z = x[..., i] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor with pytest.raises(TypeError): z = x[[True, False], i] z = x[ivector("b"), i] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor @@ -215,26 +216,26 @@ def test__getitem__AdvancedSubtensor(): # This is a `__getitem__` call that's redirected to `_tensor_py_operators.take` z = x[i] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor # This should index nothing (i.e. return an empty copy of `x`) # We check that the index is empty z = x[[]] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types == [AdvancedSubtensor] assert isinstance(z.owner.inputs[1], TensorConstant) z = x[:, i] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types == [MakeSlice, AdvancedSubtensor] z = x[..., i, None] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types == [MakeSlice, AdvancedSubtensor] z = x[i, None] - op_types = [type(node.op) for node in pytensor.graph.basic.io_toposort([x, i], [z])] + op_types = [type(node.op) for node in io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 8de9c24b18..ed1b596691 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -28,9 +28,10 @@ zero_grad, zero_grad_, ) -from pytensor.graph.basic import Apply, graph_inputs +from pytensor.graph.basic import Apply from pytensor.graph.null_type import NullType from pytensor.graph.op import Op +from pytensor.graph.traversal import graph_inputs from pytensor.scan.op import Scan from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh from pytensor.tensor.math import sum as pt_sum diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 19598bfb21..8a9db1ae4b 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import scipy import pytensor import pytensor.typed_list @@ -37,8 +38,7 @@ def rand_ranged_matrix(minimum, maximum, shape): def random_lil(shape, dtype, nnz): - sp = pytest.importorskip("scipy") - rval = sp.sparse.lil_matrix(shape, dtype=dtype) + rval = scipy.sparse.lil_matrix(shape, dtype=dtype) huge = 2**30 for k in range(nnz): # set non-zeros in random locations (row x, col y) @@ -451,7 +451,6 @@ def test_non_tensor_type(self): assert f([[x, y], [x, y, y]], [x, y]) == 0 def test_sparse(self): - sp = pytest.importorskip("scipy") mySymbolicSparseList = TypedListType( sparse.SparseTensorType("csr", pytensor.config.floatX) )() @@ -461,8 +460,8 @@ def test_sparse(self): f = pytensor.function([mySymbolicSparseList, mySymbolicSparse], z) - x = sp.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) - y = sp.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) + x = scipy.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) + y = scipy.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) assert f([x, y], y) == 1 @@ -519,7 +518,6 @@ def test_non_tensor_type(self): assert f([[x, y], [x, y, y]], [x, y]) == 1 def test_sparse(self): - sp = pytest.importorskip("scipy") mySymbolicSparseList = TypedListType( sparse.SparseTensorType("csr", pytensor.config.floatX) )() @@ -529,8 +527,8 @@ def test_sparse(self): f = pytensor.function([mySymbolicSparseList, mySymbolicSparse], z) - x = sp.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) - y = sp.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) + x = scipy.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) + y = scipy.sparse.csr_matrix(random_lil((10, 40), pytensor.config.floatX, 3)) assert f([x, y, y], y) == 2