From fdf054b6aae41d1a2b26b6313dec32ea951f08df Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 28 Sep 2022 16:45:51 +0200 Subject: [PATCH] Assume default_output is the only measurable output in SymbolicRandomVariables --- pymc/distributions/distribution.py | 8 ++++++++ pymc/distributions/timeseries.py | 7 ------- pymc/tests/distributions/test_distribution.py | 16 +++++++++++++++- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index f3c116a8ff..541a7030a8 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -381,6 +381,14 @@ def dist( @_get_measurable_outputs.register(SymbolicRandomVariable) def _get_measurable_outputs_symbolic_random_variable(op, node): # This tells Aeppl that any non RandomType outputs are measurable + + # Assume that if there is one default_output, that's the only one that is measurable + # In the rare case this is not what one wants, a specialized _get_measuarable_outputs + # can dispatch for a subclassed Op + if op.default_output is not None: + return [node.default_output()] + + # Otherwise assume that any outputs that are not of RandomType are measurable return [out for out in node.outputs if not isinstance(out.type, RandomType)] diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index bb9f6a9da3..c99732471b 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -19,7 +19,6 @@ import aesara.tensor as at import numpy as np -from aeppl.abstract import _get_measurable_outputs from aeppl.logprob import _logprob from aesara.graph.basic import Node, clone_replace from aesara.raise_op import Assert @@ -203,12 +202,6 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None): )(init_dist, innovation_dist, steps) -@_get_measurable_outputs.register(RandomWalkRV) -def _get_measurable_outputs_random_walk(op, node): - # Ignore steps output - return [node.default_output()] - - @_change_dist_size.register(RandomWalkRV) def change_random_walk_size(op, dist, new_size, expand): init_dist, innovation_dist, steps = dist.owner.inputs diff --git a/pymc/tests/distributions/test_distribution.py b/pymc/tests/distributions/test_distribution.py index 13a82f608e..8b0a4fba55 100644 --- a/pymc/tests/distributions/test_distribution.py +++ b/pymc/tests/distributions/test_distribution.py @@ -339,7 +339,9 @@ class TestInlinedSymbolicRV(SymbolicRandomVariable): x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)() assert np.isclose(logp(x_inline, 0).eval(), 0) - def test_measurable_outputs(self): + def test_measurable_outputs_rng_ignored(self): + """Test that any RandomType outputs are ignored as a measurable_outputs""" + class TestSymbolicRV(SymbolicRandomVariable): pass @@ -347,3 +349,15 @@ class TestSymbolicRV(SymbolicRandomVariable): next_rng, dirac_delta = TestSymbolicRV([], [next_rng_, dirac_delta_], ndim_supp=0)() node = dirac_delta.owner assert get_measurable_outputs(node.op, node) == [dirac_delta] + + @pytest.mark.parametrize("default_output_idx", (0, 1)) + def test_measurable_outputs_default_output(self, default_output_idx): + """Test that if provided, a default output is considered the only measurable_output""" + + class TestSymbolicRV(SymbolicRandomVariable): + default_output = default_output_idx + + dirac_delta_1_ = DiracDelta.dist(5) + dirac_delta_2_ = DiracDelta.dist(10) + node = TestSymbolicRV([], [dirac_delta_1_, dirac_delta_2_], ndim_supp=0)().owner + assert get_measurable_outputs(node.op, node) == [node.outputs[default_output_idx]]