Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@

from pymc.backends.mcbackend import init_chain_adapters

TraceOrBackend = BaseTrace | Backend
TraceOrBackend: TypeAlias = BaseTrace | Backend
RunType: TypeAlias = Run
HAS_MCB = True
except ImportError:
Expand Down
6 changes: 3 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
from pytensor.graph.basic import Node, Variable, io_toposort
from pytensor.graph.basic import Apply, Variable, io_toposort
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only comment: it's a bit awkward to have variables called node typed as Apply instead of Node. I was thinking you could do an alias here: from pytensor.graph.basic import Apply as ApplyNode to make the type hints a bit more obvious? Or maybe that's a name change that we should consider at the pytensor level @ricardoV94 . It would match Variable -> TensorVariable that way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Node is a bit of useless baseclass for Apply and Variable. The only thing shared between the two is the get_parents. We definitely like to call Apply(Node) nodes in the codebase, and never corresponding Variable(Node) nodes, or variable nodes. I think this gives us so little that we could remove the shared base class and only have Variables and Nodes (although I understand why they wanted to call a Variable a node as well). Anyway that should be a discussion on the PyTensor side, I don't think PyMC should invent new names here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, I'll raise an issue on the pytensor side about typing generally. @michaelosthege please ignore my comment, PR looks good to go

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I agree that things should be improved on the PyTensor side. The type hinting is a bit of a mess down there too.

from pytensor.graph.features import ReplaceValidate
from pytensor.graph.rewriting.basic import GraphRewriter, in2out
from pytensor.graph.utils import MetaType
Expand Down Expand Up @@ -421,7 +421,7 @@ def __init__(
kwargs.setdefault("strict", True)
super().__init__(*args, **kwargs)

def update(self, node: Node) -> dict[Variable, Variable]:
def update(self, node: Apply) -> dict[Variable, Variable]:
"""Symbolic update expression for input random state variables

Returns a dictionary with the symbolic expressions required for correct updating
Expand All @@ -430,7 +430,7 @@ def update(self, node: Node) -> dict[Variable, Variable]:
"""
return collect_default_updates_inner_fgraph(node)

def batch_ndim(self, node: Node) -> int:
def batch_ndim(self, node: Apply) -> int:
"""Number of dimensions of the distribution's batch shape."""
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
return out_ndim - self.ndim_supp
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import Node, equal_computations
from pytensor.graph.basic import Apply, equal_computations
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
Expand Down Expand Up @@ -156,7 +156,7 @@ def _resize_components(cls, size, *components):

return [change_dist_size(component, size) for component in components]

def update(self, node: Node):
def update(self, node: Apply):
# Update for the internal mix_indexes RV
return {node.inputs[0]: node.outputs[0]}

Expand Down
14 changes: 9 additions & 5 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
# SOFTWARE.


from typing import cast

import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -72,15 +74,15 @@ class MeasurableMaxDiscrete(Max):


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMax):
return None # pragma: no cover

base_var = node.inputs[0]
base_var = cast(TensorVariable, node.inputs[0])

if base_var.owner is None:
return None
Expand All @@ -104,6 +106,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariabl
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_max: Max
if base_var.owner.op.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
else:
Expand Down Expand Up @@ -173,7 +176,7 @@ class MeasurableDiscreteMaxNeg(Max):


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
Expand All @@ -182,7 +185,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar
if isinstance(node.op, MeasurableMaxNeg):
return None # pragma: no cover

base_var = node.inputs[0]
base_var = cast(TensorVariable, node.inputs[0])

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
Expand Down Expand Up @@ -213,6 +216,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_min: Max
if base_rv.owner.op.dtype.startswith("int"):
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
else:
Expand Down
47 changes: 28 additions & 19 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from functools import partial

from pytensor.compile import SharedVariable
from pytensor.graph.basic import Constant, walk
from pytensor.tensor.basic import TensorVariable, Variable
Expand Down Expand Up @@ -55,7 +57,7 @@ def str_for_dist(

if "latex" in formatting:
if print_name is not None:
print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}"
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"

