11from functools import singledispatch
2- from typing import Tuple
2+ from typing import Optional , Tuple
33
44import aesara .tensor as at
55import aesara .tensor .random .basic as arb
66import numpy as np
7- from aesara import scan , shared
7+ from aesara import scan
88from aesara .compile .builders import OpFromGraph
99from aesara .graph .op import Op
1010from aesara .raise_op import CheckAndRaise
1111from aesara .scan import until
12+ from aesara .tensor .random import RandomStream
1213from aesara .tensor .random .op import RandomVariable
1314from aesara .tensor .var import TensorConstant , TensorVariable
1415
@@ -68,7 +69,11 @@ def __str__(self):
6869
6970
7071def truncate (
71- rv : TensorVariable , lower = None , upper = None , max_n_steps : int = 10_000 , rng = None
72+ rv : TensorVariable ,
73+ lower = None ,
74+ upper = None ,
75+ max_n_steps : int = 10_000 ,
76+ srng : Optional [RandomStream ] = None ,
7277) -> Tuple [TensorVariable , Tuple [TensorVariable , TensorVariable ]]:
7378 """Truncate a univariate `RandomVariable` between `lower` and `upper`.
7479
@@ -99,13 +104,13 @@ def truncate(
99104 lower = at .as_tensor_variable (lower ) if lower is not None else at .constant (- np .inf )
100105 upper = at .as_tensor_variable (upper ) if upper is not None else at .constant (np .inf )
101106
102- if rng is None :
103- rng = shared ( np . random . RandomState (), borrow = True )
107+ if srng is None :
108+ srng = RandomStream ( )
104109
105110 # Try to use specialized Op
106111 try :
107112 truncated_rv , updates = _truncated (
108- rv .owner .op , lower , upper , rng , * rv .owner .inputs [1 :]
113+ rv .owner .op , lower , upper , srng , * rv .owner .inputs [1 :]
109114 )
110115 return truncated_rv , updates
111116 except NotImplementedError :
@@ -116,8 +121,8 @@ def truncate(
116121 # though it would not be necessary for the icdf OpFromGraph
117122 graph_inputs = [* rv .owner .inputs [1 :], lower , upper ]
118123 graph_inputs_ = [inp .type () for inp in graph_inputs ]
119- * rv_inputs_ , lower_ , upper_ = graph_inputs_
120- rv_ = rv .owner .op . make_node ( rng , * rv_inputs_ ). default_output ( )
124+ size_ , dtype_ , * rv_inputs_ , lower_ , upper_ = graph_inputs_
125+ rv_ = srng . gen ( rv .owner .op , * rv_inputs_ , size = size_ , dtype = dtype_ )
121126
122127 # Try to use inverted cdf sampling
123128 try :
@@ -126,11 +131,10 @@ def truncate(
126131 lower_value = lower_ - 1 if rv .owner .op .dtype .startswith ("int" ) else lower_
127132 cdf_lower_ = at .exp (logcdf (rv_ , lower_value ))
128133 cdf_upper_ = at .exp (logcdf (rv_ , upper_ ))
129- uniform_ = at . random .uniform (
134+ uniform_ = srng .uniform (
130135 cdf_lower_ ,
131136 cdf_upper_ ,
132- rng = rng ,
133- size = rv_inputs_ [0 ],
137+ size = size_ ,
134138 )
135139 truncated_rv_ = icdf (rv_ , uniform_ )
136140 truncated_rv = TruncatedRV (
@@ -146,27 +150,23 @@ def truncate(
146150
147151 # Fallback to rejection sampling
148152 # TODO: Handle potential broadcast by lower / upper
149- def loop_fn (truncated_rv , reject_draws , lower , upper , rng , * rv_inputs ):
150- next_rng , new_truncated_rv = rv .owner .op . make_node ( rng , * rv_inputs ). outputs
153+ def loop_fn (truncated_rv , reject_draws , lower , upper , size , dtype , * rv_inputs ):
154+ new_truncated_rv = srng . gen ( rv .owner .op , * rv_inputs , size = size , dtype = dtype ) # type: ignore
151155 truncated_rv = at .set_subtensor (
152156 truncated_rv [reject_draws ],
153157 new_truncated_rv [reject_draws ],
154158 )
155159 reject_draws = at .or_ ((truncated_rv < lower ), (truncated_rv > upper ))
156160
157- return (
158- (truncated_rv , reject_draws ),
159- [(rng , next_rng )],
160- until (~ at .any (reject_draws )),
161- )
161+ return (truncated_rv , reject_draws ), until (~ at .any (reject_draws ))
162162
163163 (truncated_rv_ , reject_draws_ ), updates = scan (
164164 loop_fn ,
165165 outputs_info = [
166166 at .zeros_like (rv_ ),
167167 at .ones_like (rv_ , dtype = bool ),
168168 ],
169- non_sequences = [lower_ , upper_ , rng , * rv_inputs_ ],
169+ non_sequences = [lower_ , upper_ , size_ , dtype_ , * rv_inputs_ ],
170170 n_steps = max_n_steps ,
171171 strict = True ,
172172 )
@@ -180,18 +180,28 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
180180 truncated_rv = TruncatedRV (
181181 base_rv_op = rv .owner .op ,
182182 inputs = graph_inputs_ ,
183- outputs = [truncated_rv_ , tuple (updates .values ())[0 ]],
183+ # This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
184+ outputs = [truncated_rv_ , rv_ .owner .outputs [0 ], tuple (updates .values ())[0 ]],
184185 inline = True ,
185186 )(* graph_inputs )
186- updates = {truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ]}
187+ # TODO: Is the order of multiple shared variables determnistic?
188+ assert truncated_rv .owner .inputs [- 2 ] is rv_ .owner .inputs [0 ]
189+ updates = {
190+ truncated_rv .owner .inputs [- 2 ]: truncated_rv .owner .outputs [- 2 ],
191+ truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ],
192+ }
187193 return truncated_rv , updates
188194
189195
190196@_logprob .register (TruncatedRV )
191197def truncated_logprob (op , values , * inputs , ** kwargs ):
192198 (value ,) = values
193199
194- * rv_inputs , lower_bound , upper_bound , rng = inputs
200+ # Rejection sample graph has two rngs
201+ if len (op .shared_inputs ) == 2 :
202+ * rv_inputs , lower_bound , upper_bound , _ , rng = inputs
203+ else :
204+ * rv_inputs , lower_bound , upper_bound , rng = inputs
195205 rv_inputs = [rng , * rv_inputs ]
196206
197207 base_rv_op = op .base_rv_op
@@ -242,11 +252,11 @@ def truncated_logprob(op, values, *inputs, **kwargs):
242252
243253
244254@_truncated .register (arb .UniformRV )
245- def uniform_truncated (op , lower , upper , rng , size , dtype , lower_orig , upper_orig ):
246- truncated_uniform = at .random .uniform (
255+ def uniform_truncated (op , lower , upper , srng , size , dtype , lower_orig , upper_orig ):
256+ truncated_uniform = srng .gen (
257+ op ,
247258 at .max ((lower_orig , lower )),
248259 at .min ((upper_orig , upper )),
249- rng = rng ,
250260 size = size ,
251261 dtype = dtype ,
252262 )
0 commit comments