Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,24 @@ def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y):
return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))


def eager_plate_multinomial(op, x, reduced_vars):
if not reduced_vars.isdisjoint(x.probs.inputs):
return None
if not reduced_vars.issubset(x.value.inputs):
return None

backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
total_count = x.total_count
for v in reduced_vars:
if v in total_count.inputs:
total_count = total_count.reduce(ops.add, v)
else:
total_count = total_count * x.inputs[v].size
return backend_dist.Multinomial(total_count=total_count,
probs=x.probs,
value=x.value.reduce(ops.add, reduced_vars))


def _log_beta(x, y):
return ops.lgamma(x) + ops.lgamma(y) - ops.lgamma(x + y)

Expand Down
6 changes: 4 additions & 2 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
eager_multinomial,
eager_mvn,
eager_normal,
eager_plate_multinomial,
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
Expand All @@ -37,7 +38,7 @@
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor, dummy_numeric_array
from funsor.terms import Binary, Funsor, Variable, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, Reduce, Variable, eager, to_data, to_funsor
from funsor.util import methodof


Expand Down Expand Up @@ -286,6 +287,7 @@ def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None):
if hasattr(dist, "DirichletMultinomial"):
eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial)( # noqa: F821
eager_dirichlet_posterior)

eager.register(Reduce, ops.AddOp, Multinomial[Tensor, Funsor, Funsor], frozenset)( # noqa: F821
eager_plate_multinomial)

__all__ = list(x[0] for x in FUNSOR_DIST_NAMES if _get_numpyro_dist(x[0]) is not None)
2 changes: 1 addition & 1 deletion funsor/memoize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import Hashable
from collections.abc import Hashable
from contextlib import contextmanager

import funsor.interpreter as interpreter
Expand Down
3 changes: 2 additions & 1 deletion funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numbers
import typing
import warnings
from collections import Hashable, OrderedDict
from collections import OrderedDict
from collections.abc import Hashable
from functools import reduce, singledispatch
from weakref import WeakValueDictionary

Expand Down
5 changes: 4 additions & 1 deletion funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
eager_multinomial,
eager_mvn,
eager_normal,
eager_plate_multinomial,
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
Expand All @@ -41,7 +42,7 @@
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor, dummy_numeric_array
from funsor.terms import Binary, Funsor, Variable, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, Reduce, Variable, eager, to_data, to_funsor
from funsor.util import methodof


Expand Down Expand Up @@ -275,3 +276,5 @@ def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None):
eager_gamma_poisson)
eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial)( # noqa: F821
eager_dirichlet_posterior)
eager.register(Reduce, ops.AddOp, Multinomial[Tensor, Funsor, Funsor], frozenset)( # noqa: F821
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For pyro.collapse purpose, it is enough to use Multinomial[Tensor, Funsor, Tensor] here. Then we don't have to worry about x.value.reduce(...) issue in eager_plate_multinomial above. Is there a simple solution for that and I should go ahead and replace Funsor by Tensor here in the mean time?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should go ahead and replace Funsor by Tensor here in the mean time?

Sure, makes sense to me. I don't think there's a simple solution to the plate membership indication problem, at least not one simple enough to block this PR.

eager_plate_multinomial)
24 changes: 24 additions & 0 deletions test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,30 @@ def test_dirichlet_categorical_conjugate(batch_shape, size):
_assert_conjugate_density_ok(latent, conditional, obs)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('size', [2, 4, 5], ids=str)
def test_dirichlet_multinomial_conjugate_plate(batch_shape, size):
max_count = 10
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
full_shape = batch_shape + (size,)
prior = Variable("prior", Reals[size])
concentration = Tensor(ops.exp(randn(full_shape)), inputs)
value_data = ops.astype(randint(0, max_count, size=batch_shape + (7, size)), 'float32')
obs_inputs = inputs.copy()
obs_inputs['plate'] = Bint[7]
obs = Tensor(value_data, obs_inputs)
total_count_data = value_data.sum(-1)
total_count = Tensor(total_count_data, obs_inputs)
latent = dist.Dirichlet(concentration, value=prior)
conditional = dist.Multinomial(probs=prior, total_count=total_count, value=obs)
p = latent + conditional.reduce(ops.add, 'plate')
reduced = p.reduce(ops.logaddexp, 'prior')
assert isinstance(reduced, Tensor)

_assert_conjugate_density_ok(latent, conditional, obs)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('size', [2, 4, 5], ids=str)
def test_dirichlet_multinomial_conjugate(batch_shape, size):
Expand Down