From 0cebb66393f67674f2d785b2537261c2583ab4ed Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 15 May 2024 18:55:44 +0200 Subject: [PATCH 1/3] feat: Add blas_cores argument to pm.sample --- pymc/sampling/mcmc.py | 115 +++++++++++++++++++++++--------- pymc/sampling/parallel.py | 40 ++++++----- requirements.txt | 1 + tests/sampling/test_mcmc.py | 8 +++ tests/sampling/test_parallel.py | 2 + 5 files changed, 119 insertions(+), 47 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 98f490fff..74652064b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -14,13 +14,14 @@ """Functions for MCMC sampling.""" +import contextlib import logging import pickle import sys import time import warnings -from collections.abc import Iterator, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence from typing import ( Any, Literal, @@ -37,6 +38,7 @@ from rich.console import Console from rich.progress import Progress from rich.theme import Theme +from threadpoolctl import threadpool_limits from typing_extensions import Protocol import pymc as pm @@ -396,6 +398,7 @@ def sample( nuts_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, + blas_cores: int | None | Literal["auto"] = "auto", **kwargs, ) -> InferenceData: ... @@ -427,6 +430,7 @@ def sample( callback=None, mp_ctx=None, model: Model | None = None, + blas_cores: int | None | Literal["auto"] = "auto", **kwargs, ) -> MultiTrace: ... @@ -456,6 +460,7 @@ def sample( nuts_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, + blas_cores: int | None | Literal["auto"] = "auto", model: Model | None = None, **kwargs, ) -> InferenceData | MultiTrace: @@ -499,6 +504,13 @@ def sample( Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. This requires the chosen sampler to be installed. All samplers, except "pymc", require the full model to be continuous. + blas_cores: int or "auto" or None, default = "auto" + The total number of threads blas and openmp functions should use during sampling. If set to None, + this will keep the default behavior of whatever blas implementation is used at runtime. + Setting it to "auto" will set it so that the total number of active blas threads is the + same as the `cores` argument. If set to an integer, the sampler will try to use that total + number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded + down. initvals : optional, dict, array of dict Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`. The keys should be names of transformed random variables. @@ -644,6 +656,37 @@ def sample( if chains is None: chains = max(2, cores) + if blas_cores == "auto": + blas_cores = cores + + cores = min(cores, chains) + + if cores < 1: + raise ValueError("`cores` must be larger or equal to one") + + if chains < 1: + raise ValueError("`chains` must be larger or equal to one") + + if blas_cores is not None and blas_cores < 1: + raise ValueError("`blas_cores` must be larger or equal to one") + + num_blas_cores_per_chain: int | None + joined_blas_limiter: Callable[[], Any] + + if blas_cores is None: + joined_blas_limiter = contextlib.nullcontext + num_blas_cores_per_chain = None + elif isinstance(blas_cores, int): + + def joined_blas_limiter(): + return threadpool_limits(limits=blas_cores) + + num_blas_cores_per_chain = blas_cores // cores + else: + raise ValueError( + f"Invalid argument `blas_cores`, must be int, 'auto' or None: {blas_cores}" + ) + if random_seed == -1: random_seed = None random_seed_list = _get_seeds_per_chain(random_seed, chains) @@ -685,21 +728,22 @@ def sample( raise ValueError( "Model can not be sampled with NUTS alone. Your model is probably not continuous." ) - return _sample_external_nuts( - sampler=nuts_sampler, - draws=draws, - tune=tune, - chains=chains, - target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), - random_seed=random_seed, - initvals=initvals, - model=model, - var_names=var_names, - progressbar=progressbar, - idata_kwargs=idata_kwargs, - nuts_sampler_kwargs=nuts_sampler_kwargs, - **kwargs, - ) + with joined_blas_limiter(): + return _sample_external_nuts( + sampler=nuts_sampler, + draws=draws, + tune=tune, + chains=chains, + target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), + random_seed=random_seed, + initvals=initvals, + model=model, + var_names=var_names, + progressbar=progressbar, + idata_kwargs=idata_kwargs, + nuts_sampler_kwargs=nuts_sampler_kwargs, + **kwargs, + ) if isinstance(step, list): step = CompoundStep(step) @@ -708,18 +752,19 @@ def sample( nuts_kwargs = kwargs.pop("nuts") [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] _log.info("Auto-assigning NUTS sampler...") - initial_points, step = init_nuts( - init=init, - chains=chains, - n_init=n_init, - model=model, - random_seed=random_seed_list, - progressbar=progressbar, - jitter_max_retries=jitter_max_retries, - tune=tune, - initvals=initvals, - **kwargs, - ) + with joined_blas_limiter(): + initial_points, step = init_nuts( + init=init, + chains=chains, + n_init=n_init, + model=model, + random_seed=random_seed_list, + progressbar=progressbar, + jitter_max_retries=jitter_max_retries, + tune=tune, + initvals=initvals, + **kwargs, + ) if initial_points is None: # Time to draw/evaluate numeric start points for each chain. @@ -756,7 +801,8 @@ def sample( ) sample_args = { - "draws": draws + tune, # FIXME: Why is tune added to draws? + # draws is now the total number of draws, including tuning + "draws": draws + tune, "step": step, "start": initial_points, "traces": traces, @@ -772,6 +818,7 @@ def sample( } parallel_args = { "mp_ctx": mp_ctx, + "blas_cores": num_blas_cores_per_chain, } sample_args.update(kwargs) @@ -817,11 +864,15 @@ def sample( if has_population_samplers: _log.info(f"Population sampling ({chains} chains)") _print_step_hierarchy(step) - _sample_population(initial_points=initial_points, parallelize=cores > 1, **sample_args) + with joined_blas_limiter(): + _sample_population( + initial_points=initial_points, parallelize=cores > 1, **sample_args + ) else: _log.info(f"Sequential sampling ({chains} chains in 1 job)") _print_step_hierarchy(step) - _sample_many(**sample_args) + with joined_blas_limiter(): + _sample_many(**sample_args) t_sampling = time.time() - t_start @@ -1139,6 +1190,7 @@ def _mp_sample( traces: Sequence[IBaseTrace], model: Model | None = None, callback: SamplingIteratorCallback | None = None, + blas_cores: int | None = None, mp_ctx=None, **kwargs, ) -> None: @@ -1190,6 +1242,7 @@ def _mp_sample( step_method=step, progressbar=progressbar, progressbar_theme=progressbar_theme, + blas_cores=blas_cores, mp_ctx=mp_ctx, ) try: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index c2f9791de..11bb5e49e 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -29,6 +29,7 @@ from rich.console import Console from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme +from threadpoolctl import threadpool_limits from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError @@ -93,6 +94,7 @@ def __init__( draws: int, tune: int, seed, + blas_cores, ): self._msg_pipe = msg_pipe self._step_method = step_method @@ -102,6 +104,7 @@ def __init__( self._at_seed = seed + 1 self._draws = draws self._tune = tune + self._blas_cores = blas_cores def _unpickle_step_method(self): unpickle_error = ( @@ -116,22 +119,23 @@ def _unpickle_step_method(self): raise ValueError(unpickle_error) def run(self): - try: - # We do not create this in __init__, as pickling this - # would destroy the shared memory. - self._unpickle_step_method() - self._point = self._make_numpy_refs() - self._start_loop() - except KeyboardInterrupt: - pass - except BaseException as e: - e = ExceptionWithTraceback(e, e.__traceback__) - # Send is not blocking so we have to force a wait for the abort - # message - self._msg_pipe.send(("error", e)) - self._wait_for_abortion() - finally: - self._msg_pipe.close() + with threadpool_limits(limits=self._blas_cores): + try: + # We do not create this in __init__, as pickling this + # would destroy the shared memory. + self._unpickle_step_method() + self._point = self._make_numpy_refs() + self._start_loop() + except KeyboardInterrupt: + pass + except BaseException as e: + e = ExceptionWithTraceback(e, e.__traceback__) + # Send is not blocking so we have to force a wait for the abort + # message + self._msg_pipe.send(("error", e)) + self._wait_for_abortion() + finally: + self._msg_pipe.close() def _wait_for_abortion(self): while True: @@ -208,6 +212,7 @@ def __init__( chain: int, seed, start: dict[str, np.ndarray], + blas_cores, mp_ctx, ): self.chain = chain @@ -256,6 +261,7 @@ def __init__( draws, tune, seed, + blas_cores, ), ) self._process.start() @@ -378,6 +384,7 @@ def __init__( step_method, progressbar: bool = True, progressbar_theme: Theme | None = default_progress_theme, + blas_cores: int | None = None, mp_ctx=None, ): if any(len(arg) != chains for arg in [seeds, start_points]): @@ -411,6 +418,7 @@ def __init__( chain, seed, start, + blas_cores, mp_ctx, ) for chain, seed, start in zip(range(chains), seeds, start_points) diff --git a/requirements.txt b/requirements.txt index 50bbd8ae8..cfb155e79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ pandas>=0.24.0 pytensor>=2.20,<2.21 rich>=13.7.1 scipy>=1.4.1 +threadpoolctl>=3.1.0,<4.0.0 typing-extensions>=3.7.4 diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 8611a7028..be71d149e 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -507,6 +507,14 @@ def test_empty_model(): error.match("any free variables") +def test_blas_cores(): + with pm.Model(): + pm.Normal("a") + pm.sample(blas_cores="auto", tune=10, cores=2, draws=10) + pm.sample(blas_cores=None, tune=10, cores=2, draws=10) + pm.sample(blas_cores=2, tune=10, cores=2, draws=10) + + def test_partial_trace_with_trace_unsupported(): with pm.Model() as model: a = pm.Normal("a", mu=0, sigma=1) diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py index b48c25f58..c69c75fab 100644 --- a/tests/sampling/test_parallel.py +++ b/tests/sampling/test_parallel.py @@ -161,6 +161,7 @@ def test_explicit_sample(mp_start_method): mp_ctx=ctx, start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}, step_method_pickled=step_method_pickled, + blas_cores=None, ) proc.start() while True: @@ -193,6 +194,7 @@ def test_iterator(): start_points=[start] * 3, step_method=step, progressbar=False, + blas_cores=None, ) with sampler: for draw in sampler: From d671429c75145f8064ce9c2ddfebcc396cee29ce Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 16 May 2024 14:59:27 +0200 Subject: [PATCH 2/3] Improve docs and add threadpoolctl to environments --- conda-envs/environment-dev.yml | 1 + conda-envs/environment-docs.yml | 1 + conda-envs/environment-jax.yml | 1 + conda-envs/environment-test.yml | 1 + conda-envs/windows-environment-dev.yml | 1 + conda-envs/windows-environment-test.yml | 1 + pymc/sampling/mcmc.py | 17 ++++------------- 7 files changed, 10 insertions(+), 13 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index db4f5840c..7097facd9 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -18,6 +18,7 @@ dependencies: - networkx - scipy>=1.4.1 - typing-extensions>=3.7.4 +- threadpoolctl>=3.1.0 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - jax diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 3609ae618..dc833d62b 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -16,6 +16,7 @@ dependencies: - rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 +- threadpoolctl>=3.1.0 # Extra dependencies for docs build - ipython>=7.16 - jax diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 59b9f404c..442f93137 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -24,6 +24,7 @@ dependencies: - python-graphviz - networkx - rich>=13.7.1 +- threadpoolctl>=3.1.0 # JAX is only compatible with Scipy 1.13.0 from >=0.4.26, but the respective version of # JAXlib is still not on conda: https://github.com/conda-forge/jaxlib-feedstock/pull/243 - scipy>=1.4.1,<1.13.0 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index ab6ce313f..5aac6372a 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -22,6 +22,7 @@ dependencies: - rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 +- threadpoolctl>=3.1.0 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 253794abf..5e2502c76 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -19,6 +19,7 @@ dependencies: - rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 +- threadpoolctl>=3.1.0 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - myst-nb diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index f851b9a35..195571a45 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -22,6 +22,7 @@ dependencies: - rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 +- threadpoolctl>=3.1.0 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 74652064b..547babe16 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -505,12 +505,12 @@ def sample( This requires the chosen sampler to be installed. All samplers, except "pymc", require the full model to be continuous. blas_cores: int or "auto" or None, default = "auto" - The total number of threads blas and openmp functions should use during sampling. If set to None, - this will keep the default behavior of whatever blas implementation is used at runtime. - Setting it to "auto" will set it so that the total number of active blas threads is the + The total number of threads blas and openmp functions should use during sampling. + Setting it to "auto" will ensure that the total number of active blas threads is the same as the `cores` argument. If set to an integer, the sampler will try to use that total number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded - down. + down. If set to None, this will keep the default behavior of whatever blas implementation + is used at runtime. initvals : optional, dict, array of dict Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`. The keys should be names of transformed random variables. @@ -661,15 +661,6 @@ def sample( cores = min(cores, chains) - if cores < 1: - raise ValueError("`cores` must be larger or equal to one") - - if chains < 1: - raise ValueError("`chains` must be larger or equal to one") - - if blas_cores is not None and blas_cores < 1: - raise ValueError("`blas_cores` must be larger or equal to one") - num_blas_cores_per_chain: int | None joined_blas_limiter: Callable[[], Any] From 5c3ccd6a7d9a2d66f8ec2f1ed9ca3315eb587a64 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 16 May 2024 15:05:02 +0200 Subject: [PATCH 3/3] Add threadpoolctl to requirements-dev.txt --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index c5344b0b3..f38cbe75b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -28,6 +28,7 @@ sphinx-notfound-page sphinx-remove-toctrees sphinx>=1.5 sphinxext-rediraffe +threadpoolctl>=3.1.0 types-cachetools typing-extensions>=3.7.4 watermark