Skip to content

Open all NUTS kwargs to user for sampling with Numpyro  #6020

@jhrcook

Description

@jhrcook

Description of the problem

In the current implementation of pymc.sampling_jax.sample_numpyro_nuts(), only some of the arguments for the Numpyro NUTS sampler are available to the user because a few are preset in the function:

if nuts_kwargs is None:
    nuts_kwargs = {}
nuts_kernel = NUTS(
    potential_fn=logp_fn,
    target_accept_prob=target_accept,
    adapt_step_size=True,
    adapt_mass_matrix=True,
    dense_mass=False,
    **nuts_kwargs,
)

I think it would be useful to make all of these arguments available to the user through the PyMC interface. Is there a reason they aren’t? If not, would this be a PR I could make?

For clarity, here is a simple example of what currently happens if a PyMC user wants to set dense_mass=True:

with pm.Model():
    y = pm.Normal("y", 0, 1)
    trace = pymc.sampling_jax.sample_numpyro_nuts(nuts_kwargs={"dense_mass": True})
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [24], in <cell line: 1>()
      1 with pm.Model():
      2     y = pm.Normal("y", 0, 1)
----> 3     trace = pymc.sampling_jax.sample_numpyro_nuts(nuts_kwargs={"dense_mass": True})

File /usr/local/Caskroom/miniconda/base/envs/speclet/lib/python3.10/site-packages/pymc/sampling_jax.py:487, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    485 if nuts_kwargs is None:
    486     nuts_kwargs = {}
--> 487 nuts_kernel = NUTS(
    488     potential_fn=logp_fn,
    489     target_accept_prob=target_accept,
    490     adapt_step_size=True,
    491     adapt_mass_matrix=True,
    492     # dense_mass=False,
    493     **nuts_kwargs,
    494 )
    496 pmap_numpyro = MCMC(
    497     nuts_kernel,
    498     num_warmup=tune,
   (...)
    503     progress_bar=progress_bar,
    504 )
    506 tic2 = datetime.now()

TypeError: numpyro.infer.hmc.NUTS() got multiple values for keyword argument 'dense_mass'

Versions and main components

  • PyMC/PyMC3 Version: 4.0.1
  • Aesara/Theano Version: 2.7.3
  • numpyro Version: 0.9.2
  • jax Version: 0.3.14
  • jaxlib Version: 0.3.14
  • Python Version: 3.10.5
  • Operating system: macOS 12.4
  • How did you install PyMC/PyMC3: conda

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions