Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import aesara.tensor as at
import numpy as np

from aesara.tensor import TensorConstant
from aesara.tensor.random.basic import (
RandomVariable,
ScipyRandomVariable,
Expand Down Expand Up @@ -1285,17 +1286,21 @@ def dist(cls, p=None, logit_p=None, **kwargs):
if logit_p is not None:
p = pm.math.softmax(logit_p, axis=-1)

if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
p = at.as_tensor_variable(p)
if isinstance(p, TensorConstant):
p_ = np.asarray(p.data)
if np.any(p_ < 0):
raise ValueError(f"Negative `p` parameters are not valid, got: {p_}")
p_sum_ = np.sum([p_], axis=-1)
if not np.all(np.isclose(p_sum_, 1.0)):
warnings.warn(
f"`p` parameters sum to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
f"`p` parameters sum to {p_sum_}, instead of 1.0. "
"They will be automatically rescaled. "
"You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
p = at.as_tensor_variable(floatX(p))
p_ = p_ / at.sum(p_, axis=-1, keepdims=True)
p = at.as_tensor_variable(p_)
return super().dist([p], **kwargs)

def moment(rv, size, p):
Expand Down Expand Up @@ -1341,7 +1346,11 @@ def logp(value, p):
)

return check_parameters(
res, at.all(p_ >= 0, axis=-1), at.all(p <= 1, axis=-1), msg="0 <= p <=1"
res,
p_ >= 0,
p_ <= 1,
at.isclose(at.sum(p, axis=-1), 1),
msg="0 <= p <=1, sum(p) = 1",
)


Expand Down
24 changes: 15 additions & 9 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from aesara.graph.op import Op
from aesara.raise_op import Assert
from aesara.sparse.basic import sp_sum
from aesara.tensor import gammaln, sigmoid
from aesara.tensor import TensorConstant, gammaln, sigmoid
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
from aesara.tensor.random.basic import dirichlet, multinomial, multivariate_normal
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
Expand Down Expand Up @@ -543,16 +543,21 @@ class Multinomial(Discrete):

@classmethod
def dist(cls, n, p, *args, **kwargs):
if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
p = at.as_tensor_variable(p)
if isinstance(p, TensorConstant):
p_ = np.asarray(p.data)
if np.any(p_ < 0):
raise ValueError(f"Negative `p` parameters are not valid, got: {p_}")
p_sum_ = np.sum([p_], axis=-1)
if not np.all(np.isclose(p_sum_, 1.0)):
warnings.warn(
f"`p` parameters sum up to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
f"`p` parameters sum to {p_sum_}, instead of 1.0. "
"They will be automatically rescaled. "
"You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
p_ = p_ / at.sum(p_, axis=-1, keepdims=True)
p = at.as_tensor_variable(p_)
n = at.as_tensor_variable(n)
p = at.as_tensor_variable(p)
return super().dist([n, p], *args, **kwargs)
Expand Down Expand Up @@ -591,10 +596,11 @@ def logp(value, n, p):
)
return check_parameters(
res,
p >= 0,
p <= 1,
at.isclose(at.sum(p, axis=-1), 1),
at.ge(n, 0),
msg="p <= 1, sum(p) = 1, n >= 0",
msg="0 <= p <= 1, sum(p) = 1, n >= 0",
)


Expand Down
53 changes: 30 additions & 23 deletions pymc/tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ def logcdf_fn(value, psi, n, p):
{"n": NatSmall, "p": Unit, "psi": Unit},
)

@pytest.mark.parametrize("n", [2, 3, 4])
def test_categorical(self, n):
check_logp(
pm.Categorical,
Domain(range(n), dtype="int64", edges=(0, n)),
{"p": Simplex(n)},
lambda value, p: categorical_logpdf(value, p),
)

@aesara.config.change_flags(compute_test_value="raise")
def test_categorical_bounds(self):
with pm.Model():
Expand All @@ -488,42 +497,40 @@ def test_categorical_bounds(self):
# entries if there is a single or pair number of negative values
# and the rest are zero
np.array([-1, -1, 0, 0]),
at.as_tensor_variable([-1, -1, 0, 0]),
],
)
def test_categorical_negative_p(self, p):
with pytest.raises(ValueError, match=f"{p}"):
with pytest.raises(ValueError, match="Negative `p` parameters are not valid"):
with pm.Model():
x = pm.Categorical("x", p=p)

def test_categorical_negative_p_symbolic(self):
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()

def test_categorical_p_not_normalized_symbolic(self):
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2]))
pm.logp(invalid_dist, value).eval()

@pytest.mark.parametrize("n", [2, 3, 4])
def test_categorical(self, n):
check_logp(
pm.Categorical,
Domain(range(n), dtype="int64", edges=(0, n)),
{"p": Simplex(n)},
lambda value, p: categorical_logpdf(value, p),
)

def test_categorical_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pytest.warns(UserWarning, match="They will be automatically rescaled"):
with pm.Model() as m:
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)

def test_categorical_negative_p_symbolic(self):
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Categorical.dist(p=[x, x, x])

with pytest.raises(ParameterValueError):
pm.logp(invalid_dist, value).eval({x: -1 / 3})

def test_categorical_p_not_normalized_symbolic(self):
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Categorical.dist(p=(x, x, x))

with pytest.raises(ParameterValueError):
pm.logp(invalid_dist, value).eval({x: 0.5})

@pytest.mark.parametrize("n", [2, 3, 4])
def test_orderedlogistic(self, n):
with warnings.catch_warnings():
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,12 +1009,12 @@ def setup_class(cls):
@pytest.mark.parametrize("batch_shape", [(3, 4), (20,)], ids=str)
def test_with_multinomial(self, batch_shape):
p = np.random.uniform(size=(*batch_shape, self.mixture_comps, 3))
p /= p.sum(axis=-1, keepdims=True)
n = 100 * np.ones((*batch_shape, 1))
w = np.ones(self.mixture_comps) / self.mixture_comps
mixture_axis = len(batch_shape)
with Model() as model:
with pytest.warns(UserWarning, match="parameters sum up to"):
comp_dists = Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3))
comp_dists = Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3))
mixture = Mixture(
"mixture",
w=w,
Expand Down
21 changes: 13 additions & 8 deletions pymc/tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,14 +548,14 @@ def test_multinomial_invalid_value(self):

def test_multinomial_negative_p(self):
# test passing a list/numpy with negative p raises an immediate error
with pytest.raises(ValueError, match="[-1, 1, 1]"):
with pytest.raises(ValueError, match="Negative `p` parameters are not valid"):
with pm.Model() as model:
x = pm.Multinomial("x", n=5, p=[-1, 1, 1])

def test_multinomial_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pytest.warns(UserWarning, match="They will be automatically rescaled"):
with pm.Model() as m:
x = pm.Multinomial("x", n=5, p=[1, 1, 1, 1, 1])
# test stored p-vals have been normalised
Expand All @@ -564,18 +564,23 @@ def test_multinomial_p_not_normalized(self):
def test_multinomial_negative_p_symbolic(self):
# Passing symbolic negative p does not raise an immediate error, but evaluating
# logp raises a ParameterValueError
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Multinomial.dist(n=1, p=[x, x, x])

with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()
pm.logp(invalid_dist, value).eval({x: -1 / 3})

def test_multinomial_p_not_normalized_symbolic(self):
# Passing symbolic p that do not add up to on does not raise any warning, but evaluating
# logp raises a ParameterValueError
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Multinomial.dist(n=1, p=(x, x, x))
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()
pm.logp(invalid_dist, value).eval({x: 0.5})

@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
@pytest.mark.parametrize(
Expand Down