Skip to content
7 changes: 7 additions & 0 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
cov = quaddist_matrix(cov, chol, tau, lower)
return super().dist([mu, cov], **kwargs)

def get_moment(rv, size, mu, cov):
moment = mu
if not rv_size_is_none(size):
moment_size = at.concatenate([size, mu.shape])
moment = at.full(moment_size, mu)
return moment

def logp(value, mu, cov):
"""
Calculate log-probability of Multivariate Normal distribution
Expand Down
35 changes: 35 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ZeroInflatedBinomial,
ZeroInflatedPoisson,
)
from pymc.distributions.multivariate import MvNormal
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.initial_point import make_initial_point_fn
from pymc.model import Model
Expand Down Expand Up @@ -751,6 +752,40 @@ def test_categorical_moment(p, size, expected):
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"mu, cov, size, expected",
[
(np.ones(1), np.identity(1), None, np.ones(1)),
(np.ones(3), np.identity(3), None, np.ones(3)),
(np.ones((2, 2)), np.identity(2), None, np.ones((2, 2))),
(np.array([1, 0, 3.0]), np.identity(3), None, np.array([1, 0, 3.0])),
(np.array([1, 0, 3.0]), np.identity(3), (4, 2), np.full((4, 2, 3), [1, 0, 3.0])),
(
np.array([1, 3.0]),
np.identity(2),
5,
np.full((5, 2), [1, 3.0]),
),
(
np.array([1, 3.0]),
np.array([[1.0, 0.5], [0.5, 2]]),
(4, 5),
np.full((4, 5, 2), [1, 3.0]),
),
(
np.array([[3.0, 5], [1, 4]]),
np.identity(2),
(4, 5),
np.full((4, 5, 2, 2), [[3.0, 5], [1, 4]]),
),
],
)
def test_mv_normal_moment(mu, cov, size, expected):
with Model() as model:
MvNormal("x", mu=mu, cov=cov, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"mu, sigma, size, expected",
[
Expand Down