diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 1347f5583..9b1af4bd4 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -81,7 +81,7 @@ from pymc.backends.mcbackend import init_chain_adapters - TraceOrBackend = BaseTrace | Backend + TraceOrBackend: TypeAlias = BaseTrace | Backend RunType: TypeAlias = Run HAS_MCB = True except ImportError: diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index bcd461391..b07c5826d 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -28,7 +28,7 @@ from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph from pytensor.graph import FunctionGraph, clone_replace, node_rewriter -from pytensor.graph.basic import Node, Variable, io_toposort +from pytensor.graph.basic import Apply, Variable, io_toposort from pytensor.graph.features import ReplaceValidate from pytensor.graph.rewriting.basic import GraphRewriter, in2out from pytensor.graph.utils import MetaType @@ -421,7 +421,7 @@ def __init__( kwargs.setdefault("strict", True) super().__init__(*args, **kwargs) - def update(self, node: Node) -> dict[Variable, Variable]: + def update(self, node: Apply) -> dict[Variable, Variable]: """Symbolic update expression for input random state variables Returns a dictionary with the symbolic expressions required for correct updating @@ -430,7 +430,7 @@ def update(self, node: Node) -> dict[Variable, Variable]: """ return collect_default_updates_inner_fgraph(node) - def batch_ndim(self, node: Node) -> int: + def batch_ndim(self, node: Apply) -> int: """Number of dimensions of the distribution's batch shape.""" out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs) return out_ndim - self.ndim_supp diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 5fff3cd3d..b02d1706c 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -18,7 +18,7 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Node, equal_computations +from pytensor.graph.basic import Apply, equal_computations from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import normalize_size_param @@ -156,7 +156,7 @@ def _resize_components(cls, size, *components): return [change_dist_size(component, size) for component in components] - def update(self, node: Node): + def update(self, node: Apply): # Update for the internal mix_indexes RV return {node.inputs[0]: node.outputs[0]} diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index b46562c82..fb19370bf 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -35,9 +35,11 @@ # SOFTWARE. +from typing import cast + import pytensor.tensor as pt -from pytensor.graph.basic import Node +from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor.elemwise import Elemwise @@ -72,7 +74,7 @@ class MeasurableMaxDiscrete(Max): @node_rewriter([Max]) -def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: +def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -80,7 +82,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariabl if isinstance(node.op, MeasurableMax): return None # pragma: no cover - base_var = node.inputs[0] + base_var = cast(TensorVariable, node.inputs[0]) if base_var.owner is None: return None @@ -104,6 +106,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariabl return None # distinguish measurable discrete and continuous (because logprob is different) + measurable_max: Max if base_var.owner.op.dtype.startswith("int"): measurable_max = MeasurableMaxDiscrete(list(axis)) else: @@ -173,7 +176,7 @@ class MeasurableDiscreteMaxNeg(Max): @node_rewriter(tracks=[Max]) -def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: +def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: @@ -182,7 +185,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar if isinstance(node.op, MeasurableMaxNeg): return None # pragma: no cover - base_var = node.inputs[0] + base_var = cast(TensorVariable, node.inputs[0]) # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)): @@ -213,6 +216,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar return None # distinguish measurable discrete and continuous (because logprob is different) + measurable_min: Max if base_rv.owner.op.dtype.startswith("int"): measurable_min = MeasurableDiscreteMaxNeg(list(axis)) else: diff --git a/pymc/printing.py b/pymc/printing.py index 13361741e..6c6bdbb71 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -13,6 +13,8 @@ # limitations under the License. +from functools import partial + from pytensor.compile import SharedVariable from pytensor.graph.basic import Constant, walk from pytensor.tensor.basic import TensorVariable, Variable @@ -55,7 +57,7 @@ def str_for_dist( if "latex" in formatting: if print_name is not None: - print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}" + print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}" op_name = ( dist.owner.op._print_name[1] @@ -96,17 +98,16 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool """Make a human-readable string representation of Model, listing all random variables and their distributions, optionally including parameter values.""" - kwargs = dict(formatting=formatting, include_params=include_params) - free_rv_reprs = [str_for_dist(dist, **kwargs) for dist in model.free_RVs] - observed_rv_reprs = [str_for_dist(rv, **kwargs) for rv in model.observed_RVs] - det_reprs = [ - str_for_potential_or_deterministic(dist, **kwargs, dist_name="Deterministic") - for dist in model.deterministics - ] - potential_reprs = [ - str_for_potential_or_deterministic(pot, **kwargs, dist_name="Potential") - for pot in model.potentials - ] + # Wrap functions to avoid confusing typecheckers + sfd = partial(str_for_dist, formatting=formatting, include_params=include_params) + sfp = partial( + str_for_potential_or_deterministic, formatting=formatting, include_params=include_params + ) + + free_rv_reprs = [sfd(dist) for dist in model.free_RVs] + observed_rv_reprs = [sfd(rv) for rv in model.observed_RVs] + det_reprs = [sfp(dist, dist_name="Deterministic") for dist in model.deterministics] + potential_reprs = [sfp(pot, dist_name="Potential") for pot in model.potentials] var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs @@ -162,6 +163,8 @@ def _str_for_input_var(var: Variable, formatting: str) -> str: from pymc.distributions.distribution import SymbolicRandomVariable def _is_potential_or_deterministic(var: Variable) -> bool: + if not hasattr(var, "str_repr"): + return False try: return var.str_repr.__func__.func is str_for_potential_or_deterministic except AttributeError: @@ -175,6 +178,7 @@ def _is_potential_or_deterministic(var: Variable) -> bool: ) or _is_potential_or_deterministic(var): # show the names for RandomVariables, Deterministics, and Potentials, rather # than the full expression + assert isinstance(var, TensorVariable) return _str_for_input_rv(var, formatting) elif isinstance(var.owner.op, DimShuffle): return _str_for_input_var(var.owner.inputs[0], formatting) @@ -182,7 +186,7 @@ def _is_potential_or_deterministic(var: Variable) -> bool: return _str_for_expression(var, formatting) -def _str_for_input_rv(var: Variable, formatting: str) -> str: +def _str_for_input_rv(var: TensorVariable, formatting: str) -> str: _str = ( var.name if var.name is not None @@ -221,12 +225,15 @@ def _expand(x): if x.owner and (not isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)): return reversed(x.owner.inputs) - parents = [ - x - for x in walk(nodes=var.owner.inputs, expand=_expand) - if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable) - ] - names = [x.name for x in parents] + parents = [] + names = [] + for x in walk(nodes=var.owner.inputs, expand=_expand): + assert isinstance(x, Variable) + if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable): + parents.append(x) + xname = x.name + assert xname is not None + names.append(xname) if "latex" in formatting: return ( @@ -257,6 +264,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): """Handy plug-in method to instruct IPython-like REPLs to use our str_repr above.""" # we know that our str_repr does not recurse, so we can ignore cycle try: + if not hasattr(obj, "str_repr"): + raise AttributeError output = obj.str_repr() # Find newlines and replace them with p.break_() # (see IPython.lib.pretty._repr_pprint) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index add8cbb17..d585ba6c1 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -14,6 +14,7 @@ import warnings from collections.abc import Callable, Generator, Iterable, Sequence +from typing import cast import numpy as np import pandas as pd @@ -29,7 +30,6 @@ from pytensor.graph.basic import ( Apply, Constant, - Node, Variable, clone_get_equiv, graph_inputs, @@ -208,8 +208,8 @@ def replace_vars_in_graphs( """ # Clone graphs and get equivalences inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = {k: k for k in replacements.keys()} - equiv = clone_get_equiv(inputs, graphs, False, False, equiv) + memo = {k: k for k in replacements.keys()} + equiv = clone_get_equiv(inputs, graphs, False, False, memo) fg = FunctionGraph( [equiv[i] for i in inputs], @@ -753,7 +753,7 @@ def find_rng_nodes( ] -def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]: +def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> list[TensorVariable]: """Replace any RNG nodes upstream of outputs by new RNGs of the same type This can be used when combining a pre-existing graph with a cloned one, to ensure @@ -775,7 +775,7 @@ def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVaria rng_cls = np.random.Generator new_rng_nodes.append(pytensor.shared(rng_cls(np.random.PCG64()))) graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True) - return graph.outputs + return cast(list[TensorVariable], graph.outputs) SeedSequenceSeed = None | int | Sequence[int] | np.ndarray | np.random.SeedSequence @@ -798,7 +798,7 @@ def reseed_rngs( rng.set_value(new_rng, borrow=True) -def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]: +def collect_default_updates_inner_fgraph(node: Apply) -> dict[Variable, Variable]: """Collect default updates from node with inner fgraph.""" op = node.op inner_updates = collect_default_updates( @@ -926,15 +926,15 @@ def find_default_update(clients, rng: Variable) -> None | Variable: if inputs is None: inputs = [] - outputs = makeiter(outputs) - fg = FunctionGraph(outputs=outputs, clone=False) + outs = makeiter(outputs) + fg = FunctionGraph(outputs=outs, clone=False) clients = fg.clients rng_updates = {} # Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True` for input_rng in ( inp - for inp in graph_inputs(outputs, blockers=inputs) + for inp in graph_inputs(outs, blockers=inputs) if ( (not must_be_shared or isinstance(inp, SharedVariable)) and isinstance(inp.type, RandomType) @@ -945,7 +945,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable: default_update = find_default_update(clients, input_rng) # Respect default update if provided - if getattr(input_rng, "default_update", None): + if hasattr(input_rng, "default_update") and input_rng.default_update is not None: rng_updates[input_rng] = input_rng.default_update else: if default_update is not None: @@ -1001,7 +1001,8 @@ def compile_pymc( # We always reseed random variables as this provides RNGs with no chances of collision if rng_updates: - reseed_rngs(rng_updates.keys(), random_seed) + rngs = cast(list[SharedVariable], list(rng_updates)) + reseed_rngs(rngs, random_seed) # If called inside a model context, see if check_bounds flag is set to False try: diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index ad6f7ede4..2f78480d9 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -225,7 +225,7 @@ def sample_smc( trace = MultiTrace(traces) _t_sampling = time.time() - t1 - sample_stats, idata = _save_sample_stats( + _, idata = _save_sample_stats( sample_settings, sample_stats, chains, diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index ada8f71b2..9aacc0ce5 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -25,7 +25,6 @@ pymc/distributions/continuous.py pymc/distributions/dist_math.py pymc/distributions/distribution.py -pymc/distributions/mixture.py pymc/distributions/multivariate.py pymc/distributions/timeseries.py pymc/distributions/truncated.py @@ -34,7 +33,6 @@ pymc/logprob/censoring.py pymc/logprob/basic.py pymc/logprob/mixture.py -pymc/logprob/order.py pymc/logprob/rewriting.py pymc/logprob/scan.py pymc/logprob/tensor.py @@ -44,7 +42,6 @@ pymc/model/core.py pymc/model/fgraph.py pymc/model/transform/conditioning.py -pymc/printing.py pymc/pytensorf.py pymc/sampling/jax.py """