From be7a7f545deb95e51b1f96ef4faa66c2c5988d83 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 26 Jun 2024 10:25:15 +0200 Subject: [PATCH 01/17] add PlateMeta and NodeMeta --- pymc/model_graph.py | 167 +++++++++++++++++++++++++++++--------------- 1 file changed, 111 insertions(+), 56 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 30b79bb194..3ec196ef2d 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -14,7 +14,8 @@ import warnings from collections import defaultdict -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable +from dataclasses import dataclass from enum import Enum from os import path from typing import Any @@ -39,6 +40,34 @@ ) +@dataclass +class PlateMeta: + names: tuple[str] + sizes: tuple[int] + dim_info: bool = True + + def __hash__(self) -> int: + return hash((self.names, self.sizes)) + + +def create_plate_label(plate_meta: PlateMeta, include_size: bool = True) -> str: + def create_label(d: int, dname: str, dlen: int): + if plate_meta.dim_info: + label = f"{dname}" + else: + label = f"{dname}_dim{d}" + + if include_size: + label = f"{label} ({dlen})" + + return label + + values = enumerate( + zip(plate_meta.names, plate_meta.sizes), + ) + return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) + + def fast_eval(var): return function([], var, mode="FAST_COMPILE")() @@ -53,6 +82,15 @@ class NodeType(str, Enum): DATA = "Data" +@dataclass +class NodeMeta: + var: TensorVariable + node_type: NodeType + + def __hash__(self) -> int: + return hash(self.var.name) + + GraphvizNodeKwargs = dict[str, Any] NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] @@ -265,31 +303,26 @@ def make_compute_graph( def _make_node( self, - var_name, - graph, + node: NodeMeta, *, node_formatters: NodeTypeFormatterMapping, - nx=False, - cluster=False, + add_node: Callable[[str, ...], None], + cluster: bool = False, formatting: str = "plain", ): """Attaches the given variable to a graphviz or networkx Digraph""" - v = self.model[var_name] - - node_type = get_node_type(var_name, self.model) - node_formatter = node_formatters[node_type] - - kwargs = node_formatter(v) + node_formatter = node_formatters[node.node_type] + kwargs = node_formatter(node.var) if cluster: kwargs["cluster"] = cluster - if nx: - graph.add_node(var_name.replace(":", "&"), **kwargs) - else: - graph.node(var_name.replace(":", "&"), **kwargs) + add_node(node.var.name.replace(":", "&"), **kwargs) - def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, set[VarName]]: + def get_plates( + self, + var_names: Iterable[VarName] | None = None, + ) -> dict[PlateMeta, set[NodeMeta]]: """Rough but surprisingly accurate plate detection. Just groups by the shape of the underlying distribution. Will be wrong @@ -302,33 +335,50 @@ def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, se """ plates = defaultdict(set) - # TODO: Evaluate all RV shapes and dim_length at once. - # This should help to find discrepancies, and - # avoids unnecessary function compiles for deetermining labels. + # TODO: Evaluate all RV shapes at once + # This should help find discrepencies, and + # avoids unnecessary function compiles for determining labels. + dim_lengths: dict[str, int] = { + name: fast_eval(value).item() for name, value in self.model.dim_lengths.items() + } for var_name in self.vars_to_plot(var_names): v = self.model[var_name] - shape: Sequence[int] = fast_eval(v.shape) - dim_labels = [] + shape: tuple[int, ...] = tuple(fast_eval(v.shape)) if var_name in self.model.named_vars_to_dims: # The RV is associated with `dims` information. + names = [] + sizes = [] for d, dname in enumerate(self.model.named_vars_to_dims[var_name]): - if dname is None: - # Unnamed dimension in a `dims` tuple! - dlen = shape[d] - dname = f"{var_name}_dim{d}" - else: - dlen = fast_eval(self.model.dim_lengths[dname]) - dim_labels.append(f"{dname} ({dlen})") - plate_label = " x ".join(dim_labels) + names.append(dname) + sizes.append(dim_lengths.get(dname, shape[d])) + + plate_meta = PlateMeta( + names=tuple(names), + sizes=tuple(sizes), + ) else: # The RV has no `dims` information. - dim_labels = [str(x) for x in shape] - plate_label = " x ".join(map(str, shape)) - plates[plate_label].add(var_name) + sizes = tuple(shape) + plate_meta = PlateMeta( + names=tuple([var_name] * len(sizes)), + sizes=sizes, + dim_info=False, + ) + + v = self.model[var_name] + node_type = get_node_type(var_name, self.model) + var = NodeMeta(var=v, node_type=node_type) + plates[plate_meta].add(var) return dict(plates) + def edges(self, var_names: Iterable[VarName] | None = None): + for child, parents in self.make_compute_graph(var_names=var_names).items(): + # parents is a set of rv names that precede child rv nodes + for parent in parents: + yield child.replace(":", "&"), parent.replace(":", "&") + def make_graph( self, var_names: Iterable[VarName] | None = None, @@ -337,6 +387,7 @@ def make_graph( figsize=None, dpi=300, node_formatters: NodeTypeFormatterMapping | None = None, + include_size: bool = True, ): """Make graphviz Digraph of PyMC model @@ -357,26 +408,31 @@ def make_graph( node_formatters = update_node_formatters(node_formatters) graph = graphviz.Digraph(self.model.name) - for plate_label, all_var_names in self.get_plates(var_names).items(): - if plate_label: + for plate_meta, all_vars in self.get_plates(var_names).items(): + if plate_meta: # must be preceded by 'cluster' to get a box around it + plate_label = create_plate_label(plate_meta, include_size=include_size) with graph.subgraph(name="cluster" + plate_label) as sub: - for var_name in all_var_names: + for var in all_vars: self._make_node( - var_name, sub, formatting=formatting, node_formatters=node_formatters + node=var, + formatting=formatting, + node_formatters=node_formatters, + add_node=sub.node, ) # plate label goes bottom right sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") else: - for var_name in all_var_names: + for var in all_vars: self._make_node( - var_name, graph, formatting=formatting, node_formatters=node_formatters + node=var, + formatting=formatting, + node_formatters=node_formatters, + add_node=graph.node, ) - for child, parents in self.make_compute_graph(var_names=var_names).items(): - # parents is a set of rv names that precede child rv nodes - for parent in parents: - graph.edge(parent.replace(":", "&"), child.replace(":", "&")) + for child, parent in self.edges(var_names=var_names): + graph.edge(parent, child) if save is not None: width, height = (None, None) if figsize is None else figsize @@ -397,6 +453,7 @@ def make_networkx( var_names: Iterable[VarName] | None = None, formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, + include_size: bool = True, ): """Make networkx Digraph of PyMC model @@ -417,20 +474,20 @@ def make_networkx( node_formatters = update_node_formatters(node_formatters) graphnetwork = networkx.DiGraph(name=self.model.name) - for plate_label, all_var_names in self.get_plates(var_names).items(): - if plate_label: + for plate_meta, all_vars in self.get_plates(var_names).items(): + if plate_meta: # # must be preceded by 'cluster' to get a box around it + plate_label = create_plate_label(plate_meta, include_size=include_size) subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) - for var_name in all_var_names: + for var in all_vars: self._make_node( - var_name, - subgraphnetwork, - nx=True, + node=var, node_formatters=node_formatters, cluster="cluster" + plate_label, formatting=formatting, + add_node=subgraphnetwork.add_node, ) for sgn in subgraphnetwork.nodes: networkx.set_node_attributes( @@ -446,19 +503,17 @@ def make_networkx( networkx.set_node_attributes(graphnetwork, node_data) graphnetwork.graph["name"] = self.model.name else: - for var_name in all_var_names: + for var in all_vars: self._make_node( - var_name, - graphnetwork, - nx=True, + node=var, formatting=formatting, node_formatters=node_formatters, + add_node=graphnetwork.add_node, ) - for child, parents in self.make_compute_graph(var_names=var_names).items(): - # parents is a set of rv names that precede child rv nodes - for parent in parents: - graphnetwork.add_edge(parent.replace(":", "&"), child.replace(":", "&")) + for child, parents in self.edges(var_names=var_names): + graphnetwork.add_edge(parents, child) + return graphnetwork From f96502f6cd90fbb993fca4e911e154b188d274d4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 26 Jun 2024 10:56:20 +0200 Subject: [PATCH 02/17] remove dim info and add kwargs --- pymc/model_graph.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 3ec196ef2d..13f82ab5fe 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -44,18 +44,21 @@ class PlateMeta: names: tuple[str] sizes: tuple[int] - dim_info: bool = True def __hash__(self) -> int: return hash((self.names, self.sizes)) -def create_plate_label(plate_meta: PlateMeta, include_size: bool = True) -> str: +def create_plate_label( + var_name: str, + plate_meta: PlateMeta, + include_size: bool = True, +) -> str: def create_label(d: int, dname: str, dlen: int): - if plate_meta.dim_info: + if dname: label = f"{dname}" else: - label = f"{dname}_dim{d}" + label = f"{var_name}_dim{d}" if include_size: label = f"{label} ({dlen})" @@ -359,11 +362,9 @@ def get_plates( ) else: # The RV has no `dims` information. - sizes = tuple(shape) plate_meta = PlateMeta( - names=tuple([var_name] * len(sizes)), - sizes=sizes, - dim_info=False, + names=(), + sizes=tuple(shape), ) v = self.model[var_name] @@ -387,7 +388,7 @@ def make_graph( figsize=None, dpi=300, node_formatters: NodeTypeFormatterMapping | None = None, - include_size: bool = True, + include_shape_size: bool = True, ): """Make graphviz Digraph of PyMC model @@ -411,7 +412,7 @@ def make_graph( for plate_meta, all_vars in self.get_plates(var_names).items(): if plate_meta: # must be preceded by 'cluster' to get a box around it - plate_label = create_plate_label(plate_meta, include_size=include_size) + plate_label = create_plate_label(plate_meta, include_size=include_shape_size) with graph.subgraph(name="cluster" + plate_label) as sub: for var in all_vars: self._make_node( @@ -453,7 +454,7 @@ def make_networkx( var_names: Iterable[VarName] | None = None, formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, - include_size: bool = True, + include_shape_size: bool = True, ): """Make networkx Digraph of PyMC model @@ -478,7 +479,7 @@ def make_networkx( if plate_meta: # # must be preceded by 'cluster' to get a box around it - plate_label = create_plate_label(plate_meta, include_size=include_size) + plate_label = create_plate_label(plate_meta, include_size=include_shape_size) subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) for var in all_vars: @@ -523,6 +524,7 @@ def model_to_networkx( var_names: Iterable[VarName] | None = None, formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, + include_shape_size: bool = True, ): """Produce a networkx Digraph from a PyMC model. @@ -548,6 +550,8 @@ def model_to_networkx( A dictionary mapping node types to functions that return a dictionary of node attributes. Check out the networkx documentation for more information how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html + include_shape_size : bool + Include the shape size in the plate label. Default is True. Examples -------- @@ -596,7 +600,10 @@ def model_to_networkx( ) model = pm.modelcontext(model) return ModelGraph(model).make_networkx( - var_names=var_names, formatting=formatting, node_formatters=node_formatters + var_names=var_names, + formatting=formatting, + node_formatters=node_formatters, + include_shape_size=include_shape_size, ) @@ -609,6 +616,7 @@ def model_to_graphviz( figsize: tuple[int, int] | None = None, dpi: int = 300, node_formatters: NodeTypeFormatterMapping | None = None, + include_shape_size: bool = True, ): """Produce a graphviz Digraph from a PyMC model. @@ -640,6 +648,8 @@ def model_to_graphviz( A dictionary mapping node types to functions that return a dictionary of node attributes. Check out graphviz documentation for more information on available attributes. https://graphviz.org/docs/nodes/ + include_shape_size : bool + Include the shape size in the plate label. Default is True. Examples -------- @@ -701,4 +711,5 @@ def model_to_graphviz( figsize=figsize, dpi=dpi, node_formatters=node_formatters, + include_shape_size=include_shape_size, ) From 962eab81b29b9a3504259ae0bbf7895f2059d9cc Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 27 Jun 2024 08:25:19 +0200 Subject: [PATCH 03/17] wrap each plate in single class --- pymc/model_graph.py | 71 +++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 13f82ab5fe..0c407f3bd9 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -17,6 +17,7 @@ from collections.abc import Callable, Iterable from dataclasses import dataclass from enum import Enum +from itertools import zip_longest from os import path from typing import Any @@ -45,7 +46,7 @@ class PlateMeta: names: tuple[str] sizes: tuple[int] - def __hash__(self) -> int: + def __hash__(self): return hash((self.names, self.sizes)) @@ -55,10 +56,10 @@ def create_plate_label( include_size: bool = True, ) -> str: def create_label(d: int, dname: str, dlen: int): - if dname: - label = f"{dname}" - else: - label = f"{var_name}_dim{d}" + if not dname: + return f"{dlen}" + + label = f"{dname}" if include_size: label = f"{label} ({dlen})" @@ -66,7 +67,7 @@ def create_label(d: int, dname: str, dlen: int): return label values = enumerate( - zip(plate_meta.names, plate_meta.sizes), + zip_longest(plate_meta.names, plate_meta.sizes, fillvalue=None), ) return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) @@ -90,10 +91,16 @@ class NodeMeta: var: TensorVariable node_type: NodeType - def __hash__(self) -> int: + def __hash__(self): return hash(self.var.name) +@dataclass +class Plate: + meta: PlateMeta + variables: list[NodeMeta] + + GraphvizNodeKwargs = dict[str, Any] NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] @@ -325,7 +332,7 @@ def _make_node( def get_plates( self, var_names: Iterable[VarName] | None = None, - ) -> dict[PlateMeta, set[NodeMeta]]: + ) -> list[Plate]: """Rough but surprisingly accurate plate detection. Just groups by the shape of the underlying distribution. Will be wrong @@ -372,13 +379,33 @@ def get_plates( var = NodeMeta(var=v, node_type=node_type) plates[plate_meta].add(var) - return dict(plates) + return [ + Plate(meta=plate_meta, variables=list(variables)) + for plate_meta, variables in plates.items() + ] + + def edges( + self, + var_names: Iterable[VarName] | None = None, + ) -> list[tuple[VarName, VarName]]: + """Get edges between the variables in the model. + + Parameters + ---------- + var_names : iterable of str, optional + Subset of variables to be plotted that identify a subgraph with respect to the entire model graph - def edges(self, var_names: Iterable[VarName] | None = None): - for child, parents in self.make_compute_graph(var_names=var_names).items(): - # parents is a set of rv names that precede child rv nodes - for parent in parents: - yield child.replace(":", "&"), parent.replace(":", "&") + Returns + ------- + list of tuple + List of edges between the variables in the model. + + """ + return [ + (VarName(child.replace(":", "&")), VarName(parent.replace(":", "&"))) + for child, parents in self.make_compute_graph(var_names=var_names).items() + for parent in parents + ] def make_graph( self, @@ -409,10 +436,14 @@ def make_graph( node_formatters = update_node_formatters(node_formatters) graph = graphviz.Digraph(self.model.name) - for plate_meta, all_vars in self.get_plates(var_names).items(): + for plate in self.get_plates(var_names): + plate_meta = plate.meta + all_vars = plate.variables if plate_meta: # must be preceded by 'cluster' to get a box around it - plate_label = create_plate_label(plate_meta, include_size=include_shape_size) + plate_label = create_plate_label( + all_vars[0].var.name, plate_meta, include_size=include_shape_size + ) with graph.subgraph(name="cluster" + plate_label) as sub: for var in all_vars: self._make_node( @@ -475,11 +506,15 @@ def make_networkx( node_formatters = update_node_formatters(node_formatters) graphnetwork = networkx.DiGraph(name=self.model.name) - for plate_meta, all_vars in self.get_plates(var_names).items(): + for plate in self.get_plates(var_names): + plate_meta = plate.meta + all_vars = plate.variables if plate_meta: # # must be preceded by 'cluster' to get a box around it - plate_label = create_plate_label(plate_meta, include_size=include_shape_size) + plate_label = create_plate_label( + all_vars[0].var.name, plate_meta, include_size=include_shape_size + ) subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) for var in all_vars: From 38428de4ad1ac77ca33c436dc60eb847b391bc63 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 27 Jun 2024 08:35:02 +0200 Subject: [PATCH 04/17] no plate for scalars --- pymc/model_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 0c407f3bd9..8063434487 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -439,7 +439,7 @@ def make_graph( for plate in self.get_plates(var_names): plate_meta = plate.meta all_vars = plate.variables - if plate_meta: + if plate_meta.names or plate_meta.sizes: # must be preceded by 'cluster' to get a box around it plate_label = create_plate_label( all_vars[0].var.name, plate_meta, include_size=include_shape_size @@ -509,7 +509,7 @@ def make_networkx( for plate in self.get_plates(var_names): plate_meta = plate.meta all_vars = plate.variables - if plate_meta: + if plate_meta.names or plate_meta.sizes: # # must be preceded by 'cluster' to get a box around it plate_label = create_plate_label( From f492a031f49c0e1fe26a2ace4e56e5c925eda1d7 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 27 Jun 2024 12:29:24 +0200 Subject: [PATCH 05/17] pull out methods into functions --- pymc/model_graph.py | 316 +++++++++++++++++++++++--------------------- 1 file changed, 165 insertions(+), 151 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 8063434487..c6332dde26 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -49,6 +49,9 @@ class PlateMeta: def __hash__(self): return hash((self.names, self.sizes)) + def __bool__(self) -> bool: + return len(self.sizes) > 0 or len(self.names) > 0 + def create_plate_label( var_name: str, @@ -97,7 +100,7 @@ def __hash__(self): @dataclass class Plate: - meta: PlateMeta + meta: PlateMeta | None variables: list[NodeMeta] @@ -204,6 +207,24 @@ def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTyp return node_formatters +def _make_node( + node: NodeMeta, + *, + node_formatters: NodeTypeFormatterMapping, + add_node: Callable[[str, ...], None], + cluster: bool = False, + formatting: str = "plain", +): + """Attaches the given variable to a graphviz or networkx Digraph""" + node_formatter = node_formatters[node.node_type] + kwargs = node_formatter(node.var) + + if cluster: + kwargs["cluster"] = cluster + + add_node(node.var.name.replace(":", "&"), **kwargs) + + class ModelGraph: def __init__(self, model): self.model = model @@ -311,24 +332,6 @@ def make_compute_graph( return input_map - def _make_node( - self, - node: NodeMeta, - *, - node_formatters: NodeTypeFormatterMapping, - add_node: Callable[[str, ...], None], - cluster: bool = False, - formatting: str = "plain", - ): - """Attaches the given variable to a graphviz or networkx Digraph""" - node_formatter = node_formatters[node.node_type] - kwargs = node_formatter(node.var) - - if cluster: - kwargs["cluster"] = cluster - - add_node(node.var.name.replace(":", "&"), **kwargs) - def get_plates( self, var_names: Iterable[VarName] | None = None, @@ -380,7 +383,7 @@ def get_plates( plates[plate_meta].add(var) return [ - Plate(meta=plate_meta, variables=list(variables)) + Plate(meta=plate_meta if plate_meta else None, variables=list(variables)) for plate_meta, variables in plates.items() ] @@ -407,150 +410,153 @@ def edges( for parent in parents ] - def make_graph( - self, - var_names: Iterable[VarName] | None = None, - formatting: str = "plain", - save=None, - figsize=None, - dpi=300, - node_formatters: NodeTypeFormatterMapping | None = None, - include_shape_size: bool = True, - ): - """Make graphviz Digraph of PyMC model - Returns - ------- - graphviz.Digraph - """ - try: - import graphviz - except ImportError: - raise ImportError( - "This function requires the python library graphviz, along with binaries. " - "The easiest way to install all of this is by running\n\n" - "\tconda install -c conda-forge python-graphviz" - ) +def make_graph( + name: str, + plates: list[Plate], + edges: list[tuple[VarName, VarName]], + formatting: str = "plain", + save=None, + figsize=None, + dpi=300, + node_formatters: NodeTypeFormatterMapping | None = None, + include_shape_size: bool = True, +): + """Make graphviz Digraph of PyMC model - node_formatters = node_formatters or {} - node_formatters = update_node_formatters(node_formatters) - - graph = graphviz.Digraph(self.model.name) - for plate in self.get_plates(var_names): - plate_meta = plate.meta - all_vars = plate.variables - if plate_meta.names or plate_meta.sizes: - # must be preceded by 'cluster' to get a box around it - plate_label = create_plate_label( - all_vars[0].var.name, plate_meta, include_size=include_shape_size - ) - with graph.subgraph(name="cluster" + plate_label) as sub: - for var in all_vars: - self._make_node( - node=var, - formatting=formatting, - node_formatters=node_formatters, - add_node=sub.node, - ) - # plate label goes bottom right - sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") - else: - for var in all_vars: - self._make_node( + Returns + ------- + graphviz.Digraph + """ + try: + import graphviz + except ImportError: + raise ImportError( + "This function requires the python library graphviz, along with binaries. " + "The easiest way to install all of this is by running\n\n" + "\tconda install -c conda-forge python-graphviz" + ) + + node_formatters = node_formatters or {} + node_formatters = update_node_formatters(node_formatters) + + graph = graphviz.Digraph(name) + for plate in plates: + if plate.meta: + # must be preceded by 'cluster' to get a box around it + plate_label = create_plate_label( + plate.variables[0].var.name, + plate.meta, + include_size=include_shape_size, + ) + with graph.subgraph(name="cluster" + plate_label) as sub: + for var in plate.variables: + _make_node( node=var, formatting=formatting, node_formatters=node_formatters, - add_node=graph.node, + add_node=sub.node, ) + # plate label goes bottom right + sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") + else: + for var in plate.variables: + _make_node( + node=var, + formatting=formatting, + node_formatters=node_formatters, + add_node=graph.node, + ) - for child, parent in self.edges(var_names=var_names): - graph.edge(parent, child) + for child, parent in edges: + graph.edge(parent, child) - if save is not None: - width, height = (None, None) if figsize is None else figsize - base, ext = path.splitext(save) - if ext: - ext = ext.replace(".", "") - else: - ext = "png" - graph_c = graph.copy() - graph_c.graph_attr.update(size=f"{width},{height}!") - graph_c.graph_attr.update(dpi=str(dpi)) - graph_c.render(filename=base, format=ext, cleanup=True) + if save is not None: + width, height = (None, None) if figsize is None else figsize + base, ext = path.splitext(save) + if ext: + ext = ext.replace(".", "") + else: + ext = "png" + graph_c = graph.copy() + graph_c.graph_attr.update(size=f"{width},{height}!") + graph_c.graph_attr.update(dpi=str(dpi)) + graph_c.render(filename=base, format=ext, cleanup=True) - return graph + return graph - def make_networkx( - self, - var_names: Iterable[VarName] | None = None, - formatting: str = "plain", - node_formatters: NodeTypeFormatterMapping | None = None, - include_shape_size: bool = True, - ): - """Make networkx Digraph of PyMC model - Returns - ------- - networkx.Digraph - """ - try: - import networkx - except ImportError: - raise ImportError( - "This function requires the python library networkx, along with binaries. " - "The easiest way to install all of this is by running\n\n" - "\tconda install networkx" - ) +def make_networkx( + name: str, + plates: list[Plate], + edges: list[tuple[VarName, VarName]], + formatting: str = "plain", + node_formatters: NodeTypeFormatterMapping | None = None, + include_shape_size: bool = True, +): + """Make networkx Digraph of PyMC model - node_formatters = node_formatters or {} - node_formatters = update_node_formatters(node_formatters) + Returns + ------- + networkx.Digraph + """ + try: + import networkx + except ImportError: + raise ImportError( + "This function requires the python library networkx, along with binaries. " + "The easiest way to install all of this is by running\n\n" + "\tconda install networkx" + ) - graphnetwork = networkx.DiGraph(name=self.model.name) - for plate in self.get_plates(var_names): - plate_meta = plate.meta - all_vars = plate.variables - if plate_meta.names or plate_meta.sizes: - # # must be preceded by 'cluster' to get a box around it + node_formatters = node_formatters or {} + node_formatters = update_node_formatters(node_formatters) - plate_label = create_plate_label( - all_vars[0].var.name, plate_meta, include_size=include_shape_size - ) - subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) + graphnetwork = networkx.DiGraph(name=name) + for plate in plates: + if plate.meta: + # # must be preceded by 'cluster' to get a box around it - for var in all_vars: - self._make_node( - node=var, - node_formatters=node_formatters, - cluster="cluster" + plate_label, - formatting=formatting, - add_node=subgraphnetwork.add_node, - ) - for sgn in subgraphnetwork.nodes: - networkx.set_node_attributes( - subgraphnetwork, - {sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}}, - ) - node_data = { - e[0]: e[1] - for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True) - } - - graphnetwork = networkx.compose(graphnetwork, subgraphnetwork) - networkx.set_node_attributes(graphnetwork, node_data) - graphnetwork.graph["name"] = self.model.name - else: - for var in all_vars: - self._make_node( - node=var, - formatting=formatting, - node_formatters=node_formatters, - add_node=graphnetwork.add_node, - ) + plate_label = create_plate_label( + plate.variables[0].var.name, + plate.meta, + include_size=include_shape_size, + ) + subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) + + for var in plate.variables: + _make_node( + node=var, + node_formatters=node_formatters, + cluster="cluster" + plate_label, + formatting=formatting, + add_node=subgraphnetwork.add_node, + ) + for sgn in subgraphnetwork.nodes: + networkx.set_node_attributes( + subgraphnetwork, + {sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}}, + ) + node_data = { + e[0]: e[1] for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True) + } + + graphnetwork = networkx.compose(graphnetwork, subgraphnetwork) + networkx.set_node_attributes(graphnetwork, node_data) + graphnetwork.graph["name"] = name + else: + for var in plate.variables: + _make_node( + node=var, + formatting=formatting, + node_formatters=node_formatters, + add_node=graphnetwork.add_node, + ) - for child, parents in self.edges(var_names=var_names): - graphnetwork.add_edge(parents, child) + for child, parents in edges: + graphnetwork.add_edge(parents, child) - return graphnetwork + return graphnetwork def model_to_networkx( @@ -633,9 +639,13 @@ def model_to_networkx( UserWarning, stacklevel=2, ) + model = pm.modelcontext(model) - return ModelGraph(model).make_networkx( - var_names=var_names, + graph = ModelGraph(model) + return make_networkx( + name=model.name, + plates=graph.get_plates(var_names=var_names), + edges=graph.edges(var_names=var_names), formatting=formatting, node_formatters=node_formatters, include_shape_size=include_shape_size, @@ -738,9 +748,13 @@ def model_to_graphviz( UserWarning, stacklevel=2, ) + model = pm.modelcontext(model) - return ModelGraph(model).make_graph( - var_names=var_names, + graph = ModelGraph(model) + return make_graph( + model.name, + plates=graph.get_plates(var_names=var_names), + edges=graph.edges(var_names=var_names), formatting=formatting, save=save, figsize=figsize, From d1b5390dbe75eeb40048528690f2fb14e517ef25 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 28 Jun 2024 09:42:13 +0200 Subject: [PATCH 06/17] change name and loop over at begining --- pymc/model_graph.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index c6332dde26..9292cc4f41 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -43,7 +43,7 @@ @dataclass class PlateMeta: - names: tuple[str] + names: tuple[str | None] sizes: tuple[int] def __hash__(self): @@ -352,12 +352,15 @@ def get_plates( # This should help find discrepencies, and # avoids unnecessary function compiles for determining labels. dim_lengths: dict[str, int] = { - name: fast_eval(value).item() for name, value in self.model.dim_lengths.items() + dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items() + } + var_shapes: dict[str, tuple[int]] = { + var_name: tuple(fast_eval(self.model[var_name].shape)) + for var_name in self.vars_to_plot(var_names) } for var_name in self.vars_to_plot(var_names): - v = self.model[var_name] - shape: tuple[int, ...] = tuple(fast_eval(v.shape)) + shape: tuple[int] = var_shapes[var_name] if var_name in self.model.named_vars_to_dims: # The RV is associated with `dims` information. names = [] @@ -420,7 +423,7 @@ def make_graph( figsize=None, dpi=300, node_formatters: NodeTypeFormatterMapping | None = None, - include_shape_size: bool = True, + include_dim_lengths: bool = True, ): """Make graphviz Digraph of PyMC model @@ -447,7 +450,7 @@ def make_graph( plate_label = create_plate_label( plate.variables[0].var.name, plate.meta, - include_size=include_shape_size, + include_size=include_dim_lengths, ) with graph.subgraph(name="cluster" + plate_label) as sub: for var in plate.variables: @@ -492,7 +495,7 @@ def make_networkx( edges: list[tuple[VarName, VarName]], formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, - include_shape_size: bool = True, + include_dim_lengths: bool = True, ): """Make networkx Digraph of PyMC model @@ -520,7 +523,7 @@ def make_networkx( plate_label = create_plate_label( plate.variables[0].var.name, plate.meta, - include_size=include_shape_size, + include_size=include_dim_lengths, ) subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) @@ -565,7 +568,7 @@ def model_to_networkx( var_names: Iterable[VarName] | None = None, formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, - include_shape_size: bool = True, + include_dim_lengths: bool = True, ): """Produce a networkx Digraph from a PyMC model. @@ -591,8 +594,8 @@ def model_to_networkx( A dictionary mapping node types to functions that return a dictionary of node attributes. Check out the networkx documentation for more information how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html - include_shape_size : bool - Include the shape size in the plate label. Default is True. + include_dim_lengths : bool + Include the dim length in the plate label. Default is True. Examples -------- @@ -648,7 +651,7 @@ def model_to_networkx( edges=graph.edges(var_names=var_names), formatting=formatting, node_formatters=node_formatters, - include_shape_size=include_shape_size, + include_dim_lengths=include_dim_lengths, ) @@ -661,7 +664,7 @@ def model_to_graphviz( figsize: tuple[int, int] | None = None, dpi: int = 300, node_formatters: NodeTypeFormatterMapping | None = None, - include_shape_size: bool = True, + include_dim_lengths: bool = True, ): """Produce a graphviz Digraph from a PyMC model. @@ -693,8 +696,8 @@ def model_to_graphviz( A dictionary mapping node types to functions that return a dictionary of node attributes. Check out graphviz documentation for more information on available attributes. https://graphviz.org/docs/nodes/ - include_shape_size : bool - Include the shape size in the plate label. Default is True. + include_dim_lengths : bool + Include the dim lengths in the plate label. Default is True. Examples -------- @@ -760,5 +763,5 @@ def model_to_graphviz( figsize=figsize, dpi=dpi, node_formatters=node_formatters, - include_shape_size=include_shape_size, + include_dim_lengths=include_dim_lengths, ) From 66675576356011daa9d57cb60b1b85fe4e32fe66 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 28 Jun 2024 09:49:31 +0200 Subject: [PATCH 07/17] test none dim --- tests/test_model_graph.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 718f38e53b..0abd7dc865 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -24,7 +24,15 @@ import pymc as pm from pymc.exceptions import ImputationWarning -from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx +from pymc.model_graph import ( + ModelGraph, + NodeMeta, + NodeType, + Plate, + PlateMeta, + model_to_graphviz, + model_to_networkx, +) def school_model(): @@ -473,3 +481,24 @@ def test_custom_node_formatting_graphviz(simple_model): ] ) assert body == items + + +def test_none_dim_in_plate_meta() -> None: + coords = { + "obs": range(5), + } + with pm.Model(coords=coords) as model: + C = pt.as_tensor_variable( + np.ones((5, 3)), + name="C", + ) + pm.Deterministic("C", C, dims=("obs", None)) + + graph = ModelGraph(model) + + assert graph.get_plates() == [ + Plate( + meta=PlateMeta(names=("obs", None), sizes=(5, 3)), + variables=[NodeMeta(var=model["C"], node_type=NodeType.DETERMINISTIC)], + ), + ] From 559dc423e6670f129d9e5efc5b8132d796bcdf09 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 28 Jun 2024 10:11:06 +0200 Subject: [PATCH 08/17] rename away from meta --- pymc/model_graph.py | 40 +++++++++++++++++++++------------------ tests/test_model_graph.py | 32 +++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 9292cc4f41..46ca70cce5 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -42,7 +42,7 @@ @dataclass -class PlateMeta: +class DimInfo: names: tuple[str | None] sizes: tuple[int] @@ -55,7 +55,7 @@ def __bool__(self) -> bool: def create_plate_label( var_name: str, - plate_meta: PlateMeta, + dim_info: DimInfo, include_size: bool = True, ) -> str: def create_label(d: int, dname: str, dlen: int): @@ -70,7 +70,7 @@ def create_label(d: int, dname: str, dlen: int): return label values = enumerate( - zip_longest(plate_meta.names, plate_meta.sizes, fillvalue=None), + zip_longest(dim_info.names, dim_info.sizes, fillvalue=None), ) return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) @@ -90,7 +90,7 @@ class NodeType(str, Enum): @dataclass -class NodeMeta: +class NodeInfo: var: TensorVariable node_type: NodeType @@ -100,8 +100,8 @@ def __hash__(self): @dataclass class Plate: - meta: PlateMeta | None - variables: list[NodeMeta] + dim_info: DimInfo | None + variables: list[NodeInfo] GraphvizNodeKwargs = dict[str, Any] @@ -208,7 +208,7 @@ def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTyp def _make_node( - node: NodeMeta, + node: NodeInfo, *, node_formatters: NodeTypeFormatterMapping, add_node: Callable[[str, ...], None], @@ -369,25 +369,29 @@ def get_plates( names.append(dname) sizes.append(dim_lengths.get(dname, shape[d])) - plate_meta = PlateMeta( + dim_info = DimInfo( names=tuple(names), sizes=tuple(sizes), ) else: # The RV has no `dims` information. - plate_meta = PlateMeta( - names=(), + dim_size = len(shape) + dim_info = DimInfo( + names=tuple([None] * dim_size), sizes=tuple(shape), ) v = self.model[var_name] node_type = get_node_type(var_name, self.model) - var = NodeMeta(var=v, node_type=node_type) - plates[plate_meta].add(var) + var = NodeInfo(var=v, node_type=node_type) + plates[dim_info].add(var) return [ - Plate(meta=plate_meta if plate_meta else None, variables=list(variables)) - for plate_meta, variables in plates.items() + Plate( + dim_info=dim_info if dim_info else None, + variables=list(variables), + ) + for dim_info, variables in plates.items() ] def edges( @@ -445,11 +449,11 @@ def make_graph( graph = graphviz.Digraph(name) for plate in plates: - if plate.meta: + if plate.dim_info: # must be preceded by 'cluster' to get a box around it plate_label = create_plate_label( plate.variables[0].var.name, - plate.meta, + plate.dim_info, include_size=include_dim_lengths, ) with graph.subgraph(name="cluster" + plate_label) as sub: @@ -517,12 +521,12 @@ def make_networkx( graphnetwork = networkx.DiGraph(name=name) for plate in plates: - if plate.meta: + if plate.dim_info: # # must be preceded by 'cluster' to get a box around it plate_label = create_plate_label( plate.variables[0].var.name, - plate.meta, + plate.dim_info, include_size=include_dim_lengths, ) subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 0abd7dc865..e884378d20 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -25,11 +25,11 @@ from pymc.exceptions import ImputationWarning from pymc.model_graph import ( + DimInfo, ModelGraph, - NodeMeta, + NodeInfo, NodeType, Plate, - PlateMeta, model_to_graphviz, model_to_networkx, ) @@ -483,22 +483,38 @@ def test_custom_node_formatting_graphviz(simple_model): assert body == items -def test_none_dim_in_plate_meta() -> None: +def test_none_dim_in_plate() -> None: coords = { "obs": range(5), } with pm.Model(coords=coords) as model: - C = pt.as_tensor_variable( + data = pt.as_tensor_variable( np.ones((5, 3)), - name="C", + name="data", ) - pm.Deterministic("C", C, dims=("obs", None)) + pm.Deterministic("C", data, dims=("obs", None)) graph = ModelGraph(model) assert graph.get_plates() == [ Plate( - meta=PlateMeta(names=("obs", None), sizes=(5, 3)), - variables=[NodeMeta(var=model["C"], node_type=NodeType.DETERMINISTIC)], + dim_info=DimInfo(names=("obs", None), sizes=(5, 3)), + variables=[NodeInfo(var=model["C"], node_type=NodeType.DETERMINISTIC)], ), ] + assert graph.edges() == [] + + +def test_shape_without_dims() -> None: + with pm.Model() as model: + pm.Normal("mu", shape=5) + + graph = ModelGraph(model) + + assert graph.get_plates() == [ + Plate( + dim_info=DimInfo(names=(None,), sizes=(5,)), + variables=[NodeInfo(var=model["mu"], node_type=NodeType.FREE_RV)], + ), + ] + assert graph.edges() == [] From aec7ae580ac9eea4dd53171bfa742f6aa9268a5a Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 28 Jun 2024 10:51:11 +0200 Subject: [PATCH 09/17] test for scalar case --- tests/test_model_graph.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index e884378d20..c5c4f4699b 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -518,3 +518,19 @@ def test_shape_without_dims() -> None: ), ] assert graph.edges() == [] + + +def test_scalars_have_no_dim_info() -> None: + with pm.Model() as model: + pm.Normal("x") + + graph = ModelGraph(model) + + assert graph.get_plates() == [ + Plate( + dim_info=None, + variables=[NodeInfo(var=model["x"], node_type=NodeType.FREE_RV)], + ) + ] + + assert graph.edges() == [] From 6d8b2ee7d2a836b2e6441b63fd36118ad84392ac Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 28 Jun 2024 12:03:58 +0200 Subject: [PATCH 10/17] dim info with empty tuples --- pymc/model_graph.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 46ca70cce5..8e26ac4007 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -17,7 +17,6 @@ from collections.abc import Callable, Iterable from dataclasses import dataclass from enum import Enum -from itertools import zip_longest from os import path from typing import Any @@ -46,6 +45,10 @@ class DimInfo: names: tuple[str | None] sizes: tuple[int] + def __post_init__(self) -> None: + if len(self.names) != len(self.sizes): + raise ValueError("The number of names and sizes must be equal.") + def __hash__(self): return hash((self.names, self.sizes)) @@ -69,9 +72,7 @@ def create_label(d: int, dname: str, dlen: int): return label - values = enumerate( - zip_longest(dim_info.names, dim_info.sizes, fillvalue=None), - ) + values = enumerate(zip(dim_info.names, dim_info.sizes)) return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) @@ -388,7 +389,7 @@ def get_plates( return [ Plate( - dim_info=dim_info if dim_info else None, + dim_info=dim_info, variables=list(variables), ) for dim_info, variables in plates.items() From 382a573a30ebd0ccc4d48b464fb3816f0cfa19ec Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 28 Jun 2024 12:04:48 +0200 Subject: [PATCH 11/17] test square None dim case and change scalar expected --- tests/test_model_graph.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index c5c4f4699b..d671afc976 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -489,18 +489,23 @@ def test_none_dim_in_plate() -> None: } with pm.Model(coords=coords) as model: data = pt.as_tensor_variable( - np.ones((5, 3)), + np.ones((5, 5)), name="data", ) pm.Deterministic("C", data, dims=("obs", None)) + pm.Deterministic("D", data.T, dims=(None, "obs")) graph = ModelGraph(model) assert graph.get_plates() == [ Plate( - dim_info=DimInfo(names=("obs", None), sizes=(5, 3)), + dim_info=DimInfo(names=("obs", None), sizes=(5, 5)), variables=[NodeInfo(var=model["C"], node_type=NodeType.DETERMINISTIC)], ), + Plate( + dim_info=DimInfo(names=(None, "obs"), sizes=(5, 5)), + variables=[NodeInfo(var=model["D"], node_type=NodeType.DETERMINISTIC)], + ), ] assert graph.edges() == [] @@ -520,7 +525,7 @@ def test_shape_without_dims() -> None: assert graph.edges() == [] -def test_scalars_have_no_dim_info() -> None: +def test_scalars_dim_info() -> None: with pm.Model() as model: pm.Normal("x") @@ -528,7 +533,7 @@ def test_scalars_have_no_dim_info() -> None: assert graph.get_plates() == [ Plate( - dim_info=None, + dim_info=DimInfo(names=(), sizes=()), variables=[NodeInfo(var=model["x"], node_type=NodeType.FREE_RV)], ) ] From a2e9e60950bb7baaa801025eb4beb20bb34db6a4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 28 Jun 2024 12:10:49 +0200 Subject: [PATCH 12/17] change sizes to lengths --- pymc/model_graph.py | 20 ++++++++++---------- tests/test_model_graph.py | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 8e26ac4007..7e898ca015 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -43,17 +43,17 @@ @dataclass class DimInfo: names: tuple[str | None] - sizes: tuple[int] + lengths: tuple[int] def __post_init__(self) -> None: - if len(self.names) != len(self.sizes): - raise ValueError("The number of names and sizes must be equal.") + if len(self.names) != len(self.lengths): + raise ValueError("The number of names and lengths must be equal.") def __hash__(self): - return hash((self.names, self.sizes)) + return hash((self.names, self.lengths)) def __bool__(self) -> bool: - return len(self.sizes) > 0 or len(self.names) > 0 + return len(self.lengths) > 0 or len(self.names) > 0 def create_plate_label( @@ -72,7 +72,7 @@ def create_label(d: int, dname: str, dlen: int): return label - values = enumerate(zip(dim_info.names, dim_info.sizes)) + values = enumerate(zip(dim_info.names, dim_info.lengths)) return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) @@ -365,21 +365,21 @@ def get_plates( if var_name in self.model.named_vars_to_dims: # The RV is associated with `dims` information. names = [] - sizes = [] + lengths = [] for d, dname in enumerate(self.model.named_vars_to_dims[var_name]): names.append(dname) - sizes.append(dim_lengths.get(dname, shape[d])) + lengths.append(dim_lengths.get(dname, shape[d])) dim_info = DimInfo( names=tuple(names), - sizes=tuple(sizes), + lengths=tuple(lengths), ) else: # The RV has no `dims` information. dim_size = len(shape) dim_info = DimInfo( names=tuple([None] * dim_size), - sizes=tuple(shape), + lengths=tuple(shape), ) v = self.model[var_name] diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index d671afc976..ff6e31b9de 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -499,11 +499,11 @@ def test_none_dim_in_plate() -> None: assert graph.get_plates() == [ Plate( - dim_info=DimInfo(names=("obs", None), sizes=(5, 5)), + dim_info=DimInfo(names=("obs", None), lengths=(5, 5)), variables=[NodeInfo(var=model["C"], node_type=NodeType.DETERMINISTIC)], ), Plate( - dim_info=DimInfo(names=(None, "obs"), sizes=(5, 5)), + dim_info=DimInfo(names=(None, "obs"), lengths=(5, 5)), variables=[NodeInfo(var=model["D"], node_type=NodeType.DETERMINISTIC)], ), ] @@ -518,7 +518,7 @@ def test_shape_without_dims() -> None: assert graph.get_plates() == [ Plate( - dim_info=DimInfo(names=(None,), sizes=(5,)), + dim_info=DimInfo(names=(None,), lengths=(5,)), variables=[NodeInfo(var=model["mu"], node_type=NodeType.FREE_RV)], ), ] @@ -533,7 +533,7 @@ def test_scalars_dim_info() -> None: assert graph.get_plates() == [ Plate( - dim_info=DimInfo(names=(), sizes=()), + dim_info=DimInfo(names=(), lengths=()), variables=[NodeInfo(var=model["x"], node_type=NodeType.FREE_RV)], ) ] From 2411da06e833daeffc8d613f4b1ffe81ef6ac108 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 30 Jun 2024 08:43:15 +0200 Subject: [PATCH 13/17] remove var_name parameter --- pymc/model_graph.py | 56 +++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 7e898ca015..f357d2579f 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -56,24 +56,29 @@ def __bool__(self) -> bool: return len(self.lengths) > 0 or len(self.names) > 0 -def create_plate_label( - var_name: str, +PlateLabelFunc = Callable[[DimInfo], str] + + +def create_plate_label_without_dim_length( dim_info: DimInfo, - include_size: bool = True, ) -> str: - def create_label(d: int, dname: str, dlen: int): - if not dname: - return f"{dlen}" + def create_label(dname: str | None, dlen: int): + return f"{dname}" if dname else f"{dlen}" - label = f"{dname}" + return " x ".join( + create_label(dname, dlen) for (dname, dlen) in zip(dim_info.names, dim_info.lengths) + ) - if include_size: - label = f"{label} ({dlen})" - return label +def create_plate_label_with_dim_length( + dim_info: DimInfo, +) -> str: + def create_label(dname: str | None, dlen: int): + return f"{dname} ({dlen})" if dname else f"{dlen}" - values = enumerate(zip(dim_info.names, dim_info.lengths)) - return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) + return " x ".join( + create_label(dname, dlen) for (dname, dlen) in zip(dim_info.names, dim_info.lengths) + ) def fast_eval(var): @@ -101,7 +106,7 @@ def __hash__(self): @dataclass class Plate: - dim_info: DimInfo | None + dim_info: DimInfo variables: list[NodeInfo] @@ -428,7 +433,7 @@ def make_graph( figsize=None, dpi=300, node_formatters: NodeTypeFormatterMapping | None = None, - include_dim_lengths: bool = True, + create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, ): """Make graphviz Digraph of PyMC model @@ -452,11 +457,8 @@ def make_graph( for plate in plates: if plate.dim_info: # must be preceded by 'cluster' to get a box around it - plate_label = create_plate_label( - plate.variables[0].var.name, - plate.dim_info, - include_size=include_dim_lengths, - ) + plate_label = create_plate_label(plate.dim_info) + with graph.subgraph(name="cluster" + plate_label) as sub: for var in plate.variables: _make_node( @@ -500,7 +502,7 @@ def make_networkx( edges: list[tuple[VarName, VarName]], formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, - include_dim_lengths: bool = True, + create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, ): """Make networkx Digraph of PyMC model @@ -525,11 +527,7 @@ def make_networkx( if plate.dim_info: # # must be preceded by 'cluster' to get a box around it - plate_label = create_plate_label( - plate.variables[0].var.name, - plate.dim_info, - include_size=include_dim_lengths, - ) + plate_label = create_plate_label(plate.dim_info) subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) for var in plate.variables: @@ -656,7 +654,9 @@ def model_to_networkx( edges=graph.edges(var_names=var_names), formatting=formatting, node_formatters=node_formatters, - include_dim_lengths=include_dim_lengths, + create_plate_label=create_plate_label_with_dim_length + if include_dim_lengths + else create_plate_label_without_dim_length, ) @@ -768,5 +768,7 @@ def model_to_graphviz( figsize=figsize, dpi=dpi, node_formatters=node_formatters, - include_dim_lengths=include_dim_lengths, + create_plate_label=create_plate_label_with_dim_length + if include_dim_lengths + else create_plate_label_without_dim_length, ) From 633a8cca3c298e2e2df2360e3ebdee5785bb2eff Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 30 Jun 2024 10:06:52 +0200 Subject: [PATCH 14/17] workthrough mypy --- pymc/model_graph.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index f357d2579f..2ef124b467 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from enum import Enum from os import path -from typing import Any +from typing import Any, Protocol, cast from pytensor import function from pytensor.graph import Apply @@ -42,8 +42,8 @@ @dataclass class DimInfo: - names: tuple[str | None] - lengths: tuple[int] + names: tuple[str | None, ...] + lengths: tuple[int, ...] def __post_init__(self) -> None: if len(self.names) != len(self.lengths): @@ -213,22 +213,27 @@ def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTyp return node_formatters +class AddNode(Protocol): + def __call__(self, arg1: str, **kwargs: Any) -> None: ... + + def _make_node( node: NodeInfo, *, node_formatters: NodeTypeFormatterMapping, - add_node: Callable[[str, ...], None], - cluster: bool = False, + add_node: AddNode, + cluster: str | None = None, formatting: str = "plain", ): """Attaches the given variable to a graphviz or networkx Digraph""" node_formatter = node_formatters[node.node_type] kwargs = node_formatter(node.var) - if cluster: + if cluster is not None: kwargs["cluster"] = cluster - add_node(node.var.name.replace(":", "&"), **kwargs) + var_name: str = cast(str, node.var.name) + add_node(var_name.replace(":", "&"), **kwargs) class ModelGraph: @@ -360,13 +365,13 @@ def get_plates( dim_lengths: dict[str, int] = { dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items() } - var_shapes: dict[str, tuple[int]] = { + var_shapes: dict[str, tuple[int, ...]] = { var_name: tuple(fast_eval(self.model[var_name].shape)) for var_name in self.vars_to_plot(var_names) } for var_name in self.vars_to_plot(var_names): - shape: tuple[int] = var_shapes[var_name] + shape: tuple[int, ...] = var_shapes[var_name] if var_name in self.model.named_vars_to_dims: # The RV is associated with `dims` information. names = [] @@ -458,8 +463,9 @@ def make_graph( if plate.dim_info: # must be preceded by 'cluster' to get a box around it plate_label = create_plate_label(plate.dim_info) + plate_name = f"cluster{plate_label}" - with graph.subgraph(name="cluster" + plate_label) as sub: + with graph.subgraph(name=plate_name) as sub: for var in plate.variables: _make_node( node=var, @@ -528,13 +534,14 @@ def make_networkx( # # must be preceded by 'cluster' to get a box around it plate_label = create_plate_label(plate.dim_info) - subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) + plate_name = f"cluster{plate_label}" + subgraphnetwork = networkx.DiGraph(name=plate_name, label=plate_label) for var in plate.variables: _make_node( node=var, node_formatters=node_formatters, - cluster="cluster" + plate_label, + cluster=plate_name, formatting=formatting, add_node=subgraphnetwork.add_node, ) From 950409dcb2df69b1593635bd1f7b6fda79ee857d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 3 Jul 2024 12:49:53 +0200 Subject: [PATCH 15/17] use inline in generator --- pymc/model_graph.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 2ef124b467..52b04e4c87 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -62,22 +62,18 @@ def __bool__(self) -> bool: def create_plate_label_without_dim_length( dim_info: DimInfo, ) -> str: - def create_label(dname: str | None, dlen: int): - return f"{dname}" if dname else f"{dlen}" - return " x ".join( - create_label(dname, dlen) for (dname, dlen) in zip(dim_info.names, dim_info.lengths) + f"{dname}" if dname else f"{dlen}" + for (dname, dlen) in zip(dim_info.names, dim_info.lengths) ) def create_plate_label_with_dim_length( dim_info: DimInfo, ) -> str: - def create_label(dname: str | None, dlen: int): - return f"{dname} ({dlen})" if dname else f"{dlen}" - return " x ".join( - create_label(dname, dlen) for (dname, dlen) in zip(dim_info.names, dim_info.lengths) + f"{dname} ({dlen})" if dname else f"{dlen}" + for (dname, dlen) in zip(dim_info.names, dim_info.lengths) ) From b9bcf92a3c131d8edcda9245627b6a15e68ffc87 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 3 Jul 2024 13:38:00 +0200 Subject: [PATCH 16/17] adjust previous tests --- pymc/model_graph.py | 6 ++ tests/test_model_graph.py | 176 +++++++++++++++++++++++++++++--------- 2 files changed, 140 insertions(+), 42 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 52b04e4c87..2f6d24e795 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -105,6 +105,12 @@ class Plate: dim_info: DimInfo variables: list[NodeInfo] + def __eq__(self, other) -> bool: + if not isinstance(other, Plate): + return False + + return self.dim_info == other.dim_info and set(self.variables) == set(other.variables) + GraphvizNodeKwargs = dict[str, Any] NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index ff6e31b9de..5c2633f226 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -35,6 +35,10 @@ ) +def sort_plates(plates: list[Plate]) -> list[Plate]: + return sorted(plates, key=lambda x: x.dim_info.lengths) + + def school_model(): """ Schools model to use in testing model_to_networkx function @@ -130,13 +134,36 @@ def radon_model(): # of the model variables that the observations belong to: "log_radon": {"y_like"}, } - plates = { - "": {"b", "sigma_a", "sigma_y", "floor_measure_offset"}, - "3": {"gamma"}, - "85": {"eps_a"}, - "919": {"a", "mu_a", "y_like", "log_radon"}, - } - return model, compute_graph, plates + plates = [ + Plate( + dim_info=DimInfo(names=(), lengths=()), + variables=[ + NodeInfo(var=model["b"], node_type=NodeType.FREE_RV), + NodeInfo(var=model["sigma_a"], node_type=NodeType.FREE_RV), + NodeInfo(var=model["sigma_y"], node_type=NodeType.FREE_RV), + NodeInfo(var=model["floor_measure_offset"], node_type=NodeType.DATA), + ], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(3,)), + variables=[NodeInfo(var=model["gamma"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(85,)), + variables=[NodeInfo(var=model["eps_a"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(919,)), + variables=[ + NodeInfo(var=model["a"], node_type=NodeType.DETERMINISTIC), + NodeInfo(var=model["mu_a"], node_type=NodeType.DETERMINISTIC), + NodeInfo(var=model["y_like"], node_type=NodeType.OBSERVED_RV), + NodeInfo(var=model["log_radon"], node_type=NodeType.DATA), + ], + ), + ] + + return model, compute_graph, sort_plates(plates) def model_with_imputations(): @@ -156,13 +183,25 @@ def model_with_imputations(): "L_observed": {"a"}, "L": {"L_unobserved", "L_observed"}, } - plates = { - "": {"a"}, - "2": {"L_unobserved"}, - "10": {"L_observed"}, - "12": {"L"}, - } - return model, compute_graph, plates + plates = [ + Plate( + dim_info=DimInfo(names=(), lengths=()), + variables=[NodeInfo(var=model["a"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(2,)), + variables=[NodeInfo(var=model["L_unobserved"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(10,)), + variables=[NodeInfo(var=model["L_observed"], node_type=NodeType.OBSERVED_RV)], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(12,)), + variables=[NodeInfo(var=model["L"], node_type=NodeType.DETERMINISTIC)], + ), + ] + return model, compute_graph, sort_plates(plates) def model_with_dims(): @@ -188,15 +227,33 @@ def model_with_dims(): "L": {"tax revenue"}, "observed": {"L"}, } - plates = { - "1": {"economics"}, - "city (4)": {"population"}, - "year (3)": {"time"}, - "year (3) x city (4)": {"tax revenue"}, - "3 x 4": {"L", "observed"}, - } + plates = [ + Plate( + dim_info=DimInfo(names=(None,), lengths=(1,)), + variables=[NodeInfo(var=pmodel["economics"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo(names=("city",), lengths=(4,)), + variables=[NodeInfo(var=pmodel["population"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo(names=("year",), lengths=(3,)), + variables=[NodeInfo(var=pmodel["time"], node_type=NodeType.DATA)], + ), + Plate( + dim_info=DimInfo(names=("year", "city"), lengths=(3, 4)), + variables=[NodeInfo(var=pmodel["tax revenue"], node_type=NodeType.DETERMINISTIC)], + ), + Plate( + dim_info=DimInfo(names=(None, None), lengths=(3, 4)), + variables=[ + NodeInfo(var=pmodel["L"], node_type=NodeType.OBSERVED_RV), + NodeInfo(var=pmodel["observed"], node_type=NodeType.DATA), + ], + ), + ] - return pmodel, compute_graph, plates + return pmodel, compute_graph, sort_plates(plates) def model_unnamed_observed_node(): @@ -213,12 +270,24 @@ def model_unnamed_observed_node(): "mu": set(), "y": {"mu"}, } - plates = { - "": {"mu"}, - "4": {"y"}, - } + plates = [ + Plate( + dim_info=DimInfo( + names=(), + lengths=(), + ), + variables=[NodeInfo(var=model["mu"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo( + names=(None,), + lengths=(4,), + ), + variables=[NodeInfo(var=model["y"], node_type=NodeType.OBSERVED_RV)], + ), + ] - return model, compute_graph, plates + return model, compute_graph, sort_plates(plates) def model_observation_dtype_casting(): @@ -235,9 +304,21 @@ def model_observation_dtype_casting(): "response": {"p"}, "data": {"response"}, } - plates = {"": {"p"}, "4": {"data", "response"}} + plates = [ + Plate( + dim_info=DimInfo(names=(), lengths=()), + variables=[NodeInfo(var=model["p"], node_type=NodeType.FREE_RV)], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(4,)), + variables=[ + NodeInfo(var=model["data"], node_type=NodeType.DATA), + NodeInfo(var=model["response"], node_type=NodeType.OBSERVED_RV), + ], + ), + ] - return model, compute_graph, plates + return model, compute_graph, sort_plates(plates) def model_non_random_variable_rvs(): @@ -262,12 +343,21 @@ def model_non_random_variable_rvs(): "y": {"mu"}, "z": {"y"}, } - plates = { - "": {"mu", "y"}, - "5": {"z"}, - } + plates = [ + Plate( + dim_info=DimInfo(names=(), lengths=()), + variables=[ + NodeInfo(var=model["mu"], node_type=NodeType.FREE_RV), + NodeInfo(var=model["y"], node_type=NodeType.FREE_RV), + ], + ), + Plate( + dim_info=DimInfo(names=(None,), lengths=(5,)), + variables=[NodeInfo(var=model["z"], node_type=NodeType.OBSERVED_RV)], + ), + ] - return model, compute_graph, plates + return model, compute_graph, sort_plates(plates) class BaseModelGraphTest: @@ -296,14 +386,11 @@ def test_compute_graph(self): assert actual == expected def test_plates(self): - assert self.plates == self.model_graph.get_plates() + assert self.plates == sort_plates(self.model_graph.get_plates()) def test_graphviz(self): # just make sure everything runs without error - g = self.model_graph.make_graph() - for key in self.compute_graph: - assert key in g.source g = model_to_graphviz(self.model) for key in self.compute_graph: assert key in g.source @@ -354,10 +441,15 @@ def test_issue_6335_dims_containing_none(self): pm.Deterministic("n", data, dims=(None, "time")) mg = ModelGraph(pmodel) - plates_actual = mg.get_plates() - plates_expected = { - "n_dim0 (3) x time (5)": {"n"}, - } + plates_actual = sort_plates(mg.get_plates()) + plates_expected = sort_plates( + [ + Plate( + dim_info=DimInfo(names=(None, "time"), lengths=(3, 5)), + variables=[NodeInfo(var=pmodel["n"], node_type=NodeType.DETERMINISTIC)], + ), + ] + ) assert plates_actual == plates_expected From e30f6d9201ffe480c1e87fa9041cfada309cceb1 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 3 Jul 2024 13:53:48 +0200 Subject: [PATCH 17/17] get rid of protocol --- pymc/model_graph.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 2f6d24e795..1c230fc5a1 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from enum import Enum from os import path -from typing import Any, Protocol, cast +from typing import Any, cast from pytensor import function from pytensor.graph import Apply @@ -215,8 +215,7 @@ def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTyp return node_formatters -class AddNode(Protocol): - def __call__(self, arg1: str, **kwargs: Any) -> None: ... +AddNode = Callable[[str, GraphvizNodeKwargs], None] def _make_node( @@ -235,7 +234,7 @@ def _make_node( kwargs["cluster"] = cluster var_name: str = cast(str, node.var.name) - add_node(var_name.replace(":", "&"), **kwargs) + add_node(var_name.replace(":", "&"), **kwargs) # type: ignore class ModelGraph: