Skip to content
82 changes: 52 additions & 30 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings

from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from pymc.initial_point import StartDict
from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter
Expand Down Expand Up @@ -382,69 +382,95 @@ def sample_blackjax_nuts(
return az_trace


def _numpyro_nuts_defaults() -> Dict[str, Any]:
"""Defaults parameters for Numpyro NUTS."""
return {
"adapt_step_size": True,
"adapt_mass_matrix": True,
"dense_mass": False,
}


def _update_numpyro_nuts_kwargs(nuts_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Update default Numpyro NUTS parameters with new values."""
nuts_kwargs_defaults = _numpyro_nuts_defaults()
if nuts_kwargs is not None:
nuts_kwargs_defaults.update(nuts_kwargs)
return nuts_kwargs_defaults


def sample_numpyro_nuts(
draws: int = 1000,
tune: int = 1000,
chains: int = 4,
target_accept: float = 0.8,
random_seed: RandomSeed = None,
random_seed: Optional[RandomSeed] = None,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
model: Optional[Model] = None,
var_names=None,
var_names: Optional[Sequence[str]] = None,
progress_bar: bool = True,
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: str = None,
postprocessing_backend: Optional[str] = None,
idata_kwargs: Optional[Dict] = None,
nuts_kwargs: Optional[Dict] = None,
):
) -> az.InferenceData:
"""
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.

Parameters
----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by default.
The number of samples to draw. The number of tuned samples are discarded by
default.
tune : int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number specified in
the ``draws`` argument.
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument.
chains : int, default 4
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher values like
0.9 or 0.95 often work better for problematic posteriors.
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, RandomState or Generator, optional
Random seed used by the sampling steps.
initvals: StartDict or Sequence[Optional[StartDict]], optional
Initial values for random variables provided as a dictionary (or sequence of
dictionaries) mapping the random variable (by name or reference) to desired
starting values.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
context, it defaults to that model, otherwise the model must be passed explicitly.
var_names : iterable of str, optional
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
Model to sample from. The model needs to have free random variables. When inside
a ``with`` model context, it defaults to that model, otherwise the model must be
passed explicitly.
var_names : sequence of str, optional
Names of variables for which to compute the posterior samples. Defaults to all
variables in the posterior.
progress_bar : bool, default True
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
time until completion ("expected time of arrival"; ETA).
Whether or not to display a progress bar in the command line. The bar shows the
percentage of completion, the sampling speed in samples per second (SPS), and
the estimated remaining time until completion ("expected time of arrival"; ETA).
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples. Defaults to False.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
Specify how samples should be drawn. The choices include "sequential",
"parallel", and "vectorized".
postprocessing_backend : Optional[str]
Specify how postprocessing should be computed. gpu or cpu
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
in ``idata_kwargs``.
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
likelihood should not be included in the returned object. Values for
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
the ``model`` argument if not provided in ``idata_kwargs``.
nuts_kwargs: dict, optional
Keyword arguments for :func:`numpyro.infer.NUTS`.

Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
pointwise log likeihood values (unless skipped with ``idata_kwargs``).
ArviZ ``InferenceData`` object that contains the posterior samples, together
with their respective sample stats and pointwise log likeihood values (unless
skipped with ``idata_kwargs``).
"""

import numpyro
Expand Down Expand Up @@ -486,14 +512,10 @@ def sample_numpyro_nuts(

logp_fn = get_jaxified_logp(model, negative_logp=False)

if nuts_kwargs is None:
nuts_kwargs = {}
nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
nuts_kernel = NUTS(
potential_fn=logp_fn,
target_accept_prob=target_accept,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
**nuts_kwargs,
)

Expand Down
49 changes: 49 additions & 0 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Dict

import aesara
import aesara.tensor as at
import arviz as az
import jax
import numpy as np
import pytest
Expand All @@ -12,7 +15,9 @@
from pymc.sampling_jax import (
_get_batched_jittered_initial_points,
_get_log_likelihood,
_numpyro_nuts_defaults,
_replace_shared_variables,
_update_numpyro_nuts_kwargs,
get_jaxified_graph,
get_jaxified_logp,
sample_blackjax_nuts,
Expand Down Expand Up @@ -270,3 +275,47 @@ def test_seeding(chains, random_seed, sampler):
if chains > 1:
assert np.all(result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1))
assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1))


@pytest.mark.parametrize(
"nuts_kwargs",
[
{"adapt_step_size": False},
{"adapt_mass_matrix": True},
{"dense_mass": True},
{"adapt_step_size": False, "adapt_mass_matrix": True, "dense_mass": True},
{"fake-key": "fake-value"},
],
)
def test_update_numpyro_nuts_kwargs(nuts_kwargs: Dict[str, Any]):
original_kwargs = nuts_kwargs.copy()
new_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)

# Maintains original key-value pairs.
for k, v in original_kwargs.items():
assert new_kwargs[k] == v

for k, v in _numpyro_nuts_defaults().items():
if k not in original_kwargs:
assert new_kwargs[k] == v


@pytest.mark.parametrize(
"nuts_kwargs",
[
{"adapt_step_size": False},
{"adapt_mass_matrix": True},
{"dense_mass": True},
{"adapt_step_size": False, "adapt_mass_matrix": True, "dense_mass": True},
{"adapt_step_size": False, "step_size": 0.13},
],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to monkey patch? If it's not possible I think we can survive without explicit tests. This one seems too inneficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll give it shot. I agree that the current method I used is inefficient and isn't really testing the desired behavior.

def test_numpyro_nuts_kwargs_are_used(nuts_kwargs: Dict[str, Any]):
with pm.Model():
pm.Normal("a")
trace = sample_numpyro_nuts(10, tune=10, chains=1, nuts_kwargs=nuts_kwargs)

assert isinstance(trace, az.InferenceData) # to help IDE
assert hasattr(trace, "sample_stats")
if "step_size" in nuts_kwargs:
assert np.allclose(trace.sample_stats["step_size"], nuts_kwargs["step_size"])