Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc.distributions.dist_math import bound, factln, logpow, multigammaln
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
from pymc.distributions.shape_utils import broadcast_dist_samples_to, rv_size_is_none, to_tuple
from pymc.math import kron_diag, kron_dot

__all__ = [
Expand Down Expand Up @@ -405,6 +405,13 @@ def dist(cls, a, **kwargs):

return super().dist([a], **kwargs)

def get_moment(rv, size, a):
norm_constant = at.sum(a, axis=-1)[..., None]
moment = a/norm_constant
if not rv_size_is_none(size):
return at.full(size, moment)
return moment

def logp(value, a):
"""
Calculate log-probability of Dirichlet distribution
Expand Down
25 changes: 25 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Constant,
DiscreteUniform,
ExGaussian,
Dirichlet,
Exponential,
Flat,
Gamma,
Expand Down Expand Up @@ -611,4 +612,28 @@ def test_hyper_geometric_moment(N, k, n, size, expected):
def test_discrete_uniform_moment(lower, upper, size, expected):
with Model() as model:
DiscreteUniform("x", lower=lower, upper=upper, size=size)

@pytest.mark.parametrize(
"a, size, expected",
[
(
np.array([2, 3, 5, 7, 11]),
None,
np.array([2, 3, 5, 7, 11])/28,
),
(
np.array([[1, 2, 3], [5, 6, 7]]),
None,
np.array([[1, 2, 3], [5, 6, 7]])/np.array([6, 18])[..., np.newaxis],
),
(
np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])),
(11, 5,),
np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]),
),
]
)
def test_dirichlet_moment(a, size, expected):
with Model() as model:
Dirichlet("x", a=a, size=size)
assert_moment_is_expected(model, expected)