@@ -1622,7 +1622,7 @@ def compile_forward_sampling_function(
16221622 basic_rvs : Optional [List [Variable ]] = None ,
16231623 givens_dict : Optional [Dict [Variable , Any ]] = None ,
16241624 ** kwargs ,
1625- ) -> Callable [..., Union [np .ndarray , List [np .ndarray ]]]:
1625+ ) -> Tuple [ Callable [..., Union [np .ndarray , List [np .ndarray ]]], Set [ Variable ]]:
16261626 """Compile a function to draw samples, conditioned on the values of some variables.
16271627
16281628 The goal of this function is to walk the aesara computational graph from the list
@@ -1635,13 +1635,10 @@ def compile_forward_sampling_function(
16351635
16361636 - Variables in the outputs list
16371637 - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1638- - Basic RVs that are not in the ``vars_in_trace`` list
1638+ - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
16391639 - Variables that are keys in the ``givens_dict``
16401640 - Variables that have volatile inputs
16411641
1642- Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
1643- that are in the ``basic_rvs`` list.
1644-
16451642 Concretely, this function can be used to compile a function to sample from the
16461643 posterior predictive distribution of a model that has variables that are conditioned
16471644 on ``MutableData`` instances. The variables that depend on the mutable data will be
@@ -1670,12 +1667,19 @@ def compile_forward_sampling_function(
16701667 output of ``model.basic_RVs``) should have a reference to the variables that should
16711668 be considered as random variable instances. This includes variables that have
16721669 a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
1673- Censored distributions. If ``None``, only pure random variables will be considered
1674- as potential random variables.
1670+ Censored distributions.
16751671 givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
16761672 A dictionary that maps tensor variables to the values that should be used to replace them
16771673 in the compiled function. The types of the key and value should match or an error will be
16781674 raised during compilation.
1675+
1676+ Returns
1677+ -------
1678+ function: Callable
1679+ Compiled forward sampling Aesara function
1680+ volatile_basic_rvs: Set of Variable
1681+ Set of all basic_rvs that were considered volatile and will be resampled when
1682+ the function is evaluated
16791683 """
16801684 if givens_dict is None :
16811685 givens_dict = {}
@@ -1741,7 +1745,10 @@ def expand(node):
17411745 for node , value in givens_dict .items ()
17421746 ]
17431747
1744- return compile_pymc (inputs , fg .outputs , givens = givens , on_unused_input = "ignore" , ** kwargs )
1748+ return (
1749+ compile_pymc (inputs , fg .outputs , givens = givens , on_unused_input = "ignore" , ** kwargs ),
1750+ set (basic_rvs ) & (volatile_nodes - set (givens_dict )), # Basic RVs that will be resampled
1751+ )
17451752
17461753
17471754def sample_posterior_predictive (
@@ -1900,7 +1907,6 @@ def sample_posterior_predictive(
19001907 vars_ = model .observed_RVs + model .auto_deterministics
19011908
19021909 indices = np .arange (samples )
1903-
19041910 if progressbar :
19051911 indices = progress_bar (indices , total = samples , display = progressbar )
19061912
@@ -1923,17 +1929,17 @@ def sample_posterior_predictive(
19231929 compile_kwargs .setdefault ("allow_input_downcast" , True )
19241930 compile_kwargs .setdefault ("accept_inplace" , True )
19251931
1926- sampler_fn = point_wrapper (
1927- compile_forward_sampling_function (
1928- outputs = vars_to_sample ,
1929- vars_in_trace = vars_in_trace ,
1930- basic_rvs = model .basic_RVs ,
1931- givens_dict = None ,
1932- random_seed = random_seed ,
1933- ** compile_kwargs ,
1934- )
1932+ _sampler_fn , volatile_basic_rvs = compile_forward_sampling_function (
1933+ outputs = vars_to_sample ,
1934+ vars_in_trace = vars_in_trace ,
1935+ basic_rvs = model .basic_RVs ,
1936+ givens_dict = None ,
1937+ random_seed = random_seed ,
1938+ ** compile_kwargs ,
19351939 )
1936-
1940+ sampler_fn = point_wrapper (_sampler_fn )
1941+ # All model variables have a name, but mypy does not know this
1942+ _log .info (f"Sampling: { list (sorted (volatile_basic_rvs , key = lambda var : var .name ))} " ) # type: ignore
19371943 ppc_trace_t = _DefaultTrace (samples )
19381944 try :
19391945 if isinstance (_trace , MultiTrace ):
@@ -2242,7 +2248,7 @@ def sample_prior_predictive(
22422248 compile_kwargs .setdefault ("allow_input_downcast" , True )
22432249 compile_kwargs .setdefault ("accept_inplace" , True )
22442250
2245- sampler_fn = compile_forward_sampling_function (
2251+ sampler_fn , volatile_basic_rvs = compile_forward_sampling_function (
22462252 vars_to_sample ,
22472253 vars_in_trace = [],
22482254 basic_rvs = model .basic_RVs ,
@@ -2251,6 +2257,8 @@ def sample_prior_predictive(
22512257 ** compile_kwargs ,
22522258 )
22532259
2260+ # All model variables have a name, but mypy does not know this
2261+ _log .info (f"Sampling: { list (sorted (volatile_basic_rvs , key = lambda var : var .name ))} " ) # type: ignore
22542262 values = zip (* (sampler_fn () for i in range (samples )))
22552263
22562264 data = {k : np .stack (v ) for k , v in zip (names , values )}
0 commit comments