3737import numpy as np
3838import pytensor
3939import pytensor .tensor as pt
40+ import pytest
4041import scipy .stats as st
4142
4243from pymc import logp
44+ from pymc .logprob .abstract import MeasurableVariable
4345from pymc .logprob .basic import factorized_joint_logprob
4446from pymc .logprob .censoring import MeasurableClip
4547from pymc .logprob .rewriting import construct_ir_fgraph
@@ -121,10 +123,13 @@ def test_nested_scalar_mixtures():
121123 assert np .isclose (logp_fn (0 , 0 , 1 , 50 ), st .norm .logpdf (150 ) + np .log (0.5 ) * 3 )
122124
123125
124- def test_unvalued_ir_reversion ():
126+ @pytest .mark .parametrize ("nested" , (False , True ))
127+ def test_unvalued_ir_reversion (nested ):
125128 """Make sure that un-valued IR rewrites are reverted."""
126129 x_rv = pt .random .normal ()
127130 y_rv = pt .clip (x_rv , 0 , 1 )
131+ if nested :
132+ y_rv = y_rv + 5
128133 z_rv = pt .random .normal (y_rv , 1 , name = "z" )
129134 z_vv = z_rv .clone ()
130135
@@ -134,14 +139,10 @@ def test_unvalued_ir_reversion():
134139
135140 z_fgraph , _ , memo = construct_ir_fgraph (rv_values )
136141
137- assert memo [y_rv ] in z_fgraph .preserve_rv_mappings .measurable_conversions
138-
139- measurable_y_rv = z_fgraph .preserve_rv_mappings .measurable_conversions [memo [y_rv ]]
140- assert isinstance (measurable_y_rv .owner .op , MeasurableClip )
141-
142- # `construct_ir_fgraph` should've reverted the un-valued measurable IR
143- # change
144- assert measurable_y_rv not in z_fgraph
142+ assert len (z_fgraph .preserve_rv_mappings .measurable_conversions ) == 1 + nested
143+ assert (
144+ sum (isinstance (node .op , MeasurableVariable ) for node in z_fgraph .apply_nodes ) == 2
145+ ) # Just the 2 rvs
145146
146147
147148def test_shifted_cumsum ():
0 commit comments