Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.distributions.distribution import COERCIONS
from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate
from numpyro.util import not_jax_tracer

Expand Down Expand Up @@ -268,6 +268,8 @@ def process_message(self, msg):
if msg["type"] == "sample":
if msg["value"] is None:
msg["value"] = msg["name"]
if isinstance(msg["fn"], ExpandedDistribution):
msg["fn"] = msg["fn"].base_dist

if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
msg["stop"] = True
Expand Down
4 changes: 3 additions & 1 deletion numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.distributions.distribution import Distribution
from numpyro.util import identity

_PYRO_STACK = []
Expand Down Expand Up @@ -316,7 +317,8 @@ def process_message(self, msg):
cond_indep_stack = msg['cond_indep_stack']
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
if msg['type'] == 'sample':
# only expand if fn is Distribution, not a Funsor
if msg['type'] == 'sample' and isinstance(msg['fn'], Distribution):
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg['fn'].batch_shape
if 'sample_shape' in msg['kwargs']:
Expand Down
86 changes: 85 additions & 1 deletion test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def guide():
svi.update(svi_state)


@pytest.mark.xfail(reason="missing pattern in Funsor")
def test_collapse_beta_binomial_plate():
data = np.array([0., 1., 5., 5.])

Expand All @@ -591,6 +590,91 @@ def guide():
svi.update(svi_state)


def test_collapse_normal_normal():
data = np.array(0.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
y = numpyro.sample("y", dist.Normal(x, 1.))
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_normal_plate():
data = np.arange(5.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
y = numpyro.sample("y", dist.Normal(x, 1.))
with handlers.plate("data", len(data)):
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_plate_normal():
data = np.arange(5.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
with handlers.plate("data", len(data)):
y = numpyro.sample("y", dist.Normal(x, 1.))
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_mvn_mvn():
T, d, S = 5, 2, 3
data = jnp.ones((T, S))

def model():
x = numpyro.sample("x", dist.Exponential(1))
with handlers.collapse():
with numpyro.plate("d", d):
# TODO: verify that to_event works here
beta0 = numpyro.sample("beta0", dist.Normal(0, 1).expand([S]).to_event(1))
# TODO: address beta0 is a str, which cannot do infer_param_domain
beta = numpyro.sample("beta", dist.MultivariateNormal(beta0, jnp.eye(S)))
# FIXME: beta is a string here, how to apply numeric operators
mean = jnp.ones((T, d)) @ beta
with numpyro.plate("data", T, dim=-2):
numpyro.sample("obs", dist.MultivariateNormal(mean, jnp.eye(S)), obs=data)

def guide():
rate = numpyro.param("rate", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Exponential(rate))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_prng_key():
assert numpyro.prng_key() is None

Expand Down