op_name = (
dist.owner.op._print_name[1]
Expand Down Expand Up @@ -96,17 +98,16 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
"""Make a human-readable string representation of Model, listing all random variables
and their distributions, optionally including parameter values."""

kwargs = dict(formatting=formatting, include_params=include_params)
free_rv_reprs = [str_for_dist(dist, **kwargs) for dist in model.free_RVs]
observed_rv_reprs = [str_for_dist(rv, **kwargs) for rv in model.observed_RVs]
det_reprs = [
str_for_potential_or_deterministic(dist, **kwargs, dist_name="Deterministic")
for dist in model.deterministics
]
potential_reprs = [
str_for_potential_or_deterministic(pot, **kwargs, dist_name="Potential")
for pot in model.potentials
]
# Wrap functions to avoid confusing typecheckers
sfd = partial(str_for_dist, formatting=formatting, include_params=include_params)
sfp = partial(
str_for_potential_or_deterministic, formatting=formatting, include_params=include_params
)

free_rv_reprs = [sfd(dist) for dist in model.free_RVs]
observed_rv_reprs = [sfd(rv) for rv in model.observed_RVs]
det_reprs = [sfp(dist, dist_name="Deterministic") for dist in model.deterministics]
potential_reprs = [sfp(pot, dist_name="Potential") for pot in model.potentials]

var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs

Expand Down Expand Up @@ -162,6 +163,8 @@ def _str_for_input_var(var: Variable, formatting: str) -> str:
from pymc.distributions.distribution import SymbolicRandomVariable

def _is_potential_or_deterministic(var: Variable) -> bool:
if not hasattr(var, "str_repr"):
return False
try:
return var.str_repr.__func__.func is str_for_potential_or_deterministic
except AttributeError:
Expand All @@ -175,14 +178,15 @@ def _is_potential_or_deterministic(var: Variable) -> bool:
) or _is_potential_or_deterministic(var):
# show the names for RandomVariables, Deterministics, and Potentials, rather
# than the full expression
assert isinstance(var, TensorVariable)
return _str_for_input_rv(var, formatting)
elif isinstance(var.owner.op, DimShuffle):
return _str_for_input_var(var.owner.inputs[0], formatting)
else:
return _str_for_expression(var, formatting)


def _str_for_input_rv(var: Variable, formatting: str) -> str:
def _str_for_input_rv(var: TensorVariable, formatting: str) -> str:
_str = (
var.name
if var.name is not None
Expand Down Expand Up @@ -221,12 +225,15 @@ def _expand(x):
if x.owner and (not isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)):
return reversed(x.owner.inputs)

parents = [
x
for x in walk(nodes=var.owner.inputs, expand=_expand)
if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)
]
names = [x.name for x in parents]
parents = []
names = []
for x in walk(nodes=var.owner.inputs, expand=_expand):
assert isinstance(x, Variable)
if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable):
parents.append(x)
xname = x.name
assert xname is not None
names.append(xname)

