diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index b4e104eb6f..d9cf85130f 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -254,7 +254,9 @@ def sample_blackjax_nuts( idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log likelihood should - not be included in the returned object. + not be included in the returned object. Values for ``observed_data``, ``constant_data``, + ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided + in ``idata_kwargs``. Returns ------- @@ -365,16 +367,17 @@ def sample_blackjax_nuts( } posterior = mcmc_samples - az_trace = az.from_dict( - posterior=posterior, + # Use 'partial' to set default arguments before passing 'idata_kwargs' + to_trace = partial( + az.from_dict, log_likelihood=log_likelihood, observed_data=find_observations(model), constant_data=find_constants(model), coords=coords, dims=dims, attrs=make_attrs(attrs, library=blackjax), - **idata_kwargs, ) + az_trace = to_trace(posterior=posterior, **idata_kwargs) return az_trace @@ -431,7 +434,9 @@ def sample_numpyro_nuts( idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log likelihood should - not be included in the returned object. + not be included in the returned object. Values for ``observed_data``, ``constant_data``, + ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided + in ``idata_kwargs``. nuts_kwargs: dict, optional Keyword arguments for :func:`numpyro.infer.NUTS`. @@ -560,8 +565,9 @@ def sample_numpyro_nuts( } posterior = mcmc_samples - az_trace = az.from_dict( - posterior=posterior, + # Use 'partial' to set default arguments before passing 'idata_kwargs' + to_trace = partial( + az.from_dict, log_likelihood=log_likelihood, observed_data=find_observations(model), constant_data=find_constants(model), @@ -569,7 +575,7 @@ def sample_numpyro_nuts( coords=coords, dims=dims, attrs=make_attrs(attrs, library=numpyro), - **idata_kwargs, ) + az_trace = to_trace(posterior=posterior, **idata_kwargs) return az_trace diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 4e6e3c627b..7daa168eb4 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -153,6 +153,16 @@ def test_get_jaxified_logp(): assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0)))) +@pytest.fixture +def model_test_idata_kwargs(scope="module"): + with pm.Model(coords={"x_coord": ["a", "b"], "x_coord2": [1, 2]}) as m: + x = pm.Normal("x", shape=(2,), dims=["x_coord"]) + y = pm.Normal("y", x, observed=[0, 0]) + pm.ConstantData("constantdata", [1, 2, 3]) + pm.MutableData("mutabledata", 2) + return m + + @pytest.mark.parametrize( "sampler", [ @@ -165,15 +175,17 @@ def test_get_jaxified_logp(): [ dict(), dict(log_likelihood=False), + # Overwrite models coords + dict(coords={"x_coord": ["x1", "x2"]}), + # Overwrite dims from dist specification in model + dict(dims={"x": ["x_coord2"]}), + # Overwrite both coords and dims + dict(coords={"x_coord3": ["A", "B"]}, dims={"x": ["x_coord3"]}), ], ) @pytest.mark.parametrize("postprocessing_backend", [None, "cpu"]) -def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend): - with pm.Model() as m: - x = pm.Normal("x") - y = pm.Normal("y", x, observed=0) - pm.ConstantData("constantdata", [1, 2, 3]) - pm.MutableData("mutabledata", 2) +def test_idata_kwargs(model_test_idata_kwargs, sampler, idata_kwargs, postprocessing_backend): + with model_test_idata_kwargs: idata = sampler( tune=50, draws=50, @@ -189,6 +201,12 @@ def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend): else: assert "log_likelihood" not in idata + x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.RV_dims)["x"][0] + assert idata.posterior.x.dims[-1] == x_dim_expected + + x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[x_dim_expected] + assert list(x_coords_expected) == list(idata.posterior.x.coords[x_dim_expected].values) + def test_get_batched_jittered_initial_points(): with pm.Model() as model: