Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ def sample_blackjax_nuts(
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
not be included in the returned object.
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
in ``idata_kwargs``.

Returns
-------
Expand Down Expand Up @@ -365,16 +367,17 @@ def sample_blackjax_nuts(
}

posterior = mcmc_samples
az_trace = az.from_dict(
posterior=posterior,
# Use 'partial' to set default arguments before passing 'idata_kwargs'
to_trace = partial(
az.from_dict,
log_likelihood=log_likelihood,
observed_data=find_observations(model),
constant_data=find_constants(model),
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
**idata_kwargs,
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)

return az_trace

Expand Down Expand Up @@ -431,7 +434,9 @@ def sample_numpyro_nuts(
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
not be included in the returned object.
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
in ``idata_kwargs``.
nuts_kwargs: dict, optional
Keyword arguments for :func:`numpyro.infer.NUTS`.

Expand Down Expand Up @@ -560,16 +565,17 @@ def sample_numpyro_nuts(
}

posterior = mcmc_samples
az_trace = az.from_dict(
posterior=posterior,
# Use 'partial' to set default arguments before passing 'idata_kwargs'
to_trace = partial(
az.from_dict,
log_likelihood=log_likelihood,
observed_data=find_observations(model),
constant_data=find_constants(model),
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=numpyro),
**idata_kwargs,
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)

return az_trace
30 changes: 24 additions & 6 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@ def test_get_jaxified_logp():
assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0))))


@pytest.fixture
def model_test_idata_kwargs(scope="module"):
with pm.Model(coords={"x_coord": ["a", "b"], "x_coord2": [1, 2]}) as m:
x = pm.Normal("x", shape=(2,), dims=["x_coord"])
y = pm.Normal("y", x, observed=[0, 0])
pm.ConstantData("constantdata", [1, 2, 3])
pm.MutableData("mutabledata", 2)
return m


@pytest.mark.parametrize(
"sampler",
[
Expand All @@ -165,15 +175,17 @@ def test_get_jaxified_logp():
[
dict(),
dict(log_likelihood=False),
# Overwrite models coords
dict(coords={"x_coord": ["x1", "x2"]}),
# Overwrite dims from dist specification in model
dict(dims={"x": ["x_coord2"]}),
# Overwrite both coords and dims
dict(coords={"x_coord3": ["A", "B"]}, dims={"x": ["x_coord3"]}),
],
)
@pytest.mark.parametrize("postprocessing_backend", [None, "cpu"])
def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
with pm.Model() as m:
x = pm.Normal("x")
y = pm.Normal("y", x, observed=0)
pm.ConstantData("constantdata", [1, 2, 3])
pm.MutableData("mutabledata", 2)
def test_idata_kwargs(model_test_idata_kwargs, sampler, idata_kwargs, postprocessing_backend):
with model_test_idata_kwargs:
idata = sampler(
tune=50,
draws=50,
Expand All @@ -189,6 +201,12 @@ def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
else:
assert "log_likelihood" not in idata

x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.RV_dims)["x"][0]
assert idata.posterior.x.dims[-1] == x_dim_expected

x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[x_dim_expected]
assert list(x_coords_expected) == list(idata.posterior.x.coords[x_dim_expected].values)


def test_get_batched_jittered_initial_points():
with pm.Model() as model:
Expand Down