diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 42e7ee1564..01f8b0d502 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -47,6 +47,7 @@ from pymc.initial_point import StartDict from pymc.logprob.utils import CheckParameterValue from pymc.sampling.mcmc import _init_jitter +from pymc.stats.convergence import log_warnings, run_convergence_checks from pymc.util import ( RandomSeed, RandomState, @@ -157,25 +158,19 @@ def logp_fn_wrap(x): return logp_fn_wrap -# Adopted from arviz numpyro extractor -def _sample_stats_to_xarray(posterior): - """Extract sample_stats from NumPyro posterior.""" - rename_key = { - "potential_energy": "lp", - "adapt_state.step_size": "step_size", - "num_steps": "n_steps", - "accept_prob": "acceptance_rate", - } - data = {} - for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): - if isinstance(value, (dict, tuple)): - continue - name = rename_key.get(stat, stat) - value = value.copy() - data[name] = value - if stat == "num_steps": - data["tree_depth"] = np.log2(value).astype(int) + 1 - return data +def _get_log_likelihood( + model: Model, + samples, + backend: Optional[Literal["cpu", "gpu"]] = None, + postprocessing_vectorize: Literal["vmap", "scan"] = "scan", +) -> dict: + """Compute log-likelihood for all observations""" + elemwise_logp = model.logp(model.observed_RVs, sum=False) + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) + result = _postprocess_samples( + jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize + ) + return {v.name: r for v, r in zip(model.observed_RVs, result)} def _device_put(input, device: str): @@ -203,55 +198,6 @@ def _postprocess_samples( raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}") -def _blackjax_stats_to_dict(sample_stats, potential_energy) -> dict: - """Extract compatible stats from blackjax NUTS sampler - with PyMC/Arviz naming conventions. - - Parameters - ---------- - sample_stats: NUTSInfo - Blackjax NUTSInfo object containing sampler statistics - potential_energy: ArrayLike - Potential energy values of sampled positions. - - Returns - ------- - Dict[str, ArrayLike] - Dictionary of sampler statistics. - """ - rename_key = { - "is_divergent": "diverging", - "energy": "energy", - "num_trajectory_expansions": "tree_depth", - "num_integration_steps": "n_steps", - "acceptance_rate": "acceptance_rate", # naming here is - "acceptance_probability": "acceptance_rate", # depending on blackjax version - } - converted_stats = {} - converted_stats["lp"] = potential_energy - for old_name, new_name in rename_key.items(): - value = getattr(sample_stats, old_name, None) - if value is None: - continue - converted_stats[new_name] = value - return converted_stats - - -def _get_log_likelihood( - model: Model, - samples, - backend: Optional[Literal["cpu", "gpu"]] = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "scan", -) -> dict: - """Compute log-likelihood for all observations""" - elemwise_logp = model.logp(model.observed_RVs, sum=False) - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) - result = _postprocess_samples( - jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize - ) - return {v.name: r for v, r in zip(model.observed_RVs, result)} - - def _get_batched_jittered_initial_points( model: Model, chains: int, @@ -282,16 +228,6 @@ def _get_batched_jittered_initial_points( return [np.stack(init_state) for init_state in zip(*initial_points_values)] -def _update_coords_and_dims( - coords: dict[str, Any], dims: dict[str, Any], idata_kwargs: dict[str, Any] -) -> None: - """Update 'coords' and 'dims' dicts with values in 'idata_kwargs'.""" - if "coords" in idata_kwargs: - coords.update(idata_kwargs.pop("coords")) - if "dims" in idata_kwargs: - dims.update(idata_kwargs.pop("dims")) - - def _blackjax_inference_loop( seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs ): @@ -323,7 +259,6 @@ def _one_step(state, xs): if progress_bar: from blackjax.progress_bar import progress_bar_scan - logger.info("Sample with tuned parameters") one_step = jax.jit(progress_bar_scan(draws)(_one_step)) else: one_step = jax.jit(_one_step) @@ -334,24 +269,51 @@ def _one_step(state, xs): return states, infos -def sample_blackjax_nuts( - draws: int = 1000, - tune: int = 1000, - chains: int = 4, - target_accept: float = 0.8, - random_seed: Optional[RandomState] = None, - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, - jitter: bool = True, - model: Optional[Model] = None, - var_names: Optional[Sequence[str]] = None, - progress_bar: bool = False, - keep_untransformed: bool = False, - chain_method: str = "parallel", - postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "scan", - idata_kwargs: Optional[dict[str, Any]] = None, - adaptation_kwargs: Optional[dict[str, Any]] = None, - postprocessing_chunks=None, # deprecated +def _blackjax_stats_to_dict(sample_stats, potential_energy) -> dict: + """Extract compatible stats from blackjax NUTS sampler + with PyMC/Arviz naming conventions. + + Parameters + ---------- + sample_stats: NUTSInfo + Blackjax NUTSInfo object containing sampler statistics + potential_energy: ArrayLike + Potential energy values of sampled positions. + + Returns + ------- + Dict[str, ArrayLike] + Dictionary of sampler statistics. + """ + rename_key = { + "is_divergent": "diverging", + "energy": "energy", + "num_trajectory_expansions": "tree_depth", + "num_integration_steps": "n_steps", + "acceptance_rate": "acceptance_rate", # naming here is + "acceptance_probability": "acceptance_rate", # depending on blackjax version + } + converted_stats = {} + converted_stats["lp"] = potential_energy + for old_name, new_name in rename_key.items(): + value = getattr(sample_stats, old_name, None) + if value is None: + continue + converted_stats[new_name] = value + return converted_stats + + +def _sample_blackjax_nuts( + model: Model, + target_accept: float, + tune: int, + draws: int, + chains: int, + chain_method: Optional[str], + progressbar: bool, + random_seed: int, + initial_points, + nuts_kwargs, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. @@ -409,58 +371,20 @@ def sample_blackjax_nuts( with their respective sample stats and pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ - if postprocessing_chunks is not None: - import warnings - warnings.warn( - "postprocessing_chunks is deprecated due to being unstable, " - "using postprocessing_vectorize='scan' instead", - DeprecationWarning, - ) import blackjax - model = modelcontext(model) - - if var_names is None: - var_names = model.unobserved_value_vars - - vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) - - (random_seed,) = _get_seeds_per_chain(random_seed, 1) - - tic1 = datetime.now() - logger.info("Compiling...") - - init_params = _get_batched_jittered_initial_points( - model=model, - chains=chains, - initvals=initvals, - random_seed=random_seed, - jitter=jitter, - ) - - if chains == 1: - init_params = [np.stack(init_state) for init_state in zip(init_params)] - - logprob_fn = get_jaxified_logp(model) - - seed = jax.random.PRNGKey(random_seed) - keys = jax.random.split(seed, chains) - - if adaptation_kwargs is None: - adaptation_kwargs = {} - # Adapted from numpyro if chain_method == "parallel": map_fn = jax.pmap - if progress_bar: + if progressbar: import warnings warnings.warn( "BlackJax currently only display progress bar correctly under " "`chain_method == 'vectorized'`. Setting `progressbar=False`." ) - progress_bar = False + progressbar = False elif chain_method == "vectorized": map_fn = jax.vmap else: @@ -468,102 +392,116 @@ def sample_blackjax_nuts( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' ) - adaptation_kwargs["progress_bar"] = progress_bar + if chains == 1: + initial_points = [np.stack(init_state) for init_state in zip(initial_points)] + + logprob_fn = get_jaxified_logp(model) + + seed = jax.random.PRNGKey(random_seed) + keys = jax.random.split(seed, chains) + + nuts_kwargs["progress_bar"] = progressbar get_posterior_samples = partial( _blackjax_inference_loop, logprob_fn=logprob_fn, tune=tune, draws=draws, target_accept=target_accept, - **adaptation_kwargs, + **nuts_kwargs, ) - tic2 = datetime.now() - logger.info(f"Compilation time = {tic2 - tic1}") - - logger.info("Sampling...") - - states, stats = map_fn(get_posterior_samples)(keys, init_params) + states, stats = map_fn(get_posterior_samples)(keys, initial_points) raw_mcmc_samples = states.position potential_energy = states.logdensity.block_until_ready() - tic3 = datetime.now() - logger.info(f"Sampling time = {tic3 - tic2}") + sample_stats = _blackjax_stats_to_dict(stats, potential_energy) - logger.info("Transforming variables...") - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = _postprocess_samples( - jax_fn, - raw_mcmc_samples, - postprocessing_backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, - ) - mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} - mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy) - tic4 = datetime.now() - logger.info(f"Transformation time = {tic4 - tic3}") + return raw_mcmc_samples, sample_stats, blackjax - if idata_kwargs is None: - idata_kwargs = {} - else: - idata_kwargs = idata_kwargs.copy() - if idata_kwargs.pop("log_likelihood", False): - tic5 = datetime.now() - logger.info("Computing Log Likelihood...") - log_likelihood = _get_log_likelihood( - model, - raw_mcmc_samples, - backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, - ) - tic6 = datetime.now() - logger.info(f"Log Likelihood time = {tic6 - tic5}") - else: - log_likelihood = None - - attrs = { - "sampling_time": (tic3 - tic2).total_seconds(), +# Adopted from arviz numpyro extractor +def _numpyro_stats_to_dict(posterior): + """Extract sample_stats from NumPyro posterior.""" + rename_key = { + "potential_energy": "lp", + "adapt_state.step_size": "step_size", + "num_steps": "n_steps", + "accept_prob": "acceptance_rate", } + data = {} + for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): + if isinstance(value, (dict, tuple)): + continue + name = rename_key.get(stat, stat) + value = value.copy() + data[name] = value + if stat == "num_steps": + data["tree_depth"] = np.log2(value).astype(int) + 1 + return data - coords, dims = coords_and_dims_for_inferencedata(model) - # Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs' - # and drop keys 'coords' and 'dims' from 'idata_kwargs' if present. - _update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs) - # Use 'partial' to set default arguments before passing 'idata_kwargs' - to_trace = partial( - az.from_dict, - log_likelihood=log_likelihood, - observed_data=find_observations(model), - constant_data=find_constants(model), - sample_stats=mcmc_stats, - coords=coords, - dims=dims, - attrs=make_attrs(attrs, library=blackjax), - ) - az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs) - return az_trace +def _sample_numpyro_nuts( + model: Model, + target_accept: float, + tune: int, + draws: int, + chains: int, + chain_method: Optional[str], + progressbar: bool, + random_seed: int, + initial_points, + nuts_kwargs: dict[str, Any], +): + import numpyro + + from numpyro.infer import MCMC, NUTS + logp_fn = get_jaxified_logp(model, negative_logp=False) -def _numpyro_nuts_defaults() -> dict[str, Any]: - """Defaults parameters for Numpyro NUTS.""" - return { - "adapt_step_size": True, - "adapt_mass_matrix": True, - "dense_mass": False, - } + nuts_kwargs.setdefault("adapt_step_size", True) + nuts_kwargs.setdefault("adapt_mass_matrix", True) + nuts_kwargs.setdefault("dense_mass", False) + + nuts_kernel = NUTS( + potential_fn=logp_fn, + target_accept_prob=target_accept, + **nuts_kwargs, + ) + pmap_numpyro = MCMC( + nuts_kernel, + num_warmup=tune, + num_samples=draws, + num_chains=chains, + postprocess_fn=None, + chain_method=chain_method, + progress_bar=progressbar, + ) -def _update_numpyro_nuts_kwargs(nuts_kwargs: Optional[dict[str, Any]]) -> dict[str, Any]: - """Update default Numpyro NUTS parameters with new values.""" - nuts_kwargs_defaults = _numpyro_nuts_defaults() - if nuts_kwargs is not None: - nuts_kwargs_defaults.update(nuts_kwargs) - return nuts_kwargs_defaults + map_seed = jax.random.PRNGKey(random_seed) + if chains > 1: + map_seed = jax.random.split(map_seed, chains) + pmap_numpyro.run( + map_seed, + init_params=initial_points, + extra_fields=( + "num_steps", + "potential_energy", + "energy", + "adapt_state.step_size", + "accept_prob", + "diverging", + ), + ) + + raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) + sample_stats = _numpyro_stats_to_dict(pmap_numpyro) + return raw_mcmc_samples, sample_stats, numpyro -def sample_numpyro_nuts( + +def sample_jax_nuts( draws: int = 1000, + *, tune: int = 1000, chains: int = 4, target_accept: float = 0.8, @@ -572,77 +510,17 @@ def sample_numpyro_nuts( jitter: bool = True, model: Optional[Model] = None, var_names: Optional[Sequence[str]] = None, + nuts_kwargs: Optional[dict] = None, progressbar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", - idata_kwargs: Optional[dict] = None, - nuts_kwargs: Optional[dict] = None, postprocessing_chunks=None, + idata_kwargs: Optional[dict] = None, + compute_convergence_checks: bool = True, + nuts_sampler: Literal["numpyro", "blackjax"], ) -> az.InferenceData: - """ - Draw samples from the posterior using the NUTS method from the ``numpyro`` library. - - Parameters - ---------- - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. - tune : int, default 1000 - Number of iterations to tune. Samplers adjust the step sizes, scalings or - similar during tuning. Tuning samples will be drawn in addition to the number - specified in the ``draws`` argument. - chains : int, default 4 - The number of chains to sample. - target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher - values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, RandomState or Generator, optional - Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional - Initial values for random variables provided as a dictionary (or sequence of - dictionaries) mapping the random variable (by name or reference) to desired - starting values. - jitter: bool, default True - If True, add jitter to initial points. - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. - var_names : sequence of str, optional - Names of variables for which to compute the posterior samples. Defaults to all - variables in the posterior. - progressbar : bool, default True - Whether or not to display a progress bar in the command line. The bar shows the - percentage of completion, the sampling speed in samples per second (SPS), and - the estimated remaining time until completion ("expected time of arrival"; ETA). - keep_untransformed : bool, default False - Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "sequential", - "parallel", and "vectorized". - postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, - Specify how postprocessing should be computed. gpu or cpu - postprocessing_vectorize: Literal["vmap", "scan"], default "scan" - How to vectorize the postprocessing: vmap or sequential scan - idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as - value for the ``log_likelihood`` key to indicate that the pointwise log - likelihood should not be included in the returned object. Values for - ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from - the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and - ``dims`` are provided, they are used to update the inferred dictionaries. - nuts_kwargs: dict, optional - Keyword arguments for :func:`numpyro.infer.NUTS`. - - Returns - ------- - InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together - with their respective sample stats and pointwise log likeihood values (unless - skipped with ``idata_kwargs``). - """ if postprocessing_chunks is not None: import warnings @@ -651,23 +529,22 @@ def sample_numpyro_nuts( "using postprocessing_vectorize='scan' instead", DeprecationWarning, ) - import numpyro - - from numpyro.infer import MCMC, NUTS model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars + if nuts_kwargs is None: + nuts_kwargs = {} + else: + nuts_kwargs = nuts_kwargs.copy() + vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) (random_seed,) = _get_seeds_per_chain(random_seed, 1) - tic1 = datetime.now() - logger.info("Compiling...") - - init_params = _get_batched_jittered_initial_points( + initial_points = _get_batched_jittered_initial_points( model=model, chains=chains, initvals=initvals, @@ -675,53 +552,28 @@ def sample_numpyro_nuts( jitter=jitter, ) - logp_fn = get_jaxified_logp(model, negative_logp=False) - - nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs) - nuts_kernel = NUTS( - potential_fn=logp_fn, - target_accept_prob=target_accept, - **nuts_kwargs, - ) + if nuts_sampler == "numpyro": + sampler_fn = _sample_numpyro_nuts + elif nuts_sampler == "blackjax": + sampler_fn = _sample_blackjax_nuts + else: + raise ValueError(f"{nuts_sampler=} not recognized") - pmap_numpyro = MCMC( - nuts_kernel, - num_warmup=tune, - num_samples=draws, - num_chains=chains, - postprocess_fn=None, + tic1 = datetime.now() + raw_mcmc_samples, sample_stats, library = sampler_fn( + model=model, + target_accept=target_accept, + tune=tune, + draws=draws, + chains=chains, chain_method=chain_method, - progress_bar=progressbar, + progressbar=progressbar, + random_seed=random_seed, + initial_points=initial_points, + nuts_kwargs=nuts_kwargs, ) - tic2 = datetime.now() - logger.info(f"Compilation time = {tic2 - tic1}") - - logger.info("Sampling...") - - map_seed = jax.random.PRNGKey(random_seed) - if chains > 1: - map_seed = jax.random.split(map_seed, chains) - - pmap_numpyro.run( - map_seed, - init_params=init_params, - extra_fields=( - "num_steps", - "potential_energy", - "energy", - "adapt_state.step_size", - "accept_prob", - "diverging", - ), - ) - - raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) - - tic3 = datetime.now() - logger.info(f"Sampling time = {tic3 - tic2}") - logger.info("Transforming variables...") jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) result = _postprocess_samples( jax_fn, @@ -731,48 +583,51 @@ def sample_numpyro_nuts( ) mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} - tic4 = datetime.now() - logger.info(f"Transformation time = {tic4 - tic3}") - if idata_kwargs is None: idata_kwargs = {} else: idata_kwargs = idata_kwargs.copy() if idata_kwargs.pop("log_likelihood", False): - tic5 = datetime.now() - logger.info("Computing Log Likelihood...") log_likelihood = _get_log_likelihood( model, raw_mcmc_samples, backend=postprocessing_backend, postprocessing_vectorize=postprocessing_vectorize, ) - tic6 = datetime.now() - logger.info( - f"Log Likelihood time = {tic6 - tic5}", - ) else: log_likelihood = None attrs = { - "sampling_time": (tic3 - tic2).total_seconds(), + "sampling_time": (tic2 - tic1).total_seconds(), } coords, dims = coords_and_dims_for_inferencedata(model) # Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs' # and drop keys 'coords' and 'dims' from 'idata_kwargs' if present. - _update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs) + if "coords" in idata_kwargs: + coords.update(idata_kwargs.pop("coords")) + if "dims" in idata_kwargs: + dims.update(idata_kwargs.pop("dims")) # Use 'partial' to set default arguments before passing 'idata_kwargs' to_trace = partial( az.from_dict, log_likelihood=log_likelihood, observed_data=find_observations(model), constant_data=find_constants(model), - sample_stats=_sample_stats_to_xarray(pmap_numpyro), + sample_stats=sample_stats, coords=coords, dims=dims, - attrs=make_attrs(attrs, library=numpyro), + attrs=make_attrs(attrs, library=library), ) az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs) + + if compute_convergence_checks: + warns = run_convergence_checks(az_trace, model) + log_warnings(warns) + return az_trace + + +sample_numpyro_nuts = partial(sample_jax_nuts, nuts_sampler="numpyro") +sample_blackjax_nuts = partial(sample_jax_nuts, nuts_sampler="blackjax") diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index dcb461066c..7e21d606f9 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -256,7 +256,7 @@ def all_continuous(vars): def _sample_external_nuts( - sampler: str, + sampler: Literal["nutpie", "numpyro", "blackjax"], draws: int, tune: int, chains: int, @@ -337,10 +337,10 @@ def _sample_external_nuts( ) return idata - elif sampler == "numpyro": + elif sampler in ("numpyro", "blackjax"): import pymc.sampling.jax as pymc_jax - idata = pymc_jax.sample_numpyro_nuts( + idata = pymc_jax.sample_jax_nuts( draws=draws, tune=tune, chains=chains, @@ -349,23 +349,7 @@ def _sample_external_nuts( initvals=initvals, model=model, progressbar=progressbar, - idata_kwargs=idata_kwargs, - **nuts_sampler_kwargs, - ) - return idata - - elif sampler == "blackjax": - import pymc.sampling.jax as pymc_jax - - idata = pymc_jax.sample_blackjax_nuts( - draws=draws, - tune=tune, - chains=chains, - target_accept=target_accept, - random_seed=random_seed, - initvals=initvals, - model=model, - progress_bar=progressbar, + nuts_sampler=sampler, idata_kwargs=idata_kwargs, **nuts_sampler_kwargs, ) @@ -387,7 +371,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, - nuts_sampler: str = "pymc", + nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", jitter_max_retries: int = 10, @@ -416,7 +400,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, - nuts_sampler: str = "pymc", + nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", jitter_max_retries: int = 10, @@ -445,7 +429,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, - nuts_sampler: str = "pymc", + nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", jitter_max_retries: int = 10, diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index 470b79f808..aee6df767c 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -62,20 +62,28 @@ class SamplerWarning: def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWarning]: + warnings: list[SamplerWarning] = [] + if not hasattr(idata, "posterior"): msg = "No posterior samples. Unable to run convergence checks" warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) - return [warn] + warnings.append(warn) + return warnings + + warnings += warn_divergences(idata) + warnings += warn_treedepth(idata) if idata["posterior"].sizes["draw"] < 100: msg = "The number of samples is too small to check convergence reliably." warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) - return [warn] + warnings.append(warn) + return warnings if idata["posterior"].sizes["chain"] == 1: msg = "Only one chain was sampled, this makes it impossible to run some convergence checks" warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") - return [warn] + warnings.append(warn) + return warnings elif idata["posterior"].sizes["chain"] < 4: msg = ( @@ -83,9 +91,8 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar "convergence diagnostics" ) warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") - return [warn] + warnings.append(warn) - warnings: list[SamplerWarning] = [] valid_name = [rv.name for rv in model.free_RVs + model.deterministics] varnames = [] for rv in model.free_RVs: @@ -99,7 +106,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar ess = arviz.ess(idata, var_names=varnames) rhat = arviz.rhat(idata, var_names=varnames) - warnings = [] rhat_max = max(val.max() for val in rhat.values()) if rhat_max > 1.01: msg = ( @@ -121,9 +127,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess) warnings.append(warn) - warnings += warn_divergences(idata) - warnings += warn_treedepth(idata) - return warnings diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 57716c7d8f..d8d0cae246 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import re import warnings from typing import Any, Callable, Optional @@ -35,9 +37,7 @@ from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, - _numpyro_nuts_defaults, _replace_shared_variables, - _update_numpyro_nuts_kwargs, get_jaxified_graph, get_jaxified_logp, sample_blackjax_nuts, @@ -400,29 +400,6 @@ def test_seeding(chains, random_seed, sampler): assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1)) -@pytest.mark.parametrize( - "nuts_kwargs", - [ - {"adapt_step_size": False}, - {"adapt_mass_matrix": True}, - {"dense_mass": True}, - {"adapt_step_size": False, "adapt_mass_matrix": True, "dense_mass": True}, - {"fake-key": "fake-value"}, - ], -) -def test_update_numpyro_nuts_kwargs(nuts_kwargs: dict[str, Any]): - original_kwargs = nuts_kwargs.copy() - new_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs) - - # Maintains original key-value pairs. - for k, v in original_kwargs.items(): - assert new_kwargs[k] == v - - for k, v in _numpyro_nuts_defaults().items(): - if k not in original_kwargs: - assert new_kwargs[k] == v - - @mock.patch("numpyro.infer.MCMC") def test_numpyro_nuts_kwargs_are_used(mocked: mock.MagicMock): mocked.side_effect = MCMC @@ -512,3 +489,17 @@ def test_sample_partially_observed(): assert idata.observed_data["x_observed"].shape == (2,) assert idata.posterior["x_unobserved"].shape == (1, 10, 1) assert idata.posterior["x"].shape == (1, 10, 3) + + +@pytest.mark.parametrize("nuts_sampler", ("numpyro", "blackjax")) +def test_convergence_warnings(caplog, nuts_sampler): + with pm.Model() as m: + # Model that should diverge + sigma = pm.Normal("sigma", initval=3, transform=None) + pm.Normal("obs", mu=0, sigma=sigma, observed=[0.99, 1.0, 1.01]) + + with caplog.at_level(logging.WARNING, logger="pymc"): + pm.sample(nuts_sampler=nuts_sampler, random_seed=581) + + [record] = caplog.records + assert re.match(r"There were \d+ divergences after tuning", record.message)