Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 165 additions & 151 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class PlateMeta:
def __hash__(self):
return hash((self.names, self.sizes))

def __bool__(self) -> bool:
return len(self.sizes) > 0 or len(self.names) > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a plate without names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with pm.Model(): 
    pm.Normal("y", shape=3)

y has sizes but no dim names.
Currently creates Plate(meta=PlateMeta(names=(), sizes=(5, )), variables=[NodeMeta(var=y, node_type=...)])

Think there should be some cases to test now that this logic is exposed. Will be much easier to confirm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names here are dim names

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens for a Deterministic with dims=("test_dim", None)? Apparently we still allow None dims for things that are not RVs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That y should be names=(None,) ?

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of pm.Deterministic("x", np.zeros((3, 3)), dims=("hello", None)) and pm.Deterministic("y", np.zeros((3, 3)), dims=(None, "hello"). We don't want to put those in the same plate because dims can't be repeated, so they are definitely different things?

Can we add a test for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test. I had to wrap the data in as_tensor_variable or I'd get an error saying the data needs name attribute



def create_plate_label(
var_name: str,
Expand Down Expand Up @@ -97,7 +100,7 @@ def __hash__(self):

@dataclass
class Plate:
meta: PlateMeta
meta: PlateMeta | None
variables: list[NodeMeta]


Expand Down Expand Up @@ -204,6 +207,24 @@ def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTyp
return node_formatters


def _make_node(
node: NodeMeta,
*,
node_formatters: NodeTypeFormatterMapping,
add_node: Callable[[str, ...], None],
cluster: bool = False,
formatting: str = "plain",
):
"""Attaches the given variable to a graphviz or networkx Digraph"""
node_formatter = node_formatters[node.node_type]
kwargs = node_formatter(node.var)

if cluster:
kwargs["cluster"] = cluster

add_node(node.var.name.replace(":", "&"), **kwargs)


class ModelGraph:
def __init__(self, model):
self.model = model
Expand Down Expand Up @@ -311,24 +332,6 @@ def make_compute_graph(

return input_map

def _make_node(
self,
node: NodeMeta,
*,
node_formatters: NodeTypeFormatterMapping,
add_node: Callable[[str, ...], None],
cluster: bool = False,
formatting: str = "plain",
):
"""Attaches the given variable to a graphviz or networkx Digraph"""
node_formatter = node_formatters[node.node_type]
kwargs = node_formatter(node.var)

if cluster:
kwargs["cluster"] = cluster

add_node(node.var.name.replace(":", "&"), **kwargs)

def get_plates(
self,
var_names: Iterable[VarName] | None = None,
Expand Down Expand Up @@ -380,7 +383,7 @@ def get_plates(
plates[plate_meta].add(var)

return [
Plate(meta=plate_meta, variables=list(variables))
Plate(meta=plate_meta if plate_meta else None, variables=list(variables))
for plate_meta, variables in plates.items()
]

Expand All @@ -407,150 +410,153 @@ def edges(
for parent in parents
]

def make_graph(
self,
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
save=None,
figsize=None,
dpi=300,
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Make graphviz Digraph of PyMC model

Returns
-------
graphviz.Digraph
"""
try:
import graphviz
except ImportError:
raise ImportError(
"This function requires the python library graphviz, along with binaries. "
"The easiest way to install all of this is by running\n\n"
"\tconda install -c conda-forge python-graphviz"
)
def make_graph(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
formatting: str = "plain",
save=None,
figsize=None,
dpi=300,
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Make graphviz Digraph of PyMC model

node_formatters = node_formatters or {}
node_formatters = update_node_formatters(node_formatters)

graph = graphviz.Digraph(self.model.name)
for plate in self.get_plates(var_names):
plate_meta = plate.meta
all_vars = plate.variables
if plate_meta.names or plate_meta.sizes:
# must be preceded by 'cluster' to get a box around it
plate_label = create_plate_label(
all_vars[0].var.name, plate_meta, include_size=include_shape_size
)
with graph.subgraph(name="cluster" + plate_label) as sub:
for var in all_vars:
self._make_node(
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=sub.node,
)
# plate label goes bottom right
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
else:
for var in all_vars:
self._make_node(
Returns
-------
graphviz.Digraph
"""
try:
import graphviz
except ImportError:
raise ImportError(
"This function requires the python library graphviz, along with binaries. "
"The easiest way to install all of this is by running\n\n"
"\tconda install -c conda-forge python-graphviz"
)

node_formatters = node_formatters or {}
node_formatters = update_node_formatters(node_formatters)

graph = graphviz.Digraph(name)
for plate in plates:
if plate.meta:
# must be preceded by 'cluster' to get a box around it
plate_label = create_plate_label(
plate.variables[0].var.name,
plate.meta,
include_size=include_shape_size,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should create_plate_label now take plate_formatters that among other things decides on whether to include_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that is fair. Where do you view that being exposed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I exposed the create_plate_label in both make_graph and make_networkx. However, left it out in the model_to_graphviz and model_to_networkx functions.

If a user defines Callable[[DimInfo], str] function, then that can be used in the first two, more general functions

with graph.subgraph(name="cluster" + plate_label) as sub:
for var in plate.variables:
_make_node(
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=graph.node,
add_node=sub.node,
)
# plate label goes bottom right
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
else:
for var in plate.variables:
_make_node(
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=graph.node,
)

for child, parent in self.edges(var_names=var_names):
graph.edge(parent, child)
for child, parent in edges:
graph.edge(parent, child)

if save is not None:
width, height = (None, None) if figsize is None else figsize
base, ext = path.splitext(save)
if ext:
ext = ext.replace(".", "")
else:
ext = "png"
graph_c = graph.copy()
graph_c.graph_attr.update(size=f"{width},{height}!")
graph_c.graph_attr.update(dpi=str(dpi))
graph_c.render(filename=base, format=ext, cleanup=True)
if save is not None:
width, height = (None, None) if figsize is None else figsize
base, ext = path.splitext(save)
if ext:
ext = ext.replace(".", "")
else:
ext = "png"
graph_c = graph.copy()
graph_c.graph_attr.update(size=f"{width},{height}!")
graph_c.graph_attr.update(dpi=str(dpi))
graph_c.render(filename=base, format=ext, cleanup=True)

return graph
return graph

def make_networkx(
self,
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Make networkx Digraph of PyMC model

Returns
-------
networkx.Digraph
"""
try:
import networkx
except ImportError:
raise ImportError(
"This function requires the python library networkx, along with binaries. "
"The easiest way to install all of this is by running\n\n"
"\tconda install networkx"
)
def make_networkx(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Make networkx Digraph of PyMC model

node_formatters = node_formatters or {}
node_formatters = update_node_formatters(node_formatters)
Returns
-------
networkx.Digraph
"""
try:
import networkx
except ImportError:
raise ImportError(
"This function requires the python library networkx, along with binaries. "
"The easiest way to install all of this is by running\n\n"
"\tconda install networkx"
)

graphnetwork = networkx.DiGraph(name=self.model.name)
for plate in self.get_plates(var_names):
plate_meta = plate.meta
all_vars = plate.variables
if plate_meta.names or plate_meta.sizes:
# # must be preceded by 'cluster' to get a box around it
node_formatters = node_formatters or {}
node_formatters = update_node_formatters(node_formatters)

plate_label = create_plate_label(
all_vars[0].var.name, plate_meta, include_size=include_shape_size
)
subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label)
graphnetwork = networkx.DiGraph(name=name)
for plate in plates:
if plate.meta:
# # must be preceded by 'cluster' to get a box around it

for var in all_vars:
self._make_node(
node=var,
node_formatters=node_formatters,
cluster="cluster" + plate_label,
formatting=formatting,
add_node=subgraphnetwork.add_node,
)
for sgn in subgraphnetwork.nodes:
networkx.set_node_attributes(
subgraphnetwork,
{sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}},
)
node_data = {
e[0]: e[1]
for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True)
}

graphnetwork = networkx.compose(graphnetwork, subgraphnetwork)
networkx.set_node_attributes(graphnetwork, node_data)
graphnetwork.graph["name"] = self.model.name
else:
for var in all_vars:
self._make_node(
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=graphnetwork.add_node,
)
plate_label = create_plate_label(
plate.variables[0].var.name,
plate.meta,
include_size=include_shape_size,
)
subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label)

for var in plate.variables:
_make_node(
node=var,
node_formatters=node_formatters,
cluster="cluster" + plate_label,
formatting=formatting,
add_node=subgraphnetwork.add_node,
)
for sgn in subgraphnetwork.nodes:
networkx.set_node_attributes(
subgraphnetwork,
{sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}},
)
node_data = {
e[0]: e[1] for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True)
}

graphnetwork = networkx.compose(graphnetwork, subgraphnetwork)
networkx.set_node_attributes(graphnetwork, node_data)
graphnetwork.graph["name"] = name
else:
for var in plate.variables:
_make_node(
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=graphnetwork.add_node,
)

for child, parents in self.edges(var_names=var_names):
graphnetwork.add_edge(parents, child)
for child, parents in edges:
graphnetwork.add_edge(parents, child)

return graphnetwork
return graphnetwork


def model_to_networkx(
Expand Down Expand Up @@ -633,9 +639,13 @@ def model_to_networkx(
UserWarning,
stacklevel=2,
)

model = pm.modelcontext(model)
return ModelGraph(model).make_networkx(
var_names=var_names,
graph = ModelGraph(model)
return make_networkx(
name=model.name,
plates=graph.get_plates(var_names=var_names),
edges=graph.edges(var_names=var_names),
formatting=formatting,
node_formatters=node_formatters,
include_shape_size=include_shape_size,
Expand Down Expand Up @@ -738,9 +748,13 @@ def model_to_graphviz(
UserWarning,
stacklevel=2,
)

model = pm.modelcontext(model)
return ModelGraph(model).make_graph(
var_names=var_names,
graph = ModelGraph(model)
return make_graph(
model.name,
plates=graph.get_plates(var_names=var_names),
edges=graph.edges(var_names=var_names),
formatting=formatting,
save=save,
figsize=figsize,
Expand Down