Skip to content

Cannot pass dims to idata_kwargs parameter in sample_numpyro_nuts #5932

@jhrcook

Description

@jhrcook

A minimal, self-contained, and reproducible example.

import pymc as pm
import pymc.sampling_jax
import arviz as az

coords = {"param": ["a", "b"]}
with pm.Model(coords=coords) as model:
    chol, corr, stds = pm.LKJCholeskyCov(
        "chol", n=2, eta=2.0, sd_dist=pm.Gamma.dist(2, 1)
    )
    trace = pm.sample(
        10, tune=10, cores=2, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
    )

with model:
    trace = pymc.sampling_jax.sample_numpyro_nuts(
        10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
    )

Full traceback.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [5], in <cell line: 1>()
      1 with model:
----> 2     trace = pymc.sampling_jax.sample_numpyro_nuts(10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}})

File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:564, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    559 attrs = {
    560     "sampling_time": (tic3 - tic2).total_seconds(),
    561 }
    563 posterior = mcmc_samples
--> 564 az_trace = az.from_dict(
    565     posterior=posterior,
    566     log_likelihood=log_likelihood,
    567     observed_data=find_observations(model),
    568     constant_data=find_constants(model),
    569     sample_stats=_sample_stats_to_xarray(pmap_numpyro),
    570     coords=coords,
    571     dims=dims,
    572     attrs=make_attrs(attrs, library=numpyro),
    573     **idata_kwargs,
    574 )
    576 return az_trace

TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'dims'

Details.

I am trying to add coordinate labels to the variables produced by LKJCholeskyCov() by passing the dimensions to the idata_kwargs parameter of the PyMC sampling function (as demonstrated in Oriol Abril's blog post (https://oriolabrilpla.cat/python/arviz/pymc/xarray/2022/06/07/pymc-arviz.html#2nd-example:-radon-multilevel-model). The method works when sampling with the default PyMC sampler, but fails with the Numpyro JAX backend. I have provided a full example in the code above, but please let me know if more details are needed

Versions and main components

  • PyMC Version: 4.0.1
  • Aesara Version: 2.7.3
  • Python Version: 3.10.5
  • Other relevant libraries: jax=v0.3.14, jaxlib=0.3.10, numpyro=0.9.2
  • Operating system: macOS Monterey (v12.4)
  • How did you install PyMC/PyMC3: conda

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions