Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,9 @@ def partial_observed_rv_logprob(op, values, dist, mask, **kwargs):
# For the logp, simply join the values
[obs_value, unobs_value] = values
antimask = ~mask
joined_value = pt.empty(constant_fold([dist.shape])[0])
# We don't need it to be completely folded, just to avoid any RVs in the graph of the shape
[folded_shape] = constant_fold([dist.shape], raise_not_constant=False)
joined_value = pt.empty(folded_shape)
joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
joined_logp = logp(dist, joined_value)
Expand Down
30 changes: 25 additions & 5 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,16 +979,21 @@ def test_univariate(self, symbolic_rv):
np.testing.assert_allclose(obs_logp, st.norm([1, 2]).logpdf([0.25, 0.5]))
np.testing.assert_allclose(unobs_logp, st.norm([3]).logpdf([0.25]))

@pytest.mark.parametrize("mutable_shape", (False, True))
@pytest.mark.parametrize("obs_component_selected", (True, False))
def test_multivariate_constant_mask_separable(self, obs_component_selected):
def test_multivariate_constant_mask_separable(self, obs_component_selected, mutable_shape):
if obs_component_selected:
mask = np.zeros((1, 4), dtype=bool)
else:
mask = np.ones((1, 4), dtype=bool)
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
if mutable_shape:
shape = (1, pytensor.shared(np.array(4, dtype=int)))
else:
shape = (1, 4)
rv = pm.Dirichlet.dist(pt.arange(shape[-1]) + 1, shape=shape)
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
Expand Down Expand Up @@ -1023,6 +1028,10 @@ def test_multivariate_constant_mask_separable(self, obs_component_selected):
np.testing.assert_allclose(obs_logp, expected_obs_logp)
np.testing.assert_allclose(unobs_logp, expected_unobs_logp)

if mutable_shape:
shape[-1].set_value(7)
assert tuple(joined_rv.shape.eval()) == (1, 7)

def test_multivariate_constant_mask_unseparable(self):
mask = pt.constant(np.array([[True, True, False, False]]))
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
Expand Down Expand Up @@ -1097,14 +1106,19 @@ def test_multivariate_shared_mask_separable(self):
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

def test_multivariate_shared_mask_unseparable(self):
@pytest.mark.parametrize("mutable_shape", (False, True))
def test_multivariate_shared_mask_unseparable(self, mutable_shape):
# Even if the mask is initially not mixing support dims,
# it could later be changed in a way that does!
mask = shared(np.array([[True, True, True, True]]))
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
if mutable_shape:
shape = mask.shape
else:
shape = (1, 4)
rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=shape)
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
Expand Down Expand Up @@ -1134,16 +1148,22 @@ def test_multivariate_shared_mask_unseparable(self):

# Test that we can update a shared mask
mask.set_value(np.array([[False, False, True, True]]))
equivalent_value = np.array([0.1, 0.4, 0.4, 0.1])

assert tuple(obs_rv.shape.eval()) == (2,)
assert tuple(unobs_rv.shape.eval()) == (2,)

new_expected_logp = pm.logp(rv, [0.1, 0.4, 0.4, 0.1]).eval()
new_expected_logp = pm.logp(rv, equivalent_value).eval()
assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak
obs_logp, unobs_logp = logp_fn()
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

if mutable_shape:
mask.set_value(np.array([[False, False, True, False], [False, False, False, True]]))
assert tuple(obs_rv.shape.eval()) == (6,)
assert tuple(unobs_rv.shape.eval()) == (2,)

def test_support_point(self):
x = pm.GaussianRandomWalk.dist(init_dist=pm.Normal.dist(-5), mu=1, steps=9)
ref_support_point = support_point(x).eval()
Expand Down
37 changes: 33 additions & 4 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pytensor.sparse as sparse
import pytensor.tensor as pt
import pytest
import scipy
import scipy.sparse as sps
import scipy.stats as st

Expand All @@ -38,7 +39,7 @@

import pymc as pm

from pymc import Deterministic, Model, Potential
from pymc import Deterministic, Model, MvNormal, Potential
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.distributions import Normal, transforms
from pymc.distributions.distribution import PartialObservedRV
Expand Down Expand Up @@ -1504,11 +1505,39 @@ def test_truncated_normal(self):
"""
with Model() as m:
mu = pm.TruncatedNormal("mu", mu=1, sigma=2, lower=0)
x = pm.TruncatedNormal(
"x", mu=mu, sigma=0.5, lower=0, observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan])
)
with pytest.warns(ImputationWarning):
x = pm.TruncatedNormal(
"x",
mu=mu,
sigma=0.5,
lower=0,
observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan]),
)
m.check_start_vals(m.initial_point())

def test_coordinates(self):
# Regression test for https://github.com/pymc-devs/pymc/issues/7304

coords = {"trial": range(30), "feature": range(2)}
observed = np.zeros((30, 2))
observed[0, 0] = np.nan

with Model(coords=coords) as model:
with pytest.warns(ImputationWarning):
MvNormal(
"y",
mu=np.zeros(2),
cov=np.eye(2),
observed=observed,
dims=("trial", "feature"),
)

logp_fn = model.compile_logp()
np.testing.assert_allclose(
logp_fn({"y_unobserved": [0]}),
scipy.stats.multivariate_normal.logpdf([0, 0], cov=np.eye(2)) * 30,
)


class TestShared:
def test_deterministic(self):
Expand Down