1616
1717from pytensor .tensor import TensorVariable
1818from pytensor .tensor .random .op import RandomVariable
19+ from pytensor .tensor .random .utils import normalize_size_param
1920
2021from pymc .distributions .distribution import (
2122 Distribution ,
2223 SymbolicRandomVariable ,
2324 _support_point ,
2425)
25- from pymc .distributions .shape_utils import _change_dist_size , change_dist_size
26+ from pymc .distributions .shape_utils import (
27+ _change_dist_size ,
28+ change_dist_size ,
29+ implicit_size_from_params ,
30+ rv_size_is_none ,
31+ )
2632from pymc .util import check_dist_not_registered
2733
2834
@@ -31,9 +37,27 @@ class CensoredRV(SymbolicRandomVariable):
3137
3238 inline_logprob = True
3339 signature = "(),(),()->()"
34- ndim_supp = 0
3540 _print_name = ("Censored" , "\\ operatorname{Censored}" )
3641
42+ @classmethod
43+ def rv_op (cls , dist , lower , upper , * , size = None ):
44+ # We don't allow passing `rng` because we don't fully control the rng of the components!
45+ lower = pt .constant (- np .inf ) if lower is None else pt .as_tensor (lower )
46+ upper = pt .constant (np .inf ) if upper is None else pt .as_tensor (upper )
47+ size = normalize_size_param (size )
48+
49+ if rv_size_is_none (size ):
50+ size = implicit_size_from_params (dist , lower , upper , ndims_params = cls .ndims_params )
51+
52+ # Censoring is achieved by clipping the base distribution between lower and upper
53+ dist = change_dist_size (dist , size )
54+ censored_rv = pt .clip (dist , lower , upper )
55+
56+ return CensoredRV (
57+ inputs = [dist , lower , upper ],
58+ outputs = [censored_rv ],
59+ )(dist , lower , upper )
60+
3761
3862class Censored (Distribution ):
3963 r"""
@@ -85,6 +109,7 @@ class Censored(Distribution):
85109 """
86110
87111 rv_type = CensoredRV
112+ rv_op = CensoredRV .rv_op
88113
89114 @classmethod
90115 def dist (cls , dist , lower , upper , ** kwargs ):
@@ -101,24 +126,6 @@ def dist(cls, dist, lower, upper, **kwargs):
101126 check_dist_not_registered (dist )
102127 return super ().dist ([dist , lower , upper ], ** kwargs )
103128
104- @classmethod
105- def rv_op (cls , dist , lower = None , upper = None , size = None ):
106- lower = pt .constant (- np .inf ) if lower is None else pt .as_tensor_variable (lower )
107- upper = pt .constant (np .inf ) if upper is None else pt .as_tensor_variable (upper )
108-
109- # When size is not specified, dist may have to be broadcasted according to lower/upper
110- dist_shape = size if size is not None else pt .broadcast_shape (dist , lower , upper )
111- dist = change_dist_size (dist , dist_shape )
112-
113- # Censoring is achieved by clipping the base distribution between lower and upper
114- dist_ , lower_ , upper_ = dist .type (), lower .type (), upper .type ()
115- censored_rv_ = pt .clip (dist_ , lower_ , upper_ )
116-
117- return CensoredRV (
118- inputs = [dist_ , lower_ , upper_ ],
119- outputs = [censored_rv_ ],
120- )(dist , lower , upper )
121-
122129
123130@_change_dist_size .register (CensoredRV )
124131def change_censored_size (cls , dist , new_size , expand = False ):
0 commit comments