44import aesara .tensor as at
55import aesara .tensor .random .basic as arb
66import numpy as np
7- from aesara import scan , shared
7+ from aesara import scan
88from aesara .compile .builders import OpFromGraph
99from aesara .graph .basic import Node
1010from aesara .graph .fg import FunctionGraph
1515from aesara .scalar .basic import clip as scalar_clip
1616from aesara .scan import until
1717from aesara .tensor .elemwise import Elemwise
18+ from aesara .tensor .random import RandomStream
1819from aesara .tensor .random .op import RandomVariable
1920from aesara .tensor .var import TensorConstant , TensorVariable
2021
@@ -188,7 +189,11 @@ def __str__(self):
188189
189190
190191def truncate (
191- rv : TensorVariable , lower = None , upper = None , max_n_steps : int = 10_000 , rng = None
192+ rv : TensorVariable ,
193+ lower = None ,
194+ upper = None ,
195+ max_n_steps : int = 10_000 ,
196+ srng : Optional [RandomStream ] = None ,
192197) -> Tuple [TensorVariable , Tuple [TensorVariable , TensorVariable ]]:
193198 """Truncate a univariate `RandomVariable` between lower and upper.
194199
@@ -218,13 +223,13 @@ def truncate(
218223 lower = at .as_tensor_variable (lower ) if lower is not None else at .constant (- np .inf )
219224 upper = at .as_tensor_variable (upper ) if upper is not None else at .constant (np .inf )
220225
221- if rng is None :
222- rng = shared ( np . random . RandomState (), borrow = True )
226+ if srng is None :
227+ srng = RandomStream ( )
223228
224229 # Try to use specialized Op
225230 try :
226231 truncated_rv , updates = _truncated (
227- rv .owner .op , lower , upper , rng , * rv .owner .inputs [1 :]
232+ rv .owner .op , lower , upper , srng , * rv .owner .inputs [1 :]
228233 )
229234 return truncated_rv , updates
230235 except NotImplementedError :
@@ -235,8 +240,8 @@ def truncate(
235240 # though it would not be necessary for the icdf OpFromGraph
236241 graph_inputs = [* rv .owner .inputs [1 :], lower , upper ]
237242 graph_inputs_ = [inp .type () for inp in graph_inputs ]
238- * rv_inputs_ , lower_ , upper_ = graph_inputs_
239- rv_ = rv .owner .op . make_node ( rng , * rv_inputs_ ). default_output ( )
243+ size_ , dtype_ , * rv_inputs_ , lower_ , upper_ = graph_inputs_
244+ rv_ = srng . gen ( rv .owner .op , * rv_inputs_ , size = size_ , dtype = dtype_ )
240245
241246 # Try to use inverted cdf sampling
242247 try :
@@ -245,11 +250,10 @@ def truncate(
245250 lower_value = lower_ - 1 if rv .owner .op .dtype .startswith ("int" ) else lower_
246251 cdf_lower_ = at .exp (logcdf (rv_ , lower_value ))
247252 cdf_upper_ = at .exp (logcdf (rv_ , upper_ ))
248- uniform_ = at . random .uniform (
253+ uniform_ = srng .uniform (
249254 cdf_lower_ ,
250255 cdf_upper_ ,
251- rng = rng ,
252- size = rv_inputs_ [0 ],
256+ size = size_ ,
253257 )
254258 truncated_rv_ = icdf (rv_ , uniform_ )
255259 truncated_rv = TruncatedRV (
@@ -265,27 +269,23 @@ def truncate(
265269
266270 # Fallback to rejection sampling
267271 # TODO: Handle potential broadcast by lower / upper
268- def loop_fn (truncated_rv , reject_draws , lower , upper , rng , * rv_inputs ):
269- next_rng , new_truncated_rv = rv .owner .op . make_node ( rng , * rv_inputs ). outputs
272+ def loop_fn (truncated_rv , reject_draws , lower , upper , size , dtype , * rv_inputs ):
273+ new_truncated_rv = srng . gen ( rv .owner .op , * rv_inputs , size = size , dtype = dtype ) # type: ignore
270274 truncated_rv = at .set_subtensor (
271275 truncated_rv [reject_draws ],
272276 new_truncated_rv [reject_draws ],
273277 )
274278 reject_draws = at .or_ ((truncated_rv < lower ), (truncated_rv > upper ))
275279
276- return (
277- (truncated_rv , reject_draws ),
278- [(rng , next_rng )],
279- until (~ at .any (reject_draws )),
280- )
280+ return (truncated_rv , reject_draws ), until (~ at .any (reject_draws ))
281281
282282 (truncated_rv_ , reject_draws_ ), updates = scan (
283283 loop_fn ,
284284 outputs_info = [
285285 at .zeros_like (rv_ ),
286286 at .ones_like (rv_ , dtype = bool ),
287287 ],
288- non_sequences = [lower_ , upper_ , rng , * rv_inputs_ ],
288+ non_sequences = [lower_ , upper_ , size_ , dtype_ , * rv_inputs_ ],
289289 n_steps = max_n_steps ,
290290 strict = True ,
291291 )
@@ -299,18 +299,28 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
299299 truncated_rv = TruncatedRV (
300300 base_rv_op = rv .owner .op ,
301301 inputs = graph_inputs_ ,
302- outputs = [truncated_rv_ , tuple (updates .values ())[0 ]],
302+ # This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
303+ outputs = [truncated_rv_ , rv_ .owner .outputs [0 ], tuple (updates .values ())[0 ]],
303304 inline = True ,
304305 )(* graph_inputs )
305- updates = {truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ]}
306+ # TODO: Is the order of multiple shared variables determnistic?
307+ assert truncated_rv .owner .inputs [- 2 ] is rv_ .owner .inputs [0 ]
308+ updates = {
309+ truncated_rv .owner .inputs [- 2 ]: truncated_rv .owner .outputs [- 2 ],
310+ truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ],
311+ }
306312 return truncated_rv , updates
307313
308314
309315@_logprob .register (TruncatedRV )
310316def truncated_logprob (op , values , * inputs , ** kwargs ):
311317 (value ,) = values
312318
313- * rv_inputs , lower_bound , upper_bound , rng = inputs
319+ # Rejection sample graph has two rngs
320+ if len (op .shared_inputs ) == 2 :
321+ * rv_inputs , lower_bound , upper_bound , _ , rng = inputs
322+ else :
323+ * rv_inputs , lower_bound , upper_bound , rng = inputs
314324 rv_inputs = [rng , * rv_inputs ]
315325
316326 base_rv_op = op .base_rv_op
@@ -361,11 +371,11 @@ def truncated_logprob(op, values, *inputs, **kwargs):
361371
362372
363373@_truncated .register (arb .UniformRV )
364- def uniform_truncated (op , lower , upper , rng , size , dtype , lower_orig , upper_orig ):
365- truncated_uniform = at .random .uniform (
374+ def uniform_truncated (op , lower , upper , srng , size , dtype , lower_orig , upper_orig ):
375+ truncated_uniform = srng .gen (
376+ op ,
366377 at .max ((lower_orig , lower )),
367378 at .min ((upper_orig , upper )),
368- rng = rng ,
369379 size = size ,
370380 dtype = dtype ,
371381 )
0 commit comments