-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor model graph and allow suppressing dim lengths #7392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
be7a7f5
f96502f
962eab8
38428de
f492a03
d1b5390
6667557
559dc42
aec7ae5
6d8b2ee
382a573
a2e9e60
2411da0
633a8cc
950409d
b9bcf92
e30f6d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,8 +14,10 @@ | |
| import warnings | ||
|
|
||
| from collections import defaultdict | ||
| from collections.abc import Callable, Iterable, Sequence | ||
| from collections.abc import Callable, Iterable | ||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| from itertools import zip_longest | ||
| from os import path | ||
| from typing import Any | ||
|
|
||
|
|
@@ -39,6 +41,37 @@ | |
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class PlateMeta: | ||
| names: tuple[str] | ||
| sizes: tuple[int] | ||
|
|
||
| def __hash__(self): | ||
| return hash((self.names, self.sizes)) | ||
|
|
||
|
|
||
| def create_plate_label( | ||
| var_name: str, | ||
| plate_meta: PlateMeta, | ||
| include_size: bool = True, | ||
| ) -> str: | ||
| def create_label(d: int, dname: str, dlen: int): | ||
| if not dname: | ||
| return f"{dlen}" | ||
|
|
||
| label = f"{dname}" | ||
|
|
||
| if include_size: | ||
| label = f"{label} ({dlen})" | ||
|
|
||
| return label | ||
|
|
||
| values = enumerate( | ||
| zip_longest(plate_meta.names, plate_meta.sizes, fillvalue=None), | ||
| ) | ||
| return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) | ||
|
|
||
|
|
||
| def fast_eval(var): | ||
| return function([], var, mode="FAST_COMPILE")() | ||
|
|
||
|
|
@@ -53,6 +86,21 @@ class NodeType(str, Enum): | |
| DATA = "Data" | ||
|
|
||
|
|
||
| @dataclass | ||
| class NodeMeta: | ||
| var: TensorVariable | ||
| node_type: NodeType | ||
|
|
||
| def __hash__(self): | ||
| return hash(self.var.name) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Plate: | ||
| meta: PlateMeta | ||
| variables: list[NodeMeta] | ||
|
|
||
|
|
||
| GraphvizNodeKwargs = dict[str, Any] | ||
| NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] | ||
|
|
||
|
|
@@ -265,31 +313,26 @@ def make_compute_graph( | |
|
|
||
| def _make_node( | ||
| self, | ||
| var_name, | ||
| graph, | ||
| node: NodeMeta, | ||
| *, | ||
| node_formatters: NodeTypeFormatterMapping, | ||
| nx=False, | ||
| cluster=False, | ||
| add_node: Callable[[str, ...], None], | ||
| cluster: bool = False, | ||
| formatting: str = "plain", | ||
| ): | ||
| """Attaches the given variable to a graphviz or networkx Digraph""" | ||
| v = self.model[var_name] | ||
|
|
||
| node_type = get_node_type(var_name, self.model) | ||
| node_formatter = node_formatters[node_type] | ||
|
|
||
| kwargs = node_formatter(v) | ||
| node_formatter = node_formatters[node.node_type] | ||
| kwargs = node_formatter(node.var) | ||
|
|
||
| if cluster: | ||
| kwargs["cluster"] = cluster | ||
|
|
||
| if nx: | ||
| graph.add_node(var_name.replace(":", "&"), **kwargs) | ||
| else: | ||
| graph.node(var_name.replace(":", "&"), **kwargs) | ||
| add_node(node.var.name.replace(":", "&"), **kwargs) | ||
|
|
||
| def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, set[VarName]]: | ||
| def get_plates( | ||
| self, | ||
| var_names: Iterable[VarName] | None = None, | ||
| ) -> list[Plate]: | ||
| """Rough but surprisingly accurate plate detection. | ||
|
|
||
| Just groups by the shape of the underlying distribution. Will be wrong | ||
|
|
@@ -302,32 +345,67 @@ def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, se | |
| """ | ||
| plates = defaultdict(set) | ||
|
|
||
| # TODO: Evaluate all RV shapes and dim_length at once. | ||
| # This should help to find discrepancies, and | ||
| # avoids unnecessary function compiles for deetermining labels. | ||
| # TODO: Evaluate all RV shapes at once | ||
| # This should help find discrepencies, and | ||
| # avoids unnecessary function compiles for determining labels. | ||
| dim_lengths: dict[str, int] = { | ||
| name: fast_eval(value).item() for name, value in self.model.dim_lengths.items() | ||
| } | ||
|
|
||
| for var_name in self.vars_to_plot(var_names): | ||
| v = self.model[var_name] | ||
| shape: Sequence[int] = fast_eval(v.shape) | ||
| dim_labels = [] | ||
| shape: tuple[int, ...] = tuple(fast_eval(v.shape)) | ||
| if var_name in self.model.named_vars_to_dims: | ||
| # The RV is associated with `dims` information. | ||
| names = [] | ||
| sizes = [] | ||
| for d, dname in enumerate(self.model.named_vars_to_dims[var_name]): | ||
| if dname is None: | ||
| # Unnamed dimension in a `dims` tuple! | ||
| dlen = shape[d] | ||
| dname = f"{var_name}_dim{d}" | ||
williambdean marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| dlen = fast_eval(self.model.dim_lengths[dname]) | ||
| dim_labels.append(f"{dname} ({dlen})") | ||
| plate_label = " x ".join(dim_labels) | ||
| names.append(dname) | ||
| sizes.append(dim_lengths.get(dname, shape[d])) | ||
ricardoV94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| plate_meta = PlateMeta( | ||
| names=tuple(names), | ||
| sizes=tuple(sizes), | ||
| ) | ||
|
||
| else: | ||
| # The RV has no `dims` information. | ||
| dim_labels = [str(x) for x in shape] | ||
| plate_label = " x ".join(map(str, shape)) | ||
| plates[plate_label].add(var_name) | ||
| plate_meta = PlateMeta( | ||
| names=(), | ||
| sizes=tuple(shape), | ||
| ) | ||
|
|
||
| v = self.model[var_name] | ||
| node_type = get_node_type(var_name, self.model) | ||
| var = NodeMeta(var=v, node_type=node_type) | ||
| plates[plate_meta].add(var) | ||
|
|
||
| return [ | ||
| Plate(meta=plate_meta, variables=list(variables)) | ||
| for plate_meta, variables in plates.items() | ||
| ] | ||
|
|
||
| def edges( | ||
| self, | ||
| var_names: Iterable[VarName] | None = None, | ||
| ) -> list[tuple[VarName, VarName]]: | ||
| """Get edges between the variables in the model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| var_names : iterable of str, optional | ||
| Subset of variables to be plotted that identify a subgraph with respect to the entire model graph | ||
|
|
||
| Returns | ||
| ------- | ||
| list of tuple | ||
| List of edges between the variables in the model. | ||
|
|
||
| return dict(plates) | ||
| """ | ||
| return [ | ||
| (VarName(child.replace(":", "&")), VarName(parent.replace(":", "&"))) | ||
| for child, parents in self.make_compute_graph(var_names=var_names).items() | ||
| for parent in parents | ||
| ] | ||
|
|
||
| def make_graph( | ||
|
||
| self, | ||
|
|
@@ -337,6 +415,7 @@ def make_graph( | |
| figsize=None, | ||
| dpi=300, | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
| ): | ||
| """Make graphviz Digraph of PyMC model | ||
|
|
||
|
|
@@ -357,26 +436,35 @@ def make_graph( | |
| node_formatters = update_node_formatters(node_formatters) | ||
|
|
||
| graph = graphviz.Digraph(self.model.name) | ||
| for plate_label, all_var_names in self.get_plates(var_names).items(): | ||
| if plate_label: | ||
| for plate in self.get_plates(var_names): | ||
| plate_meta = plate.meta | ||
| all_vars = plate.variables | ||
| if plate_meta.names or plate_meta.sizes: | ||
|
||
| # must be preceded by 'cluster' to get a box around it | ||
| plate_label = create_plate_label( | ||
| all_vars[0].var.name, plate_meta, include_size=include_shape_size | ||
| ) | ||
| with graph.subgraph(name="cluster" + plate_label) as sub: | ||
| for var_name in all_var_names: | ||
| for var in all_vars: | ||
| self._make_node( | ||
| var_name, sub, formatting=formatting, node_formatters=node_formatters | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=sub.node, | ||
| ) | ||
| # plate label goes bottom right | ||
| sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") | ||
| else: | ||
| for var_name in all_var_names: | ||
| for var in all_vars: | ||
| self._make_node( | ||
| var_name, graph, formatting=formatting, node_formatters=node_formatters | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=graph.node, | ||
| ) | ||
|
|
||
| for child, parents in self.make_compute_graph(var_names=var_names).items(): | ||
| # parents is a set of rv names that precede child rv nodes | ||
| for parent in parents: | ||
| graph.edge(parent.replace(":", "&"), child.replace(":", "&")) | ||
| for child, parent in self.edges(var_names=var_names): | ||
| graph.edge(parent, child) | ||
|
|
||
| if save is not None: | ||
| width, height = (None, None) if figsize is None else figsize | ||
|
|
@@ -397,6 +485,7 @@ def make_networkx( | |
| var_names: Iterable[VarName] | None = None, | ||
| formatting: str = "plain", | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
| ): | ||
| """Make networkx Digraph of PyMC model | ||
|
|
||
|
|
@@ -417,20 +506,24 @@ def make_networkx( | |
| node_formatters = update_node_formatters(node_formatters) | ||
|
|
||
| graphnetwork = networkx.DiGraph(name=self.model.name) | ||
| for plate_label, all_var_names in self.get_plates(var_names).items(): | ||
| if plate_label: | ||
| for plate in self.get_plates(var_names): | ||
| plate_meta = plate.meta | ||
| all_vars = plate.variables | ||
| if plate_meta.names or plate_meta.sizes: | ||
| # # must be preceded by 'cluster' to get a box around it | ||
|
|
||
| plate_label = create_plate_label( | ||
| all_vars[0].var.name, plate_meta, include_size=include_shape_size | ||
| ) | ||
| subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) | ||
|
|
||
| for var_name in all_var_names: | ||
| for var in all_vars: | ||
| self._make_node( | ||
| var_name, | ||
| subgraphnetwork, | ||
| nx=True, | ||
| node=var, | ||
| node_formatters=node_formatters, | ||
| cluster="cluster" + plate_label, | ||
| formatting=formatting, | ||
| add_node=subgraphnetwork.add_node, | ||
| ) | ||
| for sgn in subgraphnetwork.nodes: | ||
| networkx.set_node_attributes( | ||
|
|
@@ -446,19 +539,17 @@ def make_networkx( | |
| networkx.set_node_attributes(graphnetwork, node_data) | ||
| graphnetwork.graph["name"] = self.model.name | ||
| else: | ||
| for var_name in all_var_names: | ||
| for var in all_vars: | ||
| self._make_node( | ||
| var_name, | ||
| graphnetwork, | ||
| nx=True, | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=graphnetwork.add_node, | ||
| ) | ||
|
|
||
| for child, parents in self.make_compute_graph(var_names=var_names).items(): | ||
| # parents is a set of rv names that precede child rv nodes | ||
| for parent in parents: | ||
| graphnetwork.add_edge(parent.replace(":", "&"), child.replace(":", "&")) | ||
| for child, parents in self.edges(var_names=var_names): | ||
| graphnetwork.add_edge(parents, child) | ||
|
|
||
| return graphnetwork | ||
|
|
||
|
|
||
|
|
@@ -468,6 +559,7 @@ def model_to_networkx( | |
| var_names: Iterable[VarName] | None = None, | ||
| formatting: str = "plain", | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
| ): | ||
| """Produce a networkx Digraph from a PyMC model. | ||
|
|
||
|
|
@@ -493,6 +585,8 @@ def model_to_networkx( | |
| A dictionary mapping node types to functions that return a dictionary of node attributes. | ||
| Check out the networkx documentation for more information | ||
| how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html | ||
| include_shape_size : bool | ||
| Include the shape size in the plate label. Default is True. | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
@@ -541,7 +635,10 @@ def model_to_networkx( | |
| ) | ||
| model = pm.modelcontext(model) | ||
| return ModelGraph(model).make_networkx( | ||
| var_names=var_names, formatting=formatting, node_formatters=node_formatters | ||
| var_names=var_names, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| include_shape_size=include_shape_size, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -554,6 +651,7 @@ def model_to_graphviz( | |
| figsize: tuple[int, int] | None = None, | ||
| dpi: int = 300, | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
| ): | ||
| """Produce a graphviz Digraph from a PyMC model. | ||
|
|
||
|
|
@@ -585,6 +683,8 @@ def model_to_graphviz( | |
| A dictionary mapping node types to functions that return a dictionary of node attributes. | ||
| Check out graphviz documentation for more information on available | ||
| attributes. https://graphviz.org/docs/nodes/ | ||
| include_shape_size : bool | ||
| Include the shape size in the plate label. Default is True. | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
@@ -646,4 +746,5 @@ def model_to_graphviz( | |
| figsize=figsize, | ||
| dpi=dpi, | ||
| node_formatters=node_formatters, | ||
| include_shape_size=include_shape_size, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.