Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 40 additions & 21 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
import time
import warnings

from collections import defaultdict
from typing import (
Any,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
overload,
)
Expand All @@ -39,6 +41,7 @@

from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from pytensor.graph.basic import Variable
from typing_extensions import Protocol, TypeAlias

import pymc as pm
Expand Down Expand Up @@ -90,7 +93,10 @@ def __call__(self, trace: IBaseTrace, draw: Draw):


def instantiate_steppers(
model, steps: List[Step], selected_steps, step_kwargs=None
model: Model,
steps: List[Step],
selected_steps: Mapping[Type[BlockedStep], List[Any]],
step_kwargs: Optional[Dict[str, Dict]] = None,
) -> Union[Step, List[Step]]:
"""Instantiate steppers assigned to the model variables.

Expand Down Expand Up @@ -122,22 +128,36 @@ def instantiate_steppers(
used_keys = set()
for step_class, vars in selected_steps.items():
if vars:
args = step_kwargs.get(step_class.name, {})
used_keys.add(step_class.name)
name = getattr(step_class, "name")
args = step_kwargs.get(name, {})
used_keys.add(name)
step = step_class(vars=vars, model=model, **args)
steps.append(step)

unused_args = set(step_kwargs).difference(used_keys)
if unused_args:
raise ValueError("Unused step method arguments: %s" % unused_args)
s = "s" if len(unused_args) > 1 else ""
example_arg = sorted(unused_args)[0]
example_step = (list(selected_steps.keys()) or pm.STEP_METHODS)[0]
example_step_name = getattr(example_step, "name")
raise ValueError(
f"Invalid key{s} found in step_kwargs: {unused_args}. "
"Keys must be step names and values valid kwargs for that stepper. "
f'Did you mean {{"{example_step_name}": {{"{example_arg}": ...}}}}?'
)

if len(steps) == 1:
return steps[0]

return steps


def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
def assign_step_methods(
model: Model,
step: Optional[Union[Step, Sequence[Step]]] = None,
methods: Optional[Sequence[Type[BlockedStep]]] = None,
step_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[Step, List[Step]]:
"""Assign model variables to appropriate step methods.

Passing a specified model will auto-assign its constituent stochastic
Expand Down Expand Up @@ -167,49 +187,48 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
methods : list
List of step methods associated with the model's variables.
"""
steps = []
assigned_vars = set()

if methods is None:
methods = pm.STEP_METHODS
steps: List[Step] = []
assigned_vars: Set[Variable] = set()

if step is not None:
try:
steps += list(step)
except TypeError:
if isinstance(step, (BlockedStep, CompoundStep)):
steps.append(step)
else:
steps.extend(step)
for step in steps:
for var in step.vars:
if var not in model.value_vars:
raise ValueError(
f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables."
f"{var} assigned to {step} sampler is not a value variable in the model. "
"You can use `util.get_value_vars_from_user_vars` to parse user provided variables."
)
assigned_vars = assigned_vars.union(set(step.vars))

# Use competence classmethods to select step methods for remaining
# variables
selected_steps = defaultdict(list)
methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS)
selected_steps: Dict[Type[BlockedStep], List] = {}
model_logp = model.logp()

for var in model.value_vars:
if var not in assigned_vars:
# determine if a gradient can be computed
has_gradient = var.dtype not in discrete_types
has_gradient = getattr(var, "dtype") not in discrete_types
if has_gradient:
try:
tg.grad(model_logp, var)
tg.grad(model_logp, var) # type: ignore
except (NotImplementedError, tg.NullTypeGradError):
has_gradient = False

# select the best method
rv_var = model.values_to_rvs[var]
selected = max(
methods,
key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(
methods_list,
key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( # type: ignore
var, has_gradient
),
)
selected_steps[selected].append(var)
selected_steps.setdefault(selected, []).append(var)

return instantiate_steppers(model, steps, selected_steps, step_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def reset_tuning(self):
method.reset_tuning()

@property
def vars(self):
def vars(self) -> List[Variable]:
return [var for method in self.methods for var in method.vars]


Expand Down
6 changes: 2 additions & 4 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,11 @@ def test_sample_init(self):

def test_sample_args(self):
with self.model:
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match=r"'foo'"):
pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo=1)
assert "'foo'" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError, match=r"'foo'") as excinfo:
pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo={})
assert "foo" in str(excinfo.value)

def test_parallel_start(self):
with self.model:
Expand Down