-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
A minimal, self-contained, and reproducible example.
import pymc as pm
import pymc.sampling_jax
import arviz as az
coords = {"param": ["a", "b"]}
with pm.Model(coords=coords) as model:
chol, corr, stds = pm.LKJCholeskyCov(
"chol", n=2, eta=2.0, sd_dist=pm.Gamma.dist(2, 1)
)
trace = pm.sample(
10, tune=10, cores=2, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
)
with model:
trace = pymc.sampling_jax.sample_numpyro_nuts(
10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
)
Full traceback.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [5], in <cell line: 1>()
1 with model:
----> 2 trace = pymc.sampling_jax.sample_numpyro_nuts(10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}})
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:564, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
559 attrs = {
560 "sampling_time": (tic3 - tic2).total_seconds(),
561 }
563 posterior = mcmc_samples
--> 564 az_trace = az.from_dict(
565 posterior=posterior,
566 log_likelihood=log_likelihood,
567 observed_data=find_observations(model),
568 constant_data=find_constants(model),
569 sample_stats=_sample_stats_to_xarray(pmap_numpyro),
570 coords=coords,
571 dims=dims,
572 attrs=make_attrs(attrs, library=numpyro),
573 **idata_kwargs,
574 )
576 return az_trace
TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'dims'
Details.
I am trying to add coordinate labels to the variables produced by LKJCholeskyCov()
by passing the dimensions to the idata_kwargs
parameter of the PyMC sampling function (as demonstrated in Oriol Abril's blog post (https://oriolabrilpla.cat/python/arviz/pymc/xarray/2022/06/07/pymc-arviz.html#2nd-example:-radon-multilevel-model). The method works when sampling with the default PyMC sampler, but fails with the Numpyro JAX backend. I have provided a full example in the code above, but please let me know if more details are needed
Versions and main components
- PyMC Version: 4.0.1
- Aesara Version: 2.7.3
- Python Version: 3.10.5
- Other relevant libraries:
jax=v0.3.14
,jaxlib=0.3.10
,numpyro=0.9.2
- Operating system: macOS Monterey (v12.4)
- How did you install PyMC/PyMC3: conda
OriolAbril
Metadata
Metadata
Assignees
Labels
No labels