-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add blas_cores argument to pm.sample #7318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,13 +14,14 @@ | |
|
|
||
| """Functions for MCMC sampling.""" | ||
|
|
||
| import contextlib | ||
| import logging | ||
| import pickle | ||
| import sys | ||
| import time | ||
| import warnings | ||
|
|
||
| from collections.abc import Iterator, Mapping, Sequence | ||
| from collections.abc import Callable, Iterator, Mapping, Sequence | ||
| from typing import ( | ||
| Any, | ||
| Literal, | ||
|
|
@@ -37,6 +38,7 @@ | |
| from rich.console import Console | ||
| from rich.progress import Progress | ||
| from rich.theme import Theme | ||
| from threadpoolctl import threadpool_limits | ||
| from typing_extensions import Protocol | ||
|
|
||
| import pymc as pm | ||
|
|
@@ -396,6 +398,7 @@ def sample( | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | ||
| callback=None, | ||
| mp_ctx=None, | ||
| blas_cores: int | None | Literal["auto"] = "auto", | ||
| **kwargs, | ||
| ) -> InferenceData: ... | ||
|
|
||
|
|
@@ -427,6 +430,7 @@ def sample( | |
| callback=None, | ||
| mp_ctx=None, | ||
| model: Model | None = None, | ||
| blas_cores: int | None | Literal["auto"] = "auto", | ||
| **kwargs, | ||
| ) -> MultiTrace: ... | ||
|
|
||
|
|
@@ -456,6 +460,7 @@ def sample( | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | ||
| callback=None, | ||
| mp_ctx=None, | ||
| blas_cores: int | None | Literal["auto"] = "auto", | ||
| model: Model | None = None, | ||
| **kwargs, | ||
| ) -> InferenceData | MultiTrace: | ||
|
|
@@ -499,6 +504,13 @@ def sample( | |
| Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. | ||
| This requires the chosen sampler to be installed. | ||
| All samplers, except "pymc", require the full model to be continuous. | ||
| blas_cores: int or "auto" or None, default = "auto" | ||
| The total number of threads blas and openmp functions should use during sampling. If set to None, | ||
| this will keep the default behavior of whatever blas implementation is used at runtime. | ||
| Setting it to "auto" will set it so that the total number of active blas threads is the | ||
| same as the `cores` argument. If set to an integer, the sampler will try to use that total | ||
| number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded | ||
| down. | ||
| initvals : optional, dict, array of dict | ||
| Dict or list of dicts with initial value strategies to use instead of the defaults from | ||
| `Model.initial_values`. The keys should be names of transformed random variables. | ||
|
|
@@ -644,6 +656,37 @@ def sample( | |
| if chains is None: | ||
| chains = max(2, cores) | ||
|
|
||
| if blas_cores == "auto": | ||
| blas_cores = cores | ||
|
|
||
| cores = min(cores, chains) | ||
|
|
||
| if cores < 1: | ||
| raise ValueError("`cores` must be larger or equal to one") | ||
|
|
||
| if chains < 1: | ||
| raise ValueError("`chains` must be larger or equal to one") | ||
|
|
||
| if blas_cores is not None and blas_cores < 1: | ||
| raise ValueError("`blas_cores` must be larger or equal to one") | ||
|
||
|
|
||
| num_blas_cores_per_chain: int | None | ||
| joined_blas_limiter: Callable[[], Any] | ||
|
|
||
| if blas_cores is None: | ||
| joined_blas_limiter = contextlib.nullcontext | ||
| num_blas_cores_per_chain = None | ||
| elif isinstance(blas_cores, int): | ||
|
|
||
| def joined_blas_limiter(): | ||
| return threadpool_limits(limits=blas_cores) | ||
|
|
||
| num_blas_cores_per_chain = blas_cores // cores | ||
| else: | ||
| raise ValueError( | ||
| f"Invalid argument `blas_cores`, must be int, 'auto' or None: {blas_cores}" | ||
| ) | ||
|
|
||
| if random_seed == -1: | ||
| random_seed = None | ||
| random_seed_list = _get_seeds_per_chain(random_seed, chains) | ||
|
|
@@ -685,21 +728,22 @@ def sample( | |
| raise ValueError( | ||
| "Model can not be sampled with NUTS alone. Your model is probably not continuous." | ||
| ) | ||
| return _sample_external_nuts( | ||
| sampler=nuts_sampler, | ||
| draws=draws, | ||
| tune=tune, | ||
| chains=chains, | ||
| target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), | ||
| random_seed=random_seed, | ||
| initvals=initvals, | ||
| model=model, | ||
| var_names=var_names, | ||
| progressbar=progressbar, | ||
| idata_kwargs=idata_kwargs, | ||
| nuts_sampler_kwargs=nuts_sampler_kwargs, | ||
| **kwargs, | ||
| ) | ||
| with joined_blas_limiter(): | ||
| return _sample_external_nuts( | ||
| sampler=nuts_sampler, | ||
| draws=draws, | ||
| tune=tune, | ||
| chains=chains, | ||
| target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), | ||
| random_seed=random_seed, | ||
| initvals=initvals, | ||
| model=model, | ||
| var_names=var_names, | ||
| progressbar=progressbar, | ||
| idata_kwargs=idata_kwargs, | ||
| nuts_sampler_kwargs=nuts_sampler_kwargs, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| if isinstance(step, list): | ||
| step = CompoundStep(step) | ||
|
|
@@ -708,18 +752,19 @@ def sample( | |
| nuts_kwargs = kwargs.pop("nuts") | ||
| [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] | ||
| _log.info("Auto-assigning NUTS sampler...") | ||
| initial_points, step = init_nuts( | ||
| init=init, | ||
| chains=chains, | ||
| n_init=n_init, | ||
| model=model, | ||
| random_seed=random_seed_list, | ||
| progressbar=progressbar, | ||
| jitter_max_retries=jitter_max_retries, | ||
| tune=tune, | ||
| initvals=initvals, | ||
| **kwargs, | ||
| ) | ||
| with joined_blas_limiter(): | ||
| initial_points, step = init_nuts( | ||
| init=init, | ||
| chains=chains, | ||
| n_init=n_init, | ||
| model=model, | ||
| random_seed=random_seed_list, | ||
| progressbar=progressbar, | ||
| jitter_max_retries=jitter_max_retries, | ||
| tune=tune, | ||
| initvals=initvals, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| if initial_points is None: | ||
| # Time to draw/evaluate numeric start points for each chain. | ||
|
|
@@ -756,7 +801,8 @@ def sample( | |
| ) | ||
|
|
||
| sample_args = { | ||
| "draws": draws + tune, # FIXME: Why is tune added to draws? | ||
| # draws is now the total number of draws, including tuning | ||
| "draws": draws + tune, | ||
| "step": step, | ||
| "start": initial_points, | ||
| "traces": traces, | ||
|
|
@@ -772,6 +818,7 @@ def sample( | |
| } | ||
| parallel_args = { | ||
| "mp_ctx": mp_ctx, | ||
| "blas_cores": num_blas_cores_per_chain, | ||
| } | ||
|
|
||
| sample_args.update(kwargs) | ||
|
|
@@ -817,11 +864,15 @@ def sample( | |
| if has_population_samplers: | ||
| _log.info(f"Population sampling ({chains} chains)") | ||
| _print_step_hierarchy(step) | ||
| _sample_population(initial_points=initial_points, parallelize=cores > 1, **sample_args) | ||
| with joined_blas_limiter(): | ||
| _sample_population( | ||
| initial_points=initial_points, parallelize=cores > 1, **sample_args | ||
| ) | ||
| else: | ||
| _log.info(f"Sequential sampling ({chains} chains in 1 job)") | ||
| _print_step_hierarchy(step) | ||
| _sample_many(**sample_args) | ||
| with joined_blas_limiter(): | ||
| _sample_many(**sample_args) | ||
|
|
||
| t_sampling = time.time() - t_start | ||
|
|
||
|
|
@@ -1139,6 +1190,7 @@ def _mp_sample( | |
| traces: Sequence[IBaseTrace], | ||
| model: Model | None = None, | ||
| callback: SamplingIteratorCallback | None = None, | ||
| blas_cores: int | None = None, | ||
| mp_ctx=None, | ||
| **kwargs, | ||
| ) -> None: | ||
|
|
@@ -1190,6 +1242,7 @@ def _mp_sample( | |
| step_method=step, | ||
| progressbar=progressbar, | ||
| progressbar_theme=progressbar_theme, | ||
| blas_cores=blas_cores, | ||
| mp_ctx=mp_ctx, | ||
| ) | ||
| try: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain the default first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Do you think we should already default to "auto", or first release something where the default is None so that this can be tested a bit more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the new default makes much more sense. This often shows up in MvNormal models and it's very tricky for beginners to debug