diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 5af1bc3ca4..205d000371 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -18,6 +18,7 @@ from aesara.tensor.random.basic import ( RandomVariable, + ScipyRandomVariable, bernoulli, betabinom, binomial, @@ -1117,7 +1118,7 @@ def logcdf(value, good, bad, n): ) -class DiscreteUniformRV(RandomVariable): +class DiscreteUniformRV(ScipyRandomVariable): name = "discrete_uniform" ndim_supp = 0 ndims_params = [0, 0] @@ -1125,7 +1126,7 @@ class DiscreteUniformRV(RandomVariable): _print_name = ("DiscreteUniform", "\\operatorname{DiscreteUniform}") @classmethod - def rng_fn(cls, rng, lower, upper, size=None): + def rng_fn_scipy(cls, rng, lower, upper, size=None): return stats.randint.rvs(lower, upper + 1, size=size, random_state=rng) diff --git a/pymc/tests/distributions/test_discrete.py b/pymc/tests/distributions/test_discrete.py index a885170ba2..43fc62e4bc 100644 --- a/pymc/tests/distributions/test_discrete.py +++ b/pymc/tests/distributions/test_discrete.py @@ -1042,6 +1042,10 @@ def discrete_uniform_rng_fn(self, size, lower, upper, rng): "check_rv_size", ] + def test_implied_degenerate_shape(self): + x = pm.DiscreteUniform.dist(0, [1]) + assert x.eval().shape == (1,) + class TestDiracDelta(BaseTestDistributionRandom): def diracdelta_rng_fn(self, size, c):