diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index b21e0fe7c..ce88d8bd7 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -16,6 +16,7 @@ import aesara.tensor as at import numpy as np +from aesara.tensor import TensorConstant from aesara.tensor.random.basic import ( RandomVariable, ScipyRandomVariable, @@ -1285,17 +1286,21 @@ def dist(cls, p=None, logit_p=None, **kwargs): if logit_p is not None: p = pm.math.softmax(logit_p, axis=-1) - if isinstance(p, np.ndarray) or isinstance(p, list): - if (np.asarray(p) < 0).any(): - raise ValueError(f"Negative `p` parameters are not valid, got: {p}") - p_sum = np.sum([p], axis=-1) - if not np.all(np.isclose(p_sum, 1.0)): + p = at.as_tensor_variable(p) + if isinstance(p, TensorConstant): + p_ = np.asarray(p.data) + if np.any(p_ < 0): + raise ValueError(f"Negative `p` parameters are not valid, got: {p_}") + p_sum_ = np.sum([p_], axis=-1) + if not np.all(np.isclose(p_sum_, 1.0)): warnings.warn( - f"`p` parameters sum to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.", + f"`p` parameters sum to {p_sum_}, instead of 1.0. " + "They will be automatically rescaled. " + "You can rescale them directly to get rid of this warning.", UserWarning, ) - p = p / at.sum(p, axis=-1, keepdims=True) - p = at.as_tensor_variable(floatX(p)) + p_ = p_ / at.sum(p_, axis=-1, keepdims=True) + p = at.as_tensor_variable(p_) return super().dist([p], **kwargs) def moment(rv, size, p): @@ -1341,7 +1346,11 @@ def logp(value, p): ) return check_parameters( - res, at.all(p_ >= 0, axis=-1), at.all(p <= 1, axis=-1), msg="0 <= p <=1" + res, + p_ >= 0, + p_ <= 1, + at.isclose(at.sum(p, axis=-1), 1), + msg="0 <= p <=1, sum(p) = 1", ) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 2f21d9e94..2191d3f55 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -30,7 +30,7 @@ from aesara.graph.op import Op from aesara.raise_op import Assert from aesara.sparse.basic import sp_sum -from aesara.tensor import gammaln, sigmoid +from aesara.tensor import TensorConstant, gammaln, sigmoid from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace from aesara.tensor.random.basic import dirichlet, multinomial, multivariate_normal from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params @@ -543,16 +543,21 @@ class Multinomial(Discrete): @classmethod def dist(cls, n, p, *args, **kwargs): - if isinstance(p, np.ndarray) or isinstance(p, list): - if (np.asarray(p) < 0).any(): - raise ValueError(f"Negative `p` parameters are not valid, got: {p}") - p_sum = np.sum([p], axis=-1) - if not np.all(np.isclose(p_sum, 1.0)): + p = at.as_tensor_variable(p) + if isinstance(p, TensorConstant): + p_ = np.asarray(p.data) + if np.any(p_ < 0): + raise ValueError(f"Negative `p` parameters are not valid, got: {p_}") + p_sum_ = np.sum([p_], axis=-1) + if not np.all(np.isclose(p_sum_, 1.0)): warnings.warn( - f"`p` parameters sum up to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.", + f"`p` parameters sum to {p_sum_}, instead of 1.0. " + "They will be automatically rescaled. " + "You can rescale them directly to get rid of this warning.", UserWarning, ) - p = p / at.sum(p, axis=-1, keepdims=True) + p_ = p_ / at.sum(p_, axis=-1, keepdims=True) + p = at.as_tensor_variable(p_) n = at.as_tensor_variable(n) p = at.as_tensor_variable(p) return super().dist([n, p], *args, **kwargs) @@ -591,10 +596,11 @@ def logp(value, n, p): ) return check_parameters( res, + p >= 0, p <= 1, at.isclose(at.sum(p, axis=-1), 1), at.ge(n, 0), - msg="p <= 1, sum(p) = 1, n >= 0", + msg="0 <= p <= 1, sum(p) = 1, n >= 0", ) diff --git a/pymc/tests/distributions/test_discrete.py b/pymc/tests/distributions/test_discrete.py index 43fc62e4b..aedfb2925 100644 --- a/pymc/tests/distributions/test_discrete.py +++ b/pymc/tests/distributions/test_discrete.py @@ -469,6 +469,15 @@ def logcdf_fn(value, psi, n, p): {"n": NatSmall, "p": Unit, "psi": Unit}, ) + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_categorical(self, n): + check_logp( + pm.Categorical, + Domain(range(n), dtype="int64", edges=(0, n)), + {"p": Simplex(n)}, + lambda value, p: categorical_logpdf(value, p), + ) + @aesara.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self): with pm.Model(): @@ -488,42 +497,40 @@ def test_categorical_bounds(self): # entries if there is a single or pair number of negative values # and the rest are zero np.array([-1, -1, 0, 0]), + at.as_tensor_variable([-1, -1, 0, 0]), ], ) def test_categorical_negative_p(self, p): - with pytest.raises(ValueError, match=f"{p}"): + with pytest.raises(ValueError, match="Negative `p` parameters are not valid"): with pm.Model(): x = pm.Categorical("x", p=p) - def test_categorical_negative_p_symbolic(self): - with pytest.raises(ParameterValueError): - value = np.array([[1, 1, 1]]) - invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([-1, 0.5, 0.5])) - pm.logp(invalid_dist, value).eval() - - def test_categorical_p_not_normalized_symbolic(self): - with pytest.raises(ParameterValueError): - value = np.array([[1, 1, 1]]) - invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2])) - pm.logp(invalid_dist, value).eval() - - @pytest.mark.parametrize("n", [2, 3, 4]) - def test_categorical(self, n): - check_logp( - pm.Categorical, - Domain(range(n), dtype="int64", edges=(0, n)), - {"p": Simplex(n)}, - lambda value, p: categorical_logpdf(value, p), - ) - def test_categorical_p_not_normalized(self): # test UserWarning is raised for p vals that sum to more than 1 # and normaliation is triggered - with pytest.warns(UserWarning, match="[5]"): + with pytest.warns(UserWarning, match="They will be automatically rescaled"): with pm.Model() as m: x = pm.Categorical("x", p=[1, 1, 1, 1, 1]) assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0) + def test_categorical_negative_p_symbolic(self): + value = np.array([[1, 1, 1]]) + + x = at.scalar("x") + invalid_dist = pm.Categorical.dist(p=[x, x, x]) + + with pytest.raises(ParameterValueError): + pm.logp(invalid_dist, value).eval({x: -1 / 3}) + + def test_categorical_p_not_normalized_symbolic(self): + value = np.array([[1, 1, 1]]) + + x = at.scalar("x") + invalid_dist = pm.Categorical.dist(p=(x, x, x)) + + with pytest.raises(ParameterValueError): + pm.logp(invalid_dist, value).eval({x: 0.5}) + @pytest.mark.parametrize("n", [2, 3, 4]) def test_orderedlogistic(self, n): with warnings.catch_warnings(): diff --git a/pymc/tests/distributions/test_mixture.py b/pymc/tests/distributions/test_mixture.py index 113f3e37b..4fba88e5a 100644 --- a/pymc/tests/distributions/test_mixture.py +++ b/pymc/tests/distributions/test_mixture.py @@ -1009,12 +1009,12 @@ def setup_class(cls): @pytest.mark.parametrize("batch_shape", [(3, 4), (20,)], ids=str) def test_with_multinomial(self, batch_shape): p = np.random.uniform(size=(*batch_shape, self.mixture_comps, 3)) + p /= p.sum(axis=-1, keepdims=True) n = 100 * np.ones((*batch_shape, 1)) w = np.ones(self.mixture_comps) / self.mixture_comps mixture_axis = len(batch_shape) with Model() as model: - with pytest.warns(UserWarning, match="parameters sum up to"): - comp_dists = Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3)) + comp_dists = Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3)) mixture = Mixture( "mixture", w=w, diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index d023912ce..eb8574bbb 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -548,14 +548,14 @@ def test_multinomial_invalid_value(self): def test_multinomial_negative_p(self): # test passing a list/numpy with negative p raises an immediate error - with pytest.raises(ValueError, match="[-1, 1, 1]"): + with pytest.raises(ValueError, match="Negative `p` parameters are not valid"): with pm.Model() as model: x = pm.Multinomial("x", n=5, p=[-1, 1, 1]) def test_multinomial_p_not_normalized(self): # test UserWarning is raised for p vals that sum to more than 1 # and normaliation is triggered - with pytest.warns(UserWarning, match="[5]"): + with pytest.warns(UserWarning, match="They will be automatically rescaled"): with pm.Model() as m: x = pm.Multinomial("x", n=5, p=[1, 1, 1, 1, 1]) # test stored p-vals have been normalised @@ -564,18 +564,23 @@ def test_multinomial_p_not_normalized(self): def test_multinomial_negative_p_symbolic(self): # Passing symbolic negative p does not raise an immediate error, but evaluating # logp raises a ParameterValueError + value = np.array([[1, 1, 1]]) + + x = at.scalar("x") + invalid_dist = pm.Multinomial.dist(n=1, p=[x, x, x]) + with pytest.raises(ParameterValueError): - value = np.array([[1, 1, 1]]) - invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([-1, 0.5, 0.5])) - pm.logp(invalid_dist, value).eval() + pm.logp(invalid_dist, value).eval({x: -1 / 3}) def test_multinomial_p_not_normalized_symbolic(self): # Passing symbolic p that do not add up to on does not raise any warning, but evaluating # logp raises a ParameterValueError + value = np.array([[1, 1, 1]]) + + x = at.scalar("x") + invalid_dist = pm.Multinomial.dist(n=1, p=(x, x, x)) with pytest.raises(ParameterValueError): - value = np.array([[1, 1, 1]]) - invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([1, 0.5, 0.5])) - pm.logp(invalid_dist, value).eval() + pm.logp(invalid_dist, value).eval({x: 0.5}) @pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])]) @pytest.mark.parametrize(