@@ -333,7 +333,7 @@ def replace_rvs_by_values(
333333 graphs : Sequence [TensorVariable ],
334334 * ,
335335 rvs_to_values : Dict [TensorVariable , TensorVariable ],
336- rvs_to_transforms : Dict [TensorVariable , RVTransform ],
336+ rvs_to_transforms : Optional [ Dict [TensorVariable , RVTransform ]] = None ,
337337 ** kwargs ,
338338) -> List [TensorVariable ]:
339339 """Clone and replace random variables in graphs with their value variables.
@@ -346,7 +346,7 @@ def replace_rvs_by_values(
346346 The graphs in which to perform the replacements.
347347 rvs_to_values
348348 Mapping between the original graph RVs and respective value variables
349- rvs_to_transforms
349+ rvs_to_transforms, optional
350350 Mapping between the original graph RVs and respective value transforms
351351 """
352352
@@ -361,7 +361,8 @@ def replace_rvs_by_values(
361361 for rv , value in rvs_to_values .items ():
362362 equiv_rv = equiv .get (rv , rv )
363363 equiv_rvs_to_values [equiv_rv ] = equiv .get (value , value )
364- equiv_rvs_to_transforms [equiv_rv ] = rvs_to_transforms [rv ]
364+ if rvs_to_transforms is not None :
365+ equiv_rvs_to_transforms [equiv_rv ] = rvs_to_transforms [rv ]
365366
366367 def poulate_replacements (rv , replacements ):
367368 # Populate replacements dict with {rv: value} pairs indicating which graph
@@ -372,14 +373,15 @@ def poulate_replacements(rv, replacements):
372373 if value is None :
373374 return []
374375
375- transform = equiv_rvs_to_transforms .get (rv , None )
376- if transform is not None :
377- # We want to replace uses of the RV by the back-transformation of its value
378- value = transform .backward (value , * rv .owner .inputs )
379- # The value may have a less precise type than the rv. In this case
380- # filter_variable will add a SpecifyShape to ensure they are consistent
381- value = rv .type .filter_variable (value , allow_convert = True )
382- value .name = rv .name
376+ if rvs_to_transforms is not None :
377+ transform = equiv_rvs_to_transforms .get (rv , None )
378+ if transform is not None :
379+ # We want to replace uses of the RV by the back-transformation of its value
380+ value = transform .backward (value , * rv .owner .inputs )
381+ # The value may have a less precise type than the rv. In this case
382+ # filter_variable will add a SpecifyShape to ensure they are consistent
383+ value = rv .type .filter_variable (value , allow_convert = True )
384+ value .name = rv .name
383385
384386 replacements [rv ] = value
385387 # Also walk the graph of the value variable to make any additional
0 commit comments