Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
BlockedStep,
CompoundStep,
StatsBijection,
check_step_emits_tune,
flat_statname,
flatten_steps,
)
Expand Down Expand Up @@ -207,11 +208,10 @@ def make_runmeta_and_point_fn(
) -> Tuple[mcb.RunMeta, PointFunc]:
variables, point_fn = get_variables_and_point_fn(model, initial_point)

sample_stats = [
mcb.Variable("tune", "bool"),
]
check_step_emits_tune(step)

# In PyMC the sampler stats are grouped by the sampler.
sample_stats = []
steps = flatten_steps(step)
for s, sm in enumerate(steps):
for statname, (dtype, shape) in sm.stats_dtypes_shapes.items():
Expand All @@ -221,9 +221,13 @@ def make_runmeta_and_point_fn(
(-1 if s is None else s)
for s in (shape or [])
]
dt = np.dtype(dtype).name
# Object types will be pickled by the ChainRecordAdapter!
if dt == "object":
dt = "str"
svar = mcb.Variable(
name=sname,
dtype=np.dtype(dtype).name,
dtype=dt,
shape=sshape,
undefined_ndim=shape is None,
)
Expand Down
10 changes: 10 additions & 0 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
return steps


def check_step_emits_tune(step: Union[CompoundStep, BlockedStep]):
if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes:
raise TypeError(f"{type(step)} does not emit the required 'tune' stat.")
elif isinstance(step, CompoundStep):
for sstep in step.methods:
if "tune" not in sstep.stats_dtypes_shapes:
raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.")
return


class StatsBijection:
"""Map between a `list` of stats to `dict` of stats."""

Expand Down
35 changes: 33 additions & 2 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,16 @@ class BinaryGibbsMetropolis(ArrayStep):

name = "binary_gibbs_metropolis"

stats_dtypes_shapes = {
"tune": (bool, []),
}

def __init__(self, vars, order="random", transit_p=0.8, model=None):
model = pm.modelcontext(model)

# Doesn't actually tune, but it's required to emit a sampler stat
# that indicates whether a draw was done in a tuning phase.
self.tune = True
# transition probabilities
self.transit_p = transit_p

Expand All @@ -483,6 +490,11 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):

super().__init__(vars, [model.compile_logp()])

def reset_tuning(self):
# There are no tuning parameters in this step method.
self.tune = False
return

def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
logp: Callable[[RaveledVars], np.ndarray] = args[0]
order = self.order
Expand All @@ -503,7 +515,10 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
if accepted:
logp_curr = logp_prop

return q, []
stats = {
"tune": self.tune,
}
return q, [stats]

@staticmethod
def competence(var):
Expand Down Expand Up @@ -543,6 +558,10 @@ class CategoricalGibbsMetropolis(ArrayStep):

name = "categorical_gibbs_metropolis"

stats_dtypes_shapes = {
"tune": (bool, []),
}

def __init__(self, vars, proposal="uniform", order="random", model=None):
model = pm.modelcontext(model)

Expand Down Expand Up @@ -593,8 +612,17 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
else:
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")

# Doesn't actually tune, but it's required to emit a sampler stat
# that indicates whether a draw was done in a tuning phase.
self.tune = True

super().__init__(vars, [model.compile_logp()])

def reset_tuning(self):
# There are no tuning parameters in this step method.
self.tune = False
return

def astep_unif(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
logp = args[0]
point_map_info = apoint.point_map_info
Expand All @@ -614,7 +642,10 @@ def astep_unif(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType
if accepted:
logp_curr = logp_prop

return q, []
stats = {
"tune": self.tune,
}
return q, [stats]

def astep_prop(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
logp = args[0]
Expand Down
15 changes: 14 additions & 1 deletion tests/backends/test_mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,20 @@ def test_make_runmeta_and_point_fn(simple_model):
assert not vars["vector"].is_deterministic
assert not vars["vector_interval__"].is_deterministic
assert vars["matrix"].is_deterministic
assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0])
assert len(rmeta.sample_stats) == len(step.stats_dtypes[0])

with simple_model:
step = pm.NUTS()
rmeta, point_fn = make_runmeta_and_point_fn(
initial_point=simple_model.initial_point(),
step=step,
model=simple_model,
)
assert isinstance(rmeta, mcb.RunMeta)
svars = {s.name: s for s in rmeta.sample_stats}
# Unbeknownst to McBackend, object stats are pickled to str
assert "sampler_0__warning" in svars
assert svars["sampler_0__warning"].dtype == "str"
pass


Expand Down
11 changes: 11 additions & 0 deletions tests/step_methods/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Slice,
)
from pymc.step_methods.compound import (
BlockedStep,
StatsBijection,
flatten_steps,
get_stats_dtypes_shapes_from_steps,
Expand All @@ -36,6 +37,16 @@
from tests.models import simple_2model_continuous


def test_all_stepmethods_emit_tune_stat():
attrs = [getattr(pm.step_methods, n) for n in dir(pm.step_methods)]
step_types = [
attr for attr in attrs if isinstance(attr, type) and issubclass(attr, BlockedStep)
]
assert len(step_types) > 5
for cls in step_types:
assert "tune" in cls.stats_dtypes_shapes


class TestCompoundStep:
samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis)

Expand Down