-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
Description
Description of the problem
In the current implementation of pymc.sampling_jax.sample_numpyro_nuts()
, only some of the arguments for the Numpyro NUTS sampler are available to the user because a few are preset in the function:
if nuts_kwargs is None:
nuts_kwargs = {}
nuts_kernel = NUTS(
potential_fn=logp_fn,
target_accept_prob=target_accept,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
**nuts_kwargs,
)
I think it would be useful to make all of these arguments available to the user through the PyMC interface. Is there a reason they aren’t? If not, would this be a PR I could make?
For clarity, here is a simple example of what currently happens if a PyMC user wants to set dense_mass=True
:
with pm.Model():
y = pm.Normal("y", 0, 1)
trace = pymc.sampling_jax.sample_numpyro_nuts(nuts_kwargs={"dense_mass": True})
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [24], in <cell line: 1>()
1 with pm.Model():
2 y = pm.Normal("y", 0, 1)
----> 3 trace = pymc.sampling_jax.sample_numpyro_nuts(nuts_kwargs={"dense_mass": True})
File /usr/local/Caskroom/miniconda/base/envs/speclet/lib/python3.10/site-packages/pymc/sampling_jax.py:487, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
485 if nuts_kwargs is None:
486 nuts_kwargs = {}
--> 487 nuts_kernel = NUTS(
488 potential_fn=logp_fn,
489 target_accept_prob=target_accept,
490 adapt_step_size=True,
491 adapt_mass_matrix=True,
492 # dense_mass=False,
493 **nuts_kwargs,
494 )
496 pmap_numpyro = MCMC(
497 nuts_kernel,
498 num_warmup=tune,
(...)
503 progress_bar=progress_bar,
504 )
506 tic2 = datetime.now()
TypeError: numpyro.infer.hmc.NUTS() got multiple values for keyword argument 'dense_mass'
Versions and main components
- PyMC/PyMC3 Version: 4.0.1
- Aesara/Theano Version: 2.7.3
- numpyro Version: 0.9.2
- jax Version: 0.3.14
- jaxlib Version: 0.3.14
- Python Version: 3.10.5
- Operating system: macOS 12.4
- How did you install PyMC/PyMC3: conda