if "latex" in formatting:
return (
Expand Down Expand Up @@ -257,6 +264,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
"""Handy plug-in method to instruct IPython-like REPLs to use our str_repr above."""
# we know that our str_repr does not recurse, so we can ignore cycle
try:
if not hasattr(obj, "str_repr"):
raise AttributeError
output = obj.str_repr()
# Find newlines and replace them with p.break_()
# (see IPython.lib.pretty._repr_pprint)
Expand Down
23 changes: 12 additions & 11 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings

from collections.abc import Callable, Generator, Iterable, Sequence
from typing import cast

import numpy as np
import pandas as pd
Expand All @@ -29,7 +30,6 @@
from pytensor.graph.basic import (
Apply,
Constant,
Node,
Variable,
clone_get_equiv,
graph_inputs,
Expand Down Expand Up @@ -208,8 +208,8 @@ def replace_vars_in_graphs(
"""
# Clone graphs and get equivalences
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = {k: k for k in replacements.keys()}
equiv = clone_get_equiv(inputs, graphs, False, False, equiv)
memo = {k: k for k in replacements.keys()}
equiv = clone_get_equiv(inputs, graphs, False, False, memo)

fg = FunctionGraph(
[equiv[i] for i in inputs],
Expand Down Expand Up @@ -753,7 +753,7 @@ def find_rng_nodes(
]


def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]:
def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> list[TensorVariable]:
"""Replace any RNG nodes upstream of outputs by new RNGs of the same type

This can be used when combining a pre-existing graph with a cloned one, to ensure
Expand All @@ -775,7 +775,7 @@ def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVaria
rng_cls = np.random.Generator
new_rng_nodes.append(pytensor.shared(rng_cls(np.random.PCG64())))
graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
return graph.outputs
return cast(list[TensorVariable], graph.outputs)


SeedSequenceSeed = None | int | Sequence[int] | np.ndarray | np.random.SeedSequence
Expand All @@ -798,7 +798,7 @@ def reseed_rngs(
rng.set_value(new_rng, borrow=True)


def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]:
def collect_default_updates_inner_fgraph(node: Apply) -> dict[Variable, Variable]:
"""Collect default updates from node with inner fgraph."""
op = node.op
inner_updates = collect_default_updates(
Expand Down Expand Up @@ -926,15 +926,15 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
if inputs is None:
inputs = []

outputs = makeiter(outputs)
fg = FunctionGraph(outputs=outputs, clone=False)
outs = makeiter(outputs)
fg = FunctionGraph(outputs=outs, clone=False)
clients = fg.clients

rng_updates = {}
# Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
for input_rng in (
inp
for inp in graph_inputs(outputs, blockers=inputs)
for inp in graph_inputs(outs, blockers=inputs)
if (
(not must_be_shared or isinstance(inp, SharedVariable))
and isinstance(inp.type, RandomType)
Expand All @@ -945,7 +945,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
default_update = find_default_update(clients, input_rng)

# Respect default update if provided
if getattr(input_rng, "default_update", None):
if hasattr(input_rng, "default_update") and input_rng.default_update is not None:
rng_updates[input_rng] = input_rng.default_update
else:
if default_update is not None:
Expand Down Expand Up @@ -1001,7 +1001,8 @@ def compile_pymc(

# We always reseed random variables as this provides RNGs with no chances of collision
if rng_updates:
reseed_rngs(rng_updates.keys(), random_seed)
rngs = cast(list[SharedVariable], list(rng_updates))
reseed_rngs(rngs, random_seed)

# If called inside a model context, see if check_bounds flag is set to False
try:
Expand Down
2 changes: 1 addition & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def sample_smc(
trace = MultiTrace(traces)

_t_sampling = time.time() - t1
sample_stats, idata = _save_sample_stats(
_, idata = _save_sample_stats(
sample_settings,
sample_stats,
chains,
Expand Down
3 changes: 0 additions & 3 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
pymc/distributions/continuous.py
pymc/distributions/dist_math.py
pymc/distributions/distribution.py
pymc/distributions/mixture.py
pymc/distributions/multivariate.py
pymc/distributions/timeseries.py
pymc/distributions/truncated.py
Expand All @@ -34,7 +33,6 @@
pymc/logprob/censoring.py
pymc/logprob/basic.py
pymc/logprob/mixture.py
pymc/logprob/order.py
pymc/logprob/rewriting.py
pymc/logprob/scan.py
pymc/logprob/tensor.py
Expand All @@ -44,7 +42,6 @@
pymc/model/core.py
pymc/model/fgraph.py
pymc/model/transform/conditioning.py
pymc/printing.py
pymc/pytensorf.py
pymc/sampling/jax.py
"""
Expand Down