-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Refactor model graph and allow suppressing dim lengths #7392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
be7a7f5
f96502f
962eab8
38428de
f492a03
d1b5390
6667557
559dc42
aec7ae5
6d8b2ee
382a573
a2e9e60
2411da0
633a8cc
950409d
b9bcf92
e30f6d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,9 @@ class PlateMeta: | |
| def __hash__(self): | ||
| return hash((self.names, self.sizes)) | ||
|
|
||
| def __bool__(self) -> bool: | ||
| return len(self.sizes) > 0 or len(self.names) > 0 | ||
|
|
||
|
|
||
| def create_plate_label( | ||
| var_name: str, | ||
|
|
@@ -97,7 +100,7 @@ def __hash__(self): | |
|
|
||
| @dataclass | ||
| class Plate: | ||
| meta: PlateMeta | ||
| meta: PlateMeta | None | ||
| variables: list[NodeMeta] | ||
|
|
||
|
|
||
|
|
@@ -204,6 +207,24 @@ def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTyp | |
| return node_formatters | ||
|
|
||
|
|
||
| def _make_node( | ||
| node: NodeMeta, | ||
| *, | ||
| node_formatters: NodeTypeFormatterMapping, | ||
| add_node: Callable[[str, ...], None], | ||
| cluster: bool = False, | ||
| formatting: str = "plain", | ||
| ): | ||
| """Attaches the given variable to a graphviz or networkx Digraph""" | ||
| node_formatter = node_formatters[node.node_type] | ||
| kwargs = node_formatter(node.var) | ||
|
|
||
| if cluster: | ||
| kwargs["cluster"] = cluster | ||
|
|
||
| add_node(node.var.name.replace(":", "&"), **kwargs) | ||
|
|
||
|
|
||
| class ModelGraph: | ||
| def __init__(self, model): | ||
| self.model = model | ||
|
|
@@ -311,24 +332,6 @@ def make_compute_graph( | |
|
|
||
| return input_map | ||
|
|
||
| def _make_node( | ||
| self, | ||
| node: NodeMeta, | ||
| *, | ||
| node_formatters: NodeTypeFormatterMapping, | ||
| add_node: Callable[[str, ...], None], | ||
| cluster: bool = False, | ||
| formatting: str = "plain", | ||
| ): | ||
| """Attaches the given variable to a graphviz or networkx Digraph""" | ||
| node_formatter = node_formatters[node.node_type] | ||
| kwargs = node_formatter(node.var) | ||
|
|
||
| if cluster: | ||
| kwargs["cluster"] = cluster | ||
|
|
||
| add_node(node.var.name.replace(":", "&"), **kwargs) | ||
|
|
||
| def get_plates( | ||
| self, | ||
| var_names: Iterable[VarName] | None = None, | ||
|
|
@@ -380,7 +383,7 @@ def get_plates( | |
| plates[plate_meta].add(var) | ||
|
|
||
| return [ | ||
| Plate(meta=plate_meta, variables=list(variables)) | ||
| Plate(meta=plate_meta if plate_meta else None, variables=list(variables)) | ||
| for plate_meta, variables in plates.items() | ||
| ] | ||
|
|
||
|
|
@@ -407,150 +410,153 @@ def edges( | |
| for parent in parents | ||
| ] | ||
|
|
||
| def make_graph( | ||
| self, | ||
| var_names: Iterable[VarName] | None = None, | ||
| formatting: str = "plain", | ||
| save=None, | ||
| figsize=None, | ||
| dpi=300, | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
| ): | ||
| """Make graphviz Digraph of PyMC model | ||
|
|
||
| Returns | ||
| ------- | ||
| graphviz.Digraph | ||
| """ | ||
| try: | ||
| import graphviz | ||
| except ImportError: | ||
| raise ImportError( | ||
| "This function requires the python library graphviz, along with binaries. " | ||
| "The easiest way to install all of this is by running\n\n" | ||
| "\tconda install -c conda-forge python-graphviz" | ||
| ) | ||
| def make_graph( | ||
| name: str, | ||
| plates: list[Plate], | ||
| edges: list[tuple[VarName, VarName]], | ||
| formatting: str = "plain", | ||
| save=None, | ||
| figsize=None, | ||
| dpi=300, | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
williambdean marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ): | ||
| """Make graphviz Digraph of PyMC model | ||
|
|
||
| node_formatters = node_formatters or {} | ||
| node_formatters = update_node_formatters(node_formatters) | ||
|
|
||
| graph = graphviz.Digraph(self.model.name) | ||
| for plate in self.get_plates(var_names): | ||
| plate_meta = plate.meta | ||
| all_vars = plate.variables | ||
| if plate_meta.names or plate_meta.sizes: | ||
| # must be preceded by 'cluster' to get a box around it | ||
| plate_label = create_plate_label( | ||
| all_vars[0].var.name, plate_meta, include_size=include_shape_size | ||
| ) | ||
| with graph.subgraph(name="cluster" + plate_label) as sub: | ||
| for var in all_vars: | ||
| self._make_node( | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=sub.node, | ||
| ) | ||
| # plate label goes bottom right | ||
| sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") | ||
| else: | ||
| for var in all_vars: | ||
| self._make_node( | ||
| Returns | ||
| ------- | ||
| graphviz.Digraph | ||
| """ | ||
| try: | ||
| import graphviz | ||
| except ImportError: | ||
| raise ImportError( | ||
| "This function requires the python library graphviz, along with binaries. " | ||
| "The easiest way to install all of this is by running\n\n" | ||
| "\tconda install -c conda-forge python-graphviz" | ||
| ) | ||
|
|
||
| node_formatters = node_formatters or {} | ||
| node_formatters = update_node_formatters(node_formatters) | ||
|
|
||
| graph = graphviz.Digraph(name) | ||
| for plate in plates: | ||
| if plate.meta: | ||
| # must be preceded by 'cluster' to get a box around it | ||
| plate_label = create_plate_label( | ||
| plate.variables[0].var.name, | ||
| plate.meta, | ||
| include_size=include_shape_size, | ||
| ) | ||
|
||
| with graph.subgraph(name="cluster" + plate_label) as sub: | ||
| for var in plate.variables: | ||
| _make_node( | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=graph.node, | ||
| add_node=sub.node, | ||
| ) | ||
| # plate label goes bottom right | ||
| sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") | ||
| else: | ||
| for var in plate.variables: | ||
| _make_node( | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=graph.node, | ||
| ) | ||
|
|
||
| for child, parent in self.edges(var_names=var_names): | ||
| graph.edge(parent, child) | ||
| for child, parent in edges: | ||
| graph.edge(parent, child) | ||
|
|
||
| if save is not None: | ||
| width, height = (None, None) if figsize is None else figsize | ||
| base, ext = path.splitext(save) | ||
| if ext: | ||
| ext = ext.replace(".", "") | ||
| else: | ||
| ext = "png" | ||
| graph_c = graph.copy() | ||
| graph_c.graph_attr.update(size=f"{width},{height}!") | ||
| graph_c.graph_attr.update(dpi=str(dpi)) | ||
| graph_c.render(filename=base, format=ext, cleanup=True) | ||
| if save is not None: | ||
| width, height = (None, None) if figsize is None else figsize | ||
| base, ext = path.splitext(save) | ||
| if ext: | ||
| ext = ext.replace(".", "") | ||
| else: | ||
| ext = "png" | ||
| graph_c = graph.copy() | ||
| graph_c.graph_attr.update(size=f"{width},{height}!") | ||
| graph_c.graph_attr.update(dpi=str(dpi)) | ||
| graph_c.render(filename=base, format=ext, cleanup=True) | ||
|
|
||
| return graph | ||
| return graph | ||
|
|
||
| def make_networkx( | ||
| self, | ||
| var_names: Iterable[VarName] | None = None, | ||
| formatting: str = "plain", | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
| ): | ||
| """Make networkx Digraph of PyMC model | ||
|
|
||
| Returns | ||
| ------- | ||
| networkx.Digraph | ||
| """ | ||
| try: | ||
| import networkx | ||
| except ImportError: | ||
| raise ImportError( | ||
| "This function requires the python library networkx, along with binaries. " | ||
| "The easiest way to install all of this is by running\n\n" | ||
| "\tconda install networkx" | ||
| ) | ||
| def make_networkx( | ||
| name: str, | ||
| plates: list[Plate], | ||
| edges: list[tuple[VarName, VarName]], | ||
| formatting: str = "plain", | ||
| node_formatters: NodeTypeFormatterMapping | None = None, | ||
| include_shape_size: bool = True, | ||
| ): | ||
| """Make networkx Digraph of PyMC model | ||
|
|
||
| node_formatters = node_formatters or {} | ||
| node_formatters = update_node_formatters(node_formatters) | ||
| Returns | ||
| ------- | ||
| networkx.Digraph | ||
| """ | ||
| try: | ||
| import networkx | ||
| except ImportError: | ||
| raise ImportError( | ||
| "This function requires the python library networkx, along with binaries. " | ||
| "The easiest way to install all of this is by running\n\n" | ||
| "\tconda install networkx" | ||
| ) | ||
|
|
||
| graphnetwork = networkx.DiGraph(name=self.model.name) | ||
| for plate in self.get_plates(var_names): | ||
| plate_meta = plate.meta | ||
| all_vars = plate.variables | ||
| if plate_meta.names or plate_meta.sizes: | ||
| # # must be preceded by 'cluster' to get a box around it | ||
| node_formatters = node_formatters or {} | ||
| node_formatters = update_node_formatters(node_formatters) | ||
|
|
||
| plate_label = create_plate_label( | ||
| all_vars[0].var.name, plate_meta, include_size=include_shape_size | ||
| ) | ||
| subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) | ||
| graphnetwork = networkx.DiGraph(name=name) | ||
| for plate in plates: | ||
| if plate.meta: | ||
| # # must be preceded by 'cluster' to get a box around it | ||
|
|
||
| for var in all_vars: | ||
| self._make_node( | ||
| node=var, | ||
| node_formatters=node_formatters, | ||
| cluster="cluster" + plate_label, | ||
| formatting=formatting, | ||
| add_node=subgraphnetwork.add_node, | ||
| ) | ||
| for sgn in subgraphnetwork.nodes: | ||
| networkx.set_node_attributes( | ||
| subgraphnetwork, | ||
| {sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}}, | ||
| ) | ||
| node_data = { | ||
| e[0]: e[1] | ||
| for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True) | ||
| } | ||
|
|
||
| graphnetwork = networkx.compose(graphnetwork, subgraphnetwork) | ||
| networkx.set_node_attributes(graphnetwork, node_data) | ||
| graphnetwork.graph["name"] = self.model.name | ||
| else: | ||
| for var in all_vars: | ||
| self._make_node( | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=graphnetwork.add_node, | ||
| ) | ||
| plate_label = create_plate_label( | ||
| plate.variables[0].var.name, | ||
| plate.meta, | ||
| include_size=include_shape_size, | ||
| ) | ||
| subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) | ||
|
|
||
| for var in plate.variables: | ||
| _make_node( | ||
| node=var, | ||
| node_formatters=node_formatters, | ||
| cluster="cluster" + plate_label, | ||
| formatting=formatting, | ||
| add_node=subgraphnetwork.add_node, | ||
| ) | ||
| for sgn in subgraphnetwork.nodes: | ||
| networkx.set_node_attributes( | ||
| subgraphnetwork, | ||
| {sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}}, | ||
| ) | ||
| node_data = { | ||
| e[0]: e[1] for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True) | ||
| } | ||
|
|
||
| graphnetwork = networkx.compose(graphnetwork, subgraphnetwork) | ||
| networkx.set_node_attributes(graphnetwork, node_data) | ||
| graphnetwork.graph["name"] = name | ||
| else: | ||
| for var in plate.variables: | ||
| _make_node( | ||
| node=var, | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| add_node=graphnetwork.add_node, | ||
| ) | ||
|
|
||
| for child, parents in self.edges(var_names=var_names): | ||
| graphnetwork.add_edge(parents, child) | ||
| for child, parents in edges: | ||
| graphnetwork.add_edge(parents, child) | ||
|
|
||
| return graphnetwork | ||
| return graphnetwork | ||
|
|
||
|
|
||
| def model_to_networkx( | ||
|
|
@@ -633,9 +639,13 @@ def model_to_networkx( | |
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| model = pm.modelcontext(model) | ||
| return ModelGraph(model).make_networkx( | ||
| var_names=var_names, | ||
| graph = ModelGraph(model) | ||
| return make_networkx( | ||
| name=model.name, | ||
| plates=graph.get_plates(var_names=var_names), | ||
| edges=graph.edges(var_names=var_names), | ||
| formatting=formatting, | ||
| node_formatters=node_formatters, | ||
| include_shape_size=include_shape_size, | ||
|
|
@@ -738,9 +748,13 @@ def model_to_graphviz( | |
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| model = pm.modelcontext(model) | ||
| return ModelGraph(model).make_graph( | ||
| var_names=var_names, | ||
| graph = ModelGraph(model) | ||
| return make_graph( | ||
| model.name, | ||
| plates=graph.get_plates(var_names=var_names), | ||
| edges=graph.edges(var_names=var_names), | ||
| formatting=formatting, | ||
| save=save, | ||
| figsize=figsize, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is a plate without names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
y has sizes but no dim names.
Currently creates Plate(meta=PlateMeta(names=(), sizes=(5, )), variables=[NodeMeta(var=y, node_type=...)])
Think there should be some cases to test now that this logic is exposed. Will be much easier to confirm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
names here are dim names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens for a Deterministic with
dims=("test_dim", None)? Apparently we still allow None dims for things that are not RVsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That
yshould benames=(None,)?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking of
pm.Deterministic("x", np.zeros((3, 3)), dims=("hello", None))andpm.Deterministic("y", np.zeros((3, 3)), dims=(None, "hello"). We don't want to put those in the same plate because dims can't be repeated, so they are definitely different things?Can we add a test for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test. I had to wrap the data in as_tensor_variable or I'd get an error saying the data needs name attribute