From 86af2ddbaa724ee13862274f83365d7e9cdf72c6 Mon Sep 17 00:00:00 2001 From: Joseph Hall Date: Fri, 26 May 2023 19:44:00 +0100 Subject: [PATCH 1/5] More informative error message in instantiate_steppers --- pymc/sampling/mcmc.py | 14 ++++++++++++-- tests/sampling/test_mcmc.py | 6 ++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index baf2946de5..a2e7fbc4d0 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -90,7 +90,10 @@ def __call__(self, trace: IBaseTrace, draw: Draw): def instantiate_steppers( - model, steps: List[Step], selected_steps, step_kwargs=None + model, + steps: List[Step], + selected_steps: Dict[str, List[Any]], + step_kwargs: Optional[Dict[str, Dict]] = None, ) -> Union[Step, List[Step]]: """Instantiate steppers assigned to the model variables. @@ -129,7 +132,14 @@ def instantiate_steppers( unused_args = set(step_kwargs).difference(used_keys) if unused_args: - raise ValueError("Unused step method arguments: %s" % unused_args) + s = "s" if len(unused_args) > 1 else "" + example_arg = sorted(unused_args)[0] + example_step = list(selected_steps.keys())[0] + raise ValueError( + f"Invalid key{s} found in step_kwargs: {unused_args}. " + "Keys must be step names and values valid kwargs for that stepper. " + f'Did you mean {{"{example_step}": {{"{example_arg}": ...}}}}?' + ) if len(steps) == 1: return steps[0] diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 6dff413179..90ed8a63f8 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -180,13 +180,11 @@ def test_sample_init(self): def test_sample_args(self): with self.model: - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r"'foo'"): pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo=1) - assert "'foo'" in str(excinfo.value) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r"'foo'") as excinfo: pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo={}) - assert "foo" in str(excinfo.value) def test_parallel_start(self): with self.model: From 9ff19f280be97612665bbdb77911cce496545e6b Mon Sep 17 00:00:00 2001 From: Joseph Hall Date: Wed, 7 Jun 2023 10:34:44 +0100 Subject: [PATCH 2/5] Fix mypy issue --- pymc/sampling/mcmc.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index a2e7fbc4d0..e8a179394c 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -30,6 +30,7 @@ Optional, Sequence, Tuple, + Type, Union, overload, ) @@ -92,7 +93,7 @@ def __call__(self, trace: IBaseTrace, draw: Draw): def instantiate_steppers( model, steps: List[Step], - selected_steps: Dict[str, List[Any]], + selected_steps: Dict[Type[BlockedStep], List[Any]], step_kwargs: Optional[Dict[str, Dict]] = None, ) -> Union[Step, List[Step]]: """Instantiate steppers assigned to the model variables. @@ -125,8 +126,9 @@ def instantiate_steppers( used_keys = set() for step_class, vars in selected_steps.items(): if vars: - args = step_kwargs.get(step_class.name, {}) - used_keys.add(step_class.name) + name = getattr(step_class, "name") + args = step_kwargs.get(name, {}) + used_keys.add(name) step = step_class(vars=vars, model=model, **args) steps.append(step) From 4716abf4b18f51242f899d5fb1aed9c9202b868d Mon Sep 17 00:00:00 2001 From: Joseph Hall Date: Wed, 7 Jun 2023 11:19:33 +0100 Subject: [PATCH 3/5] Add typing in assign_step_methods --- pymc/sampling/mcmc.py | 44 ++++++++++++++++++++--------------- pymc/step_methods/compound.py | 2 +- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index e8a179394c..1959b8f75a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -20,15 +20,16 @@ import time import warnings -from collections import defaultdict from typing import ( Any, Dict, Iterator, List, Literal, + Mapping, Optional, Sequence, + Set, Tuple, Type, Union, @@ -40,6 +41,7 @@ from arviz import InferenceData from fastprogress.fastprogress import progress_bar +from pytensor.graph.basic import Variable from typing_extensions import Protocol, TypeAlias import pymc as pm @@ -91,9 +93,9 @@ def __call__(self, trace: IBaseTrace, draw: Draw): def instantiate_steppers( - model, + model: Model, steps: List[Step], - selected_steps: Dict[Type[BlockedStep], List[Any]], + selected_steps: Mapping[Type[BlockedStep], List[Any]], step_kwargs: Optional[Dict[str, Dict]] = None, ) -> Union[Step, List[Step]]: """Instantiate steppers assigned to the model variables. @@ -149,7 +151,12 @@ def instantiate_steppers( return steps -def assign_step_methods(model, step=None, methods=None, step_kwargs=None): +def assign_step_methods( + model: Model, + step: Optional[Union[Step, Sequence[Step]]] = None, + methods: Optional[Sequence[Type[BlockedStep]]] = None, + step_kwargs: Optional[Dict[str, Any]] = None, +) -> Union[Step, List[Step]]: """Assign model variables to appropriate step methods. Passing a specified model will auto-assign its constituent stochastic @@ -179,49 +186,48 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None): methods : list List of step methods associated with the model's variables. """ - steps = [] - assigned_vars = set() - - if methods is None: - methods = pm.STEP_METHODS + steps: List[Step] = [] + assigned_vars: Set[Variable] = set() if step is not None: - try: - steps += list(step) - except TypeError: + if isinstance(step, (BlockedStep, CompoundStep)): steps.append(step) + else: + steps.extend(step) for step in steps: for var in step.vars: if var not in model.value_vars: raise ValueError( - f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables." + f"{var} assigned to {step} sampler is not a value variable in the model. " + "You can use `util.get_value_vars_from_user_vars` to parse user provided variables." ) assigned_vars = assigned_vars.union(set(step.vars)) # Use competence classmethods to select step methods for remaining # variables - selected_steps = defaultdict(list) + methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS) + selected_steps: Dict[Type[BlockedStep], List] = {} model_logp = model.logp() for var in model.value_vars: if var not in assigned_vars: # determine if a gradient can be computed - has_gradient = var.dtype not in discrete_types + has_gradient = getattr(var, "dtype") not in discrete_types if has_gradient: try: - tg.grad(model_logp, var) + tg.grad(model_logp, var) # type: ignore except (NotImplementedError, tg.NullTypeGradError): has_gradient = False # select the best method rv_var = model.values_to_rvs[var] selected = max( - methods, - key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( + methods_list, + key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( # type: ignore var, has_gradient ), ) - selected_steps[selected].append(var) + selected_steps.setdefault(selected, []).append(var) return instantiate_steppers(model, steps, selected_steps, step_kwargs) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 28a1efb718..c0a3a13594 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -246,7 +246,7 @@ def reset_tuning(self): method.reset_tuning() @property - def vars(self): + def vars(self) -> List[Variable]: return [var for method in self.methods for var in method.vars] From 6be1647f76edcf85b212302c192812f44d2dd051 Mon Sep 17 00:00:00 2001 From: Joseph Hall Date: Wed, 7 Jun 2023 12:22:57 +0100 Subject: [PATCH 4/5] Fix failing tests --- pymc/sampling/mcmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 1959b8f75a..5d7f648a3b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -138,11 +138,11 @@ def instantiate_steppers( if unused_args: s = "s" if len(unused_args) > 1 else "" example_arg = sorted(unused_args)[0] - example_step = list(selected_steps.keys())[0] + example_step = (list(selected_steps.keys()) or pm.STEP_METHODS)[0] raise ValueError( f"Invalid key{s} found in step_kwargs: {unused_args}. " "Keys must be step names and values valid kwargs for that stepper. " - f'Did you mean {{"{example_step}": {{"{example_arg}": ...}}}}?' + f'Did you mean {{"{example_step.name}": {{"{example_arg}": ...}}}}?' ) if len(steps) == 1: From 483cd5e3d52060251d158d4d25d1e4178a2b7009 Mon Sep 17 00:00:00 2001 From: Joseph Hall Date: Wed, 7 Jun 2023 12:30:53 +0100 Subject: [PATCH 5/5] Fix mypy --- pymc/sampling/mcmc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 5d7f648a3b..5986c1444b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -139,10 +139,11 @@ def instantiate_steppers( s = "s" if len(unused_args) > 1 else "" example_arg = sorted(unused_args)[0] example_step = (list(selected_steps.keys()) or pm.STEP_METHODS)[0] + example_step_name = getattr(example_step, "name") raise ValueError( f"Invalid key{s} found in step_kwargs: {unused_args}. " "Keys must be step names and values valid kwargs for that stepper. " - f'Did you mean {{"{example_step.name}": {{"{example_arg}": ...}}}}?' + f'Did you mean {{"{example_step_name}": {{"{example_arg}": ...}}}}?' ) if len(steps) == 1: