diff --git a/pymc/backends/base.py b/pymc/backends/base.py index bc810af730..5f462d2685 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -21,7 +21,18 @@ import warnings from abc import ABC -from typing import Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import ( + Dict, + List, + Optional, + Sequence, + Set, + Sized, + Tuple, + TypeVar, + Union, + cast, +) import numpy as np @@ -510,7 +521,10 @@ def _squeeze_cat(results, combine, squeeze): return results -def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]: +S = TypeVar("S", bound=Sized) + + +def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]: """ Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized. diff --git a/pymc/backends/report.py b/pymc/backends/report.py index b336eeb58c..e7cd49c3ff 100644 --- a/pymc/backends/report.py +++ b/pymc/backends/report.py @@ -17,14 +17,7 @@ from typing import Dict, List, Optional -import arviz - -from pymc.stats.convergence import ( - _LEVELS, - SamplerWarning, - log_warnings, - run_convergence_checks, -) +from pymc.stats.convergence import _LEVELS, SamplerWarning logger = logging.getLogger("pymc") @@ -73,10 +66,6 @@ def raise_ok(self, level="error"): if errors: raise ValueError("Serious convergence issues during sampling.") - def _run_convergence_checks(self, idata: arviz.InferenceData, model): - warnings = run_convergence_checks(idata, model) - self._add_warnings(warnings) - def _add_warnings(self, warnings, chain=None): if chain is None: warn_list = self._global_warnings @@ -84,11 +73,6 @@ def _add_warnings(self, warnings, chain=None): warn_list = self._chain_warnings.setdefault(chain, []) warn_list.extend(warnings) - def _log_summary(self): - for chain, warns in self._chain_warnings.items(): - log_warnings(warns) - log_warnings(self._global_warnings) - def _slice(self, start, stop, step): report = SamplerReport() diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index a79e8bd536..54f985367c 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -40,7 +40,11 @@ from pymc.model import Model, modelcontext from pymc.sampling.parallel import Draw, _cpu_count from pymc.sampling.population import _sample_population -from pymc.stats.convergence import log_warning_stats, run_convergence_checks +from pymc.stats.convergence import ( + log_warning_stats, + log_warnings, + run_convergence_checks, +) from pymc.step_methods import NUTS, CompoundStep from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential @@ -602,7 +606,6 @@ def sample( f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " f"took {t_sampling:.0f} seconds." ) - mtrace.report._log_summary() idata = None if compute_convergence_checks or return_inferencedata: @@ -612,14 +615,9 @@ def sample( idata = pm.to_inference_data(mtrace, **ikwargs) if compute_convergence_checks: - if draws - tune < 100: - warnings.warn( - "The number of samples is too small to check convergence reliably.", - stacklevel=2, - ) - else: - convergence_warnings = run_convergence_checks(idata, model) - mtrace.report._add_warnings(convergence_warnings) + warns = run_convergence_checks(idata, model) + mtrace.report._add_warnings(warns) + log_warnings(warns) if return_inferencedata: # By default we drop the "warning" stat which contains `SamplerWarning` @@ -925,9 +923,6 @@ def _mp_sample( strace = traces[error._chain] for strace in traces: strace.close() - - multitrace = MultiTrace(traces) - multitrace._report._log_summary() raise except KeyboardInterrupt: pass diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 59fdf01ad5..45398a0ce5 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -16,7 +16,7 @@ import warnings from abc import ABC -from typing import Dict, cast +from typing import Dict, Union, cast import numpy as np import pytensor.tensor as at @@ -24,6 +24,7 @@ from pytensor.graph.replace import clone_replace from scipy.special import logsumexp from scipy.stats import multivariate_normal +from typing_extensions import TypeAlias from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection @@ -39,6 +40,9 @@ from pymc.step_methods.metropolis import MultivariateNormalProposal from pymc.vartypes import discrete_types +SMCStats: TypeAlias = Dict[str, Union[int, float]] +SMCSettings: TypeAlias = Dict[str, Union[int, float]] + class SMC_KERNEL(ABC): """Base class for the Sequential Monte Carlo kernels. @@ -304,7 +308,7 @@ def mutate(self): """Apply kernel-specific perturbation to the particles once per stage""" pass - def sample_stats(self) -> Dict: + def sample_stats(self) -> SMCStats: """Stats to be saved at the end of each stage These stats will be saved under `sample_stats` in the final InferenceData object. @@ -314,7 +318,7 @@ def sample_stats(self) -> Dict: "beta": self.beta, } - def sample_settings(self) -> Dict: + def sample_settings(self) -> SMCSettings: """SMC_kernel settings to be saved once at the end of sampling. These stats will be saved under `sample_stats` in the final InferenceData object. @@ -425,7 +429,7 @@ def mutate(self): self.acc_rate = np.mean(ac_) - def sample_stats(self): + def sample_stats(self) -> SMCStats: stats = super().sample_stats() stats.update( { @@ -434,7 +438,7 @@ def sample_stats(self): ) return stats - def sample_settings(self): + def sample_settings(self) -> SMCSettings: stats = super().sample_settings() stats.update( { @@ -543,7 +547,7 @@ def mutate(self): self.chain_acc_rate = np.mean(ac_, axis=0) - def sample_stats(self): + def sample_stats(self) -> SMCStats: stats = super().sample_stats() stats.update( { @@ -553,7 +557,7 @@ def sample_stats(self): ) return stats - def sample_settings(self): + def sample_settings(self) -> SMCSettings: stats = super().sample_settings() stats.update( { diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 0cd5c39cc0..3654332093 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -19,6 +19,7 @@ from collections import defaultdict from itertools import repeat +from typing import Any, Dict, Optional, Tuple, Union import cloudpickle import numpy as np @@ -30,9 +31,10 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace -from pymc.model import modelcontext +from pymc.model import Model, modelcontext from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH +from pymc.stats.convergence import log_warnings, run_convergence_checks from pymc.util import RandomState, _get_seeds_per_chain @@ -50,7 +52,7 @@ def sample_smc( idata_kwargs=None, progressbar=True, **kernel_kwargs, -): +) -> Union[InferenceData, MultiTrace]: r""" Sequential Monte Carlo based sampling. @@ -236,20 +238,28 @@ def sample_smc( ) if compute_convergence_checks: - _compute_convergence_checks(idata, draws, model, trace) - return idata if return_inferencedata else trace + if idata is None: + idata = to_inference_data(trace, log_likelihood=False) + warns = run_convergence_checks(idata, model) + trace.report._add_warnings(warns) + log_warnings(warns) + + if return_inferencedata: + assert idata is not None + return idata + return trace def _save_sample_stats( sample_settings, sample_stats, chains, - trace, - return_inferencedata, + trace: MultiTrace, + return_inferencedata: bool, _t_sampling, idata_kwargs, - model, -): + model: Model, +) -> Tuple[Optional[Any], Optional[InferenceData]]: sample_settings_dict = sample_settings[0] sample_settings_dict["_t_sampling"] = _t_sampling sample_stats_dict = sample_stats[0] @@ -262,12 +272,12 @@ def _save_sample_stats( value_list.append(chain_sample_stats[stat]) sample_stats_dict[stat] = value_list + idata: Optional[InferenceData] = None if not return_inferencedata: for stat, value in sample_stats_dict.items(): setattr(trace.report, stat, value) for stat, value in sample_settings_dict.items(): setattr(trace.report, stat, value) - idata = None else: for stat, value in sample_stats_dict.items(): if chains > 1: @@ -284,7 +294,7 @@ def _save_sample_stats( library=pymc, ) - ikwargs = dict(model=model) + ikwargs: Dict[str, Any] = dict(model=model) if idata_kwargs is not None: ikwargs.update(idata_kwargs) idata = to_inference_data(trace, **ikwargs) @@ -293,19 +303,6 @@ def _save_sample_stats( return sample_stats, idata -def _compute_convergence_checks(idata, draws, model, trace): - if draws < 100: - warnings.warn( - "The number of samples is too small to check convergence reliably.", - stacklevel=2, - ) - else: - if idata is None: - idata = to_inference_data(trace, log_likelihood=False) - trace.report._run_convergence_checks(idata, model) - trace.report._log_summary() - - def _sample_smc_int( draws, kernel, diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index 7d585e9248..a8b4129e42 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -53,10 +53,13 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) return [warn] + if idata["posterior"].sizes["draw"] < 100: + msg = "The number of samples is too small to check convergence reliably." + warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) + return [warn] + if idata["posterior"].sizes["chain"] == 1: - msg = ( - "Only one chain was sampled, this makes it impossible to " "run some convergence checks" - ) + msg = "Only one chain was sampled, this makes it impossible to run some convergence checks" warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") return [warn] diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index e9665909f7..548424ff4b 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from enum import IntEnum, unique -from typing import Dict, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union import numpy as np @@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]: class StatsBijection: """Map between a `list` of stats to `dict` of stats.""" - def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None: + def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: # Keep a list of flat vs. original stat names self._stat_groups: List[List[Tuple[str, str]]] = [ [(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()] for s, names_dtypes in enumerate(sampler_stats_dtypes) ] - def map(self, stats_list: StatsType) -> StatsDict: + def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict: """Combine stats dicts of multiple samplers into one dict.""" stats_dict = {} for s, sts in enumerate(stats_list): @@ -197,7 +197,7 @@ def map(self, stats_list: StatsType) -> StatsDict: stats_dict[sname] = sval return stats_dict - def rmap(self, stats_dict: StatsDict) -> StatsType: + def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType: """Split a global stats dict into a list of sampler-wise stats dicts.""" stats_list = [] for namemap in self._stat_groups: diff --git a/pymc/tests/smc/test_smc.py b/pymc/tests/smc/test_smc.py index bdced49aba..09cbd10522 100644 --- a/pymc/tests/smc/test_smc.py +++ b/pymc/tests/smc/test_smc.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings import numpy as np @@ -215,13 +216,11 @@ def test_return_datatype(self, chains): assert mt.nchains == chains assert mt["x"].size == chains * draws - def test_convergence_checks(self): - with self.fast_model: - with pytest.warns( - UserWarning, - match="The number of samples is too small", - ): + def test_convergence_checks(self, caplog): + with caplog.at_level(logging.INFO): + with self.fast_model: pm.sample_smc(draws=99) + assert "The number of samples is too small" in caplog.text def test_deprecated_parallel_arg(self): with self.fast_model: