diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 8fb41c6de..a79e8bd53 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -21,15 +21,14 @@ import warnings from collections import defaultdict -from copy import copy -from typing import Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np import pytensor.gradient as tg from arviz import InferenceData from fastprogress.fastprogress import progress_bar -from typing_extensions import TypeAlias +from typing_extensions import Protocol, TypeAlias import pymc as pm @@ -42,7 +41,7 @@ from pymc.sampling.parallel import Draw, _cpu_count from pymc.sampling.population import _sample_population from pymc.stats.convergence import log_warning_stats, run_convergence_checks -from pymc.step_methods import NUTS, CompoundStep, DEMetropolis +from pymc.step_methods import NUTS, CompoundStep from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( @@ -59,13 +58,19 @@ __all__ = [ "sample", - "iter_sample", "init_nuts", ] Step: TypeAlias = Union[BlockedStep, CompoundStep] +class SamplingIteratorCallback(Protocol): + """Signature of the callable that may be passed to `pm.sample(callable=...)`.""" + + def __call__(self, trace: BaseTrace, draw: Draw): + pass + + _log = logging.getLogger("pymc") @@ -223,7 +228,7 @@ def sample( cores: Optional[int] = None, tune: int = 1000, progressbar: bool = True, - model=None, + model: Optional[Model] = None, random_seed: RandomState = None, discard_tuned_samples: bool = True, compute_convergence_checks: bool = True, @@ -481,11 +486,23 @@ def sample( model.check_start_vals(ip) _check_start_shape(model, ip) + # Create trace backends for each chain + traces = [ + _init_trace( + expected_length=draws + tune, + stats_dtypes=step.stats_dtypes, + chain_number=chain_number, + trace=trace, + model=model, + ) + for chain_number in range(chains) + ] + sample_args = { "draws": draws, "step": step, "start": initial_points, - "trace": trace, + "traces": traces, "chains": chains, "tune": tune, "progressbar": progressbar, @@ -526,7 +543,7 @@ def sample( _log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)") _print_step_hierarchy(step) try: - mtrace = _mp_sample(**sample_args, **parallel_args) + _mp_sample(**sample_args, **parallel_args) except pickle.PickleError: _log.warning("Could not pickle model, sampling singlethreaded.") _log.debug("Pickling error:", exc_info=True) @@ -539,38 +556,23 @@ def sample( parallel = False if not parallel: if has_population_samplers: - has_demcmc = np.any( - [ - isinstance(m, DEMetropolis) - for m in (step.methods if isinstance(step, CompoundStep) else [step]) - ] - ) _log.info(f"Population sampling ({chains} chains)") - - initial_point_model_size = sum(initial_points[0][n.name].size for n in model.value_vars) - - if has_demcmc and chains < 3: - raise ValueError( - "DEMetropolis requires at least 3 chains. " - "For this {}-dimensional model you should use ≥{} chains".format( - initial_point_model_size, initial_point_model_size + 1 - ) - ) - if has_demcmc and chains <= initial_point_model_size: - warnings.warn( - "DEMetropolis should be used with more chains than dimensions! " - "(The model has {} dimensions.)".format(initial_point_model_size), - UserWarning, - stacklevel=2, - ) _print_step_hierarchy(step) - mtrace = _sample_population(parallelize=cores > 1, **sample_args) + _sample_population(initial_points=initial_points, parallelize=cores > 1, **sample_args) else: _log.info(f"Sequential sampling ({chains} chains in 1 job)") _print_step_hierarchy(step) - mtrace = _sample_many(**sample_args) + _sample_many(**sample_args) t_sampling = time.time() - t_start + + # Wrap chain traces in a MultiTrace + if discard_tuned_samples: + traces, length = _choose_chains(traces, tune) + else: + traces, length = _choose_chains(traces, 0) + mtrace = MultiTrace(traces)[:length] + # count the number of tune/draw iterations that happened # ideally via the "tune" statistic, but not all samplers record it! if "tune" in mtrace.stat_names: @@ -604,7 +606,7 @@ def sample( idata = None if compute_convergence_checks or return_inferencedata: - ikwargs = dict(model=model, save_warmup=not discard_tuned_samples) + ikwargs: Dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples) if idata_kwargs: ikwargs.update(idata_kwargs) idata = pm.to_inference_data(mtrace, **ikwargs) @@ -654,14 +656,16 @@ def _check_start_shape(model, start: PointType): def _sample_many( + *, draws: int, chains: int, + traces: Sequence[BaseTrace], start: Sequence[PointType], random_seed: Optional[Sequence[RandomSeed]], - step, - callback=None, + step: Step, + callback: Optional[SamplingIteratorCallback] = None, **kwargs, -) -> MultiTrace: +): """Samples all chains sequentially. Parameters @@ -676,35 +680,19 @@ def _sample_many( A list of seeds, one for each chain step: function Step function - - Returns - ------- - mtrace: MultiTrace - Contains samples of all chains """ - traces: List[BaseTrace] = [] for i in range(chains): - trace = _sample( + _sample( draws=draws, chain=i, start=start[i], step=step, + trace=traces[i], random_seed=None if random_seed is None else random_seed[i], callback=callback, **kwargs, ) - if trace is None: - if len(traces) == 0: - raise ValueError("Sampling stopped before a sample was created.") - else: - break - elif len(trace) < draws: - if len(traces) == 0: - traces.append(trace) - break - else: - traces.append(trace) - return MultiTrace(traces) + return def _sample( @@ -714,13 +702,13 @@ def _sample( random_seed: RandomSeed, start: PointType, draws: int, - step=None, - trace: Optional[BaseTrace] = None, + step: Step, + trace: BaseTrace, tune: int, model: Optional[Model] = None, callback=None, **kwargs, -) -> BaseTrace: +) -> None: """Main iteration for singleprocess sampling. Multiple step methods are supported via compound step methods. @@ -741,23 +729,23 @@ def _sample( step : function Step function trace : backend, optional - A backend instance or None. - If None, the NDArray backend is used. + A backend instance. tune : int Number of iterations to tune. model : Model (optional if in ``with`` context) - - Returns - ------- - strace : BaseTrace - A ``BaseTrace`` object that contains the samples for this chain. """ skip_first = kwargs.get("skip_first", 0) - trace = copy(trace) - sampling_gen = _iter_sample( - draws, step, start, trace, chain, tune, model, random_seed, callback + draws=draws, + step=step, + start=start, + trace=trace, + chain=chain, + tune=tune, + model=model, + random_seed=random_seed, + callback=callback, ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" @@ -767,87 +755,27 @@ def _sample( else: sampling = sampling_gen try: - strace = None - for it, (strace, diverging) in enumerate(sampling): + for it, diverging in enumerate(sampling): if it >= skip_first and diverging: _pbar_data["divergences"] += 1 if progressbar: sampling.comment = _desc.format(**_pbar_data) except KeyboardInterrupt: pass - if strace is None: - raise Exception("KeyboardInterrupt happened before the base trace was created.") - return strace - - -def iter_sample( - draws: int, - step, - start: PointType, - trace=None, - chain: int = 0, - tune: int = 0, - model: Optional[Model] = None, - random_seed: RandomSeed = None, - callback=None, -) -> Iterator[MultiTrace]: - """Generate a trace on each iteration using the given step method. - - Multiple step methods ared supported via compound step methods. Returns the - amount of time taken. - - Parameters - ---------- - draws : int - The number of samples to draw - step : function - Step function - start : dict - Starting point in parameter space (or partial point). - trace : backend or list - This should be a backend instance, or a list of variables to track. - If None or a list of variables, the NDArray backend is used. - chain : int, optional - Chain number used to store sample in backend. - tune : int, optional - Number of iterations to tune (defaults to 0). - model : Model (optional if in ``with`` context) - random_seed : single random seed, optional - callback : - A function which gets called for every sample from the trace of a chain. The function is - called with the trace and the current draw and will contain all samples for a single trace. - the ``draw.chain`` argument can be used to determine which of the active chains the sample - is drawn from. - Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. - - Yields - ------ - trace : MultiTrace - Contains all samples up to the current iteration - - Examples - -------- - :: - - for trace in iter_sample(500, step): - ... - """ - sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback) - for i, (strace, _) in enumerate(sampling): - yield MultiTrace([strace[: i + 1]]) def _iter_sample( + *, draws: int, - step, + step: Step, start: PointType, - trace: Optional[BaseTrace] = None, + trace: BaseTrace, chain: int = 0, tune: int = 0, - model=None, + model: Optional[Model] = None, random_seed: RandomSeed = None, - callback=None, -) -> Iterator[Tuple[BaseTrace, bool]]: + callback: Optional[SamplingIteratorCallback] = None, +) -> Iterator[bool]: """Generator for sampling one chain. (Used in singleprocess sampling.) Parameters @@ -859,9 +787,8 @@ def _iter_sample( start : dict Starting point in parameter space (or partial point). Must contain numeric (transformed) initial values for all (transformed) free variables. - trace : backend, optional - A backend instance or None. - If None, the NDArray backend is used. + trace : backend + A backend instance. chain : int, optional Chain number used to store sample in backend. tune : int, optional @@ -871,8 +798,6 @@ def _iter_sample( Yields ------ - strace : BaseTrace - The trace object containing the samples for this chain diverging : bool Indicates if the draw is divergent. Only available with some samplers. """ @@ -885,27 +810,13 @@ def _iter_sample( if random_seed is not None: np.random.seed(random_seed) - try: - step = CompoundStep(step) - except TypeError: - pass - point = start - strace: BaseTrace = _init_trace( - expected_length=draws + tune, - stats_dtypes=step.stats_dtypes, - chain_number=chain, - trace=trace, - model=model, - ) - try: step.tune = bool(tune) if hasattr(step, "reset_tuning"): step.reset_tuning() for i in range(draws): - stats = None diverging = False if i == 0 and hasattr(step, "iter_count"): @@ -913,27 +824,28 @@ def _iter_sample( if i == tune: step.stop_tuning() point, stats = step.step(point) - strace.record(point, stats) + trace.record(point, stats) log_warning_stats(stats) - diverging = i > tune and stats and stats[0].get("diverging") + diverging = i > tune and len(stats) > 0 and (stats[0].get("diverging") == True) if callback is not None: callback( - trace=strace, + trace=trace, draw=Draw(chain, i == draws, i, i < tune, stats, point), ) - yield strace, diverging + yield diverging except KeyboardInterrupt: - strace.close() + trace.close() raise except BaseException: - strace.close() + trace.close() raise else: - strace.close() + trace.close() def _mp_sample( + *, draws: int, tune: int, step, @@ -942,13 +854,12 @@ def _mp_sample( random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, - trace: Optional[BaseTrace] = None, - model=None, - callback=None, - discard_tuned_samples: bool = True, + traces: Sequence[BaseTrace], + model: Optional[Model] = None, + callback: Optional[SamplingIteratorCallback] = None, mp_ctx=None, **kwargs, -) -> MultiTrace: +) -> None: """Main iteration for multiprocess sampling. Parameters @@ -974,34 +885,18 @@ def _mp_sample( A backend instance, or None. If None, the NDArray backend is used. model : Model (optional if in ``with`` context) - callback : Callable + callback A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. the ``draw.chain`` argument can be used to determine which of the active chains the sample is drawn from. Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. - - Returns - ------- - mtrace : pymc.backends.base.MultiTrace - A ``MultiTrace`` object that contains the samples for all chains. """ import pymc.sampling.parallel as ps # We did draws += tune in pm.sample draws -= tune - traces = [ - _init_trace( - expected_length=draws + tune, - stats_dtypes=step.stats_dtypes, - chain_number=chain_number, - trace=trace, - model=model, - ) - for chain_number in range(chains) - ] - sampler = ps.ParallelSampler( draws=draws, tune=tune, @@ -1024,7 +919,7 @@ def _mp_sample( strace.close() if callback is not None: - callback(trace=trace, draw=draw) + callback(trace=strace, draw=draw) except ps.ParallelSamplingError as error: strace = traces[error._chain] @@ -1034,13 +929,8 @@ def _mp_sample( multitrace = MultiTrace(traces) multitrace._report._log_summary() raise - return MultiTrace(traces) except KeyboardInterrupt: - if discard_tuned_samples: - traces, length = _choose_chains(traces, tune) - else: - traces, length = _choose_chains(traces, 0) - return MultiTrace(traces)[:length] + pass finally: for strace in traces: strace.close() @@ -1105,7 +995,7 @@ def init_nuts( init: str = "auto", chains: int = 1, n_init: int = 500_000, - model=None, + model: Optional[Model] = None, random_seed: RandomSeed = None, progressbar=True, jitter_max_retries: int = 10, diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 742b80478..3545b0af4 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -15,6 +15,7 @@ """Specializes on running MCMCs with population step methods.""" import logging +import warnings from copy import copy from typing import Iterator, List, Sequence, Tuple, Union @@ -25,10 +26,9 @@ from fastprogress.fastprogress import progress_bar from typing_extensions import TypeAlias -from pymc.backends import _init_trace -from pymc.backends.base import BaseTrace, MultiTrace +from pymc.backends.base import BaseTrace from pymc.initial_point import PointType -from pymc.model import modelcontext +from pymc.model import Model, modelcontext from pymc.stats.convergence import log_warning_stats from pymc.step_methods import CompoundStep from pymc.step_methods.arraystep import ( @@ -36,6 +36,7 @@ PopulationArrayStepShared, StatsType, ) +from pymc.step_methods.metropolis import DEMetropolis from pymc.util import RandomSeed __all__ = () @@ -48,25 +49,25 @@ def _sample_population( + *, + initial_points: Sequence[PointType], draws: int, - chains: int, start: Sequence[PointType], random_seed: RandomSeed, - step, + step: Union[BlockedStep, CompoundStep], tune: int, - model, + model: Model, progressbar: bool = True, parallelize: bool = False, + traces: Sequence[BaseTrace], **kwargs, -) -> MultiTrace: +): """Performs sampling of a population of chains using the ``PopulationStepper``. Parameters ---------- draws : int The number of samples to draw - chains : int - The total number of chains in the population start : list Start points for each chain random_seed : single random seed, optional @@ -79,17 +80,20 @@ def _sample_population( Show progress bars? (defaults to True) parallelize : bool Setting for multiprocess parallelization - - Returns - ------- - trace : MultiTrace - Contains samples of all chains """ + warn_population_size( + step=step, + initial_points=initial_points, + model=model, + chains=len(traces), + ) + sampling = _prepare_iter_population( - draws, - step, - start, - parallelize, + draws=draws, + step=step, + start=start, + parallelize=parallelize, + traces=traces, tune=tune, model=model, random_seed=random_seed, @@ -99,10 +103,43 @@ def _sample_population( if progressbar: sampling = progress_bar(sampling, total=draws, display=progressbar) - latest_traces = None - for it, traces in enumerate(sampling): - latest_traces = traces - return MultiTrace(latest_traces) + for i in sampling: + pass + return + + +def warn_population_size( + *, + step: Union[BlockedStep, CompoundStep], + initial_points: Sequence[PointType], + model: Model, + chains: int, +): + """Emit informative errors/warnings for dangerously small population size.""" + has_demcmc = np.any( + [ + isinstance(m, DEMetropolis) + for m in (step.methods if isinstance(step, CompoundStep) else [step]) + ] + ) + + initial_point_model_size = sum(initial_points[0][n.name].size for n in model.value_vars) + + if has_demcmc and chains < 3: + raise ValueError( + "DEMetropolis requires at least 3 chains. " + "For this {}-dimensional model you should use ≥{} chains".format( + initial_point_model_size, initial_point_model_size + 1 + ) + ) + if has_demcmc and chains <= initial_point_model_size: + warnings.warn( + "DEMetropolis should be used with more chains than dimensions! " + "(The model has {} dimensions.)".format(initial_point_model_size), + UserWarning, + stacklevel=2, + ) + return class PopulationStepper: @@ -259,15 +296,17 @@ def step(self, tune_stop: bool, population) -> List[Tuple[PointType, StatsType]] def _prepare_iter_population( + *, draws: int, step, start: Sequence[PointType], parallelize: bool, + traces: Sequence[BaseTrace], tune: int, model=None, random_seed: RandomSeed = None, progressbar=True, -) -> Iterator[Sequence[BaseTrace]]: +) -> Iterator[int]: """Prepare a PopulationStepper and traces for population sampling. Parameters @@ -290,7 +329,7 @@ def _prepare_iter_population( Returns ------- _iter_population : generator - Yields traces of all chains at the same time + Main sampling iterator yieling the iteration number. """ nchains = len(start) model = modelcontext(model) @@ -305,8 +344,7 @@ def _prepare_iter_population( # The initialization of traces, samplers and points must happen in the right order: # 1. population of points is created # 2. steppers are initialized and linked to the points object - # 3. traces are initialized - # 4. a PopulationStepper is configured for parallelized stepping + # 3. a PopulationStepper is configured for parallelized stepping # 1. create a population (points) that tracks each chain # it is updated as the chains are advanced @@ -327,29 +365,25 @@ def _prepare_iter_population( sm.link_population(population, c) steppers.append(chainstep) - # 3. Initialize a BaseTrace for each chain - traces: List[BaseTrace] = [ - _init_trace( - expected_length=draws + tune, - stats_dtypes=steppers[c].stats_dtypes, - chain_number=c, - trace=None, - model=model, - ) - for c in range(nchains) - ] - - # 4. configure the PopulationStepper (expensive call) + # 3. configure the PopulationStepper (expensive call) popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar) # Because the preparations above are expensive, the actual iterator is # in another method. This way the progbar will not be disturbed. - return _iter_population(draws, tune, popstep, steppers, traces, population) + return _iter_population( + draws=draws, tune=tune, popstep=popstep, steppers=steppers, traces=traces, points=population + ) def _iter_population( - draws: int, tune: int, popstep: PopulationStepper, steppers, traces: Sequence[BaseTrace], points -) -> Iterator[Sequence[BaseTrace]]: + *, + draws: int, + tune: int, + popstep: PopulationStepper, + steppers, + traces: Sequence[BaseTrace], + points, +) -> Iterator[int]: """Iterate a ``PopulationStepper``. Parameters @@ -369,8 +403,8 @@ def _iter_population( Yields ------ - traces : list - List of trace objects of the individual chains + i + Iteration number. """ try: with popstep: @@ -386,7 +420,7 @@ def _iter_population( strace.record(points[c], stats) log_warning_stats(stats) # yield the state of all chains in parallel - yield traces + yield i except KeyboardInterrupt: for c, strace in enumerate(traces): strace.close() diff --git a/pymc/stats/log_likelihood.py b/pymc/stats/log_likelihood.py index ec473536f..ad5c903d5 100644 --- a/pymc/stats/log_likelihood.py +++ b/pymc/stats/log_likelihood.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence +from typing import Optional, Sequence, cast import numpy as np @@ -22,6 +22,7 @@ from pymc.backends.arviz import _DefaultTrace from pymc.model import Model, modelcontext +from pymc.pytensorf import PointFunc from pymc.util import dataset_to_point_list __all__ = ("compute_log_likelihood",) @@ -86,6 +87,7 @@ def compute_log_likelihood( outs=model.logp(vars=observed_vars, sum=False), on_unused_input="ignore", ) + elemwise_loglike_fn = cast(PointFunc, elemwise_loglike_fn) finally: model.rvs_to_values = original_rvs_to_values model.rvs_to_transforms = original_rvs_to_transforms diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 486f2fc8f..be89bfdc0 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -36,6 +36,7 @@ def __init__(self, methods): self.name = ( f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" ) + self.tune = True def step(self, point) -> Tuple[PointType, StatsType]: stats = [] diff --git a/pymc/tests/sampling/test_mcmc.py b/pymc/tests/sampling/test_mcmc.py index e472a947d..f23428b3e 100644 --- a/pymc/tests/sampling/test_mcmc.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -205,18 +205,6 @@ def test_sample_args(self): pm.sample(50, tune=0, foo={}) assert "foo" in str(excinfo.value) - def test_iter_sample(self): - with self.model: - samps = pm.sampling.mcmc.iter_sample( - draws=5, - step=self.step, - start=self.start, - tune=0, - random_seed=self.random_seed, - ) - for i, trace in enumerate(samps): - assert i == len(trace) - 1, "Trace does not have correct length." - def test_parallel_start(self): with self.model: with warnings.catch_warnings(): diff --git a/pymc/util.py b/pymc/util.py index 00d47a4c4..476dbd58e 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -235,7 +235,7 @@ def enhanced(*args, **kwargs): def dataset_to_point_list( - ds: xarray.Dataset, sample_dims: List + ds: xarray.Dataset, sample_dims: Sequence[str] ) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]: # All keys of the dataset must be a str var_names = list(ds.keys()) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index d2eda132e..23265045a 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -42,7 +42,6 @@ pymc/printing.py pymc/pytensorf.py pymc/sampling/jax.py -pymc/stats/log_likelihood.py pymc/variational/approximations.py pymc/variational/opvi.py """