Skip to content
3 changes: 3 additions & 0 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
cov = quaddist_matrix(cov, chol, tau, lower)
return super().dist([mu, cov], **kwargs)

def get_moment(rv, size, mu, cov):
return mu

def logp(value, mu, cov):
"""
Calculate log-probability of Multivariate Normal distribution
Expand Down
14 changes: 14 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ZeroInflatedBinomial,
ZeroInflatedPoisson,
)
from pymc.distributions.multivariate import MvNormal
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.initial_point import make_initial_point_fn
from pymc.model import Model
Expand Down Expand Up @@ -595,3 +596,16 @@ def test_discrete_uniform_moment(lower, upper, size, expected):
with Model() as model:
DiscreteUniform("x", lower=lower, upper=upper, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"mu, cov, size, expected",
[
(np.array([1.]), np.array([[1.]]), None, np.array([1.])),
(np.ones((10, )), np.identity(10), None, np.ones((10, )))
]
)
def test_mv_normal_moment(mu, cov, size, expected):
with Model() as model:
MvNormal("x", mu=mu, cov=cov, size=size)
assert_moment_is_expected(model, expected)