diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index 133a3ca965..d59e3025de 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -33,6 +33,7 @@ BlockedStep, CompoundStep, StatsBijection, + check_step_emits_tune, flat_statname, flatten_steps, ) @@ -207,11 +208,10 @@ def make_runmeta_and_point_fn( ) -> Tuple[mcb.RunMeta, PointFunc]: variables, point_fn = get_variables_and_point_fn(model, initial_point) - sample_stats = [ - mcb.Variable("tune", "bool"), - ] + check_step_emits_tune(step) # In PyMC the sampler stats are grouped by the sampler. + sample_stats = [] steps = flatten_steps(step) for s, sm in enumerate(steps): for statname, (dtype, shape) in sm.stats_dtypes_shapes.items(): @@ -221,9 +221,13 @@ def make_runmeta_and_point_fn( (-1 if s is None else s) for s in (shape or []) ] + dt = np.dtype(dtype).name + # Object types will be pickled by the ChainRecordAdapter! + if dt == "object": + dt = "str" svar = mcb.Variable( name=sname, - dtype=np.dtype(dtype).name, + dtype=dt, shape=sshape, undefined_ndim=shape is None, ) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index c0a3a13594..5e054efed4 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -262,6 +262,16 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]: return steps +def check_step_emits_tune(step: Union[CompoundStep, BlockedStep]): + if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes: + raise TypeError(f"{type(step)} does not emit the required 'tune' stat.") + elif isinstance(step, CompoundStep): + for sstep in step.methods: + if "tune" not in sstep.stats_dtypes_shapes: + raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.") + return + + class StatsBijection: """Map between a `list` of stats to `dict` of stats.""" diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 8ebe621d0c..82aa0d8753 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -458,9 +458,16 @@ class BinaryGibbsMetropolis(ArrayStep): name = "binary_gibbs_metropolis" + stats_dtypes_shapes = { + "tune": (bool, []), + } + def __init__(self, vars, order="random", transit_p=0.8, model=None): model = pm.modelcontext(model) + # Doesn't actually tune, but it's required to emit a sampler stat + # that indicates whether a draw was done in a tuning phase. + self.tune = True # transition probabilities self.transit_p = transit_p @@ -483,6 +490,11 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None): super().__init__(vars, [model.compile_logp()]) + def reset_tuning(self): + # There are no tuning parameters in this step method. + self.tune = False + return + def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: logp: Callable[[RaveledVars], np.ndarray] = args[0] order = self.order @@ -503,7 +515,10 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: if accepted: logp_curr = logp_prop - return q, [] + stats = { + "tune": self.tune, + } + return q, [stats] @staticmethod def competence(var): @@ -543,6 +558,10 @@ class CategoricalGibbsMetropolis(ArrayStep): name = "categorical_gibbs_metropolis" + stats_dtypes_shapes = { + "tune": (bool, []), + } + def __init__(self, vars, proposal="uniform", order="random", model=None): model = pm.modelcontext(model) @@ -593,8 +612,17 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): else: raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") + # Doesn't actually tune, but it's required to emit a sampler stat + # that indicates whether a draw was done in a tuning phase. + self.tune = True + super().__init__(vars, [model.compile_logp()]) + def reset_tuning(self): + # There are no tuning parameters in this step method. + self.tune = False + return + def astep_unif(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: logp = args[0] point_map_info = apoint.point_map_info @@ -614,7 +642,10 @@ def astep_unif(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType if accepted: logp_curr = logp_prop - return q, [] + stats = { + "tune": self.tune, + } + return q, [stats] def astep_prop(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: logp = args[0] diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index 2e3693c785..aa2e26ba01 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -119,7 +119,20 @@ def test_make_runmeta_and_point_fn(simple_model): assert not vars["vector"].is_deterministic assert not vars["vector_interval__"].is_deterministic assert vars["matrix"].is_deterministic - assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0]) + assert len(rmeta.sample_stats) == len(step.stats_dtypes[0]) + + with simple_model: + step = pm.NUTS() + rmeta, point_fn = make_runmeta_and_point_fn( + initial_point=simple_model.initial_point(), + step=step, + model=simple_model, + ) + assert isinstance(rmeta, mcb.RunMeta) + svars = {s.name: s for s in rmeta.sample_stats} + # Unbeknownst to McBackend, object stats are pickled to str + assert "sampler_0__warning" in svars + assert svars["sampler_0__warning"].dtype == "str" pass diff --git a/tests/step_methods/test_compound.py b/tests/step_methods/test_compound.py index 93cbecbc0d..1ed1e77ac9 100644 --- a/tests/step_methods/test_compound.py +++ b/tests/step_methods/test_compound.py @@ -26,6 +26,7 @@ Slice, ) from pymc.step_methods.compound import ( + BlockedStep, StatsBijection, flatten_steps, get_stats_dtypes_shapes_from_steps, @@ -36,6 +37,16 @@ from tests.models import simple_2model_continuous +def test_all_stepmethods_emit_tune_stat(): + attrs = [getattr(pm.step_methods, n) for n in dir(pm.step_methods)] + step_types = [ + attr for attr in attrs if isinstance(attr, type) and issubclass(attr, BlockedStep) + ] + assert len(step_types) > 5 + for cls in step_types: + assert "tune" in cls.stats_dtypes_shapes + + class TestCompoundStep: samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis)