Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,73 @@ def sample_jax_nuts(
compute_convergence_checks: bool = True,
nuts_sampler: Literal["numpyro", "blackjax"],
) -> az.InferenceData:
"""
Draw samples from the posterior using a jax NUTS method.

Parameters
----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by
default.
tune : int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument. Tuned samples are discarded.
chains : int, default 4
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, RandomState or Generator, optional
Random seed used by the sampling steps.
initvals: StartDict or Sequence[Optional[StartDict]], optional
Initial values for random variables provided as a dictionary (or sequence of
dictionaries) mapping the random variable (by name or reference) to desired
starting values.
jitter: bool, default True
If True, add jitter to initial points.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside
a ``with`` model context, it defaults to that model, otherwise the model must be
passed explicitly.
var_names : sequence of str, optional
Names of variables for which to compute the posterior samples. Defaults to all
variables in the posterior.
nuts_kwargs : dict, optional
Keyword arguments for the underlying nuts sampler
progressbar : bool, default True
If True, display a progressbar while sampling
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "parallel", and
"vectorized".
postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
Specify how postprocessing should be computed. gpu or cpu
postprocessing_vectorize : Literal["vmap", "scan"], default "scan"
How to vectorize the postprocessing: vmap or sequential scan
postprocessing_chunks : None
This argument is deprecated
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
likelihood should not be included in the returned object. Values for
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
``dims`` are provided, they are used to update the inferred dictionaries.
compute_convergence_checks : bool, default True
If True, compute ess and rhat values and warn if they indicate potential sampling issues.
nuts_sampler : Literal["numpyro", "blackjax"]
Nuts sampler library to use - do not change - use sample_numpyro_nuts or
sample_blackjax_nuts as appropriate

Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together
with their respective sample stats and pointwise log likeihood values (unless
skipped with ``idata_kwargs``).
"""
if postprocessing_chunks is not None:
import warnings

Expand Down