4141from pytensor .graph .basic import Node
4242from pytensor .graph .fg import FunctionGraph
4343from pytensor .graph .rewriting .basic import node_rewriter
44- from pytensor .scalar .basic import Mul
45- from pytensor .tensor .basic import get_underlying_scalar_constant_value
4644from pytensor .tensor .elemwise import Elemwise
47- from pytensor .tensor .exceptions import NotScalarConstantError
4845from pytensor .tensor .math import Max
4946from pytensor .tensor .random .op import RandomVariable
5047from pytensor .tensor .variable import TensorVariable
5653 _logprob_helper ,
5754)
5855from pymc .logprob .rewriting import measurable_ir_rewrites_db
56+ from pymc .logprob .utils import find_negated_var
5957from pymc .math import logdiffexp
6058from pymc .pytensorf import constant_fold
6159
@@ -168,6 +166,13 @@ class MeasurableMaxNeg(Max):
168166MeasurableVariable .register (MeasurableMaxNeg )
169167
170168
169+ class MeasurableDiscreteMaxNeg (Max ):
170+ """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""
171+
172+
173+ MeasurableVariable .register (MeasurableDiscreteMaxNeg )
174+
175+
171176@node_rewriter (tracks = [Max ])
172177def find_measurable_max_neg (fgraph : FunctionGraph , node : Node ) -> Optional [List [TensorVariable ]]:
173178 rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
@@ -180,37 +185,20 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[
180185
181186 base_var = node .inputs [0 ]
182187
183- if base_var .owner is None :
184- return None
185-
186- if not rv_map_feature .request_measurable (node .inputs ):
187- return None
188-
189188 # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
190- if not isinstance (base_var .owner .op , Elemwise ):
189+ if not ( base_var . owner is not None and isinstance (base_var .owner .op , Elemwise ) ):
191190 return None
192191
192+ base_rv = find_negated_var (base_var )
193+
193194 # negation is rv * (-1). Hence the scalar_op must be Mul
194- try :
195- if not (
196- isinstance (base_var .owner .op .scalar_op , Mul )
197- and len (base_var .owner .inputs ) == 2
198- and get_underlying_scalar_constant_value (base_var .owner .inputs [1 ]) == - 1
199- ):
200- return None
201- except NotScalarConstantError :
195+ if base_rv is None :
202196 return None
203197
204- base_rv = base_var .owner .inputs [0 ]
205-
206198 # Non-univariate distributions and non-RVs must be rejected
207199 if not (isinstance (base_rv .owner .op , RandomVariable ) and base_rv .owner .op .ndim_supp == 0 ):
208200 return None
209201
210- # TODO: We are currently only supporting continuous rvs
211- if isinstance (base_rv .owner .op , RandomVariable ) and base_rv .owner .op .dtype .startswith ("int" ):
212- return None
213-
214202 # univariate i.i.d. test which also rules out other distributions
215203 for params in base_rv .owner .inputs [3 :]:
216204 if params .type .ndim != 0 :
@@ -222,11 +210,16 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[
222210 if axis != base_var_dims :
223211 return None
224212
225- measurable_min = MeasurableMaxNeg (list (axis ))
226- min_rv_node = measurable_min .make_node (base_var )
227- min_rv = min_rv_node .outputs
213+ if not rv_map_feature .request_measurable ([base_rv ]):
214+ return None
228215
229- return min_rv
216+ # distinguish measurable discrete and continuous (because logprob is different)
217+ if base_rv .owner .op .dtype .startswith ("int" ):
218+ measurable_min = MeasurableDiscreteMaxNeg (list (axis ))
219+ else :
220+ measurable_min = MeasurableMaxNeg (list (axis ))
221+
222+ return measurable_min .make_node (base_rv ).outputs
230223
231224
232225measurable_ir_rewrites_db .register (
@@ -238,14 +231,13 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[
238231
239232
240233@_logprob .register (MeasurableMaxNeg )
241- def max_neg_logprob (op , values , base_var , ** kwargs ):
234+ def max_neg_logprob (op , values , base_rv , ** kwargs ):
242235 r"""Compute the log-likelihood graph for the `Max` operation.
243236 The formula that we use here is :
244237 \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
245238 where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
246239 """
247240 (value ,) = values
248- base_rv = base_var .owner .inputs [0 ]
249241
250242 logprob = _logprob_helper (base_rv , - value )
251243 logcdf = _logcdf_helper (base_rv , - value )
@@ -254,3 +246,31 @@ def max_neg_logprob(op, values, base_var, **kwargs):
254246 logprob = (n - 1 ) * pt .math .log (1 - pt .math .exp (logcdf )) + logprob + pt .math .log (n )
255247
256248 return logprob
249+
250+
251+ @_logprob .register (MeasurableDiscreteMaxNeg )
252+ def discrete_max_neg_logprob (op , values , base_rv , ** kwargs ):
253+ r"""Compute the log-likelihood graph for the `Max` operation.
254+
255+ The formula that we use here is :
256+ .. math::
257+ \ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
258+ where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
259+ """
260+
261+ (value ,) = values
262+
263+ # The cdf of a negative variable is the survival at the negated value
264+ logcdf = pt .log1mexp (_logcdf_helper (base_rv , - value ))
265+ logcdf_prev = pt .log1mexp (_logcdf_helper (base_rv , - (value + 1 )))
266+
267+ [n ] = constant_fold ([base_rv .size ])
268+
269+ # Now we can use the same expression as the discrete max
270+ logprob = pt .where (
271+ pt .and_ (pt .eq (logcdf , - pt .inf ), pt .eq (logcdf_prev , - pt .inf )),
272+ - pt .inf ,
273+ logdiffexp (n * logcdf_prev , n * logcdf ),
274+ )
275+
276+ return logprob
0 commit comments