Skip to content

Commit 4716abf

Browse files
author
Joseph Hall
committed
Add typing in assign_step_methods
1 parent 9ff19f2 commit 4716abf

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

pymc/sampling/mcmc.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@
2020
import time
2121
import warnings
2222

23-
from collections import defaultdict
2423
from typing import (
2524
Any,
2625
Dict,
2726
Iterator,
2827
List,
2928
Literal,
29+
Mapping,
3030
Optional,
3131
Sequence,
32+
Set,
3233
Tuple,
3334
Type,
3435
Union,
@@ -40,6 +41,7 @@
4041

4142
from arviz import InferenceData
4243
from fastprogress.fastprogress import progress_bar
44+
from pytensor.graph.basic import Variable
4345
from typing_extensions import Protocol, TypeAlias
4446

4547
import pymc as pm
@@ -91,9 +93,9 @@ def __call__(self, trace: IBaseTrace, draw: Draw):
9193

9294

9395
def instantiate_steppers(
94-
model,
96+
model: Model,
9597
steps: List[Step],
96-
selected_steps: Dict[Type[BlockedStep], List[Any]],
98+
selected_steps: Mapping[Type[BlockedStep], List[Any]],
9799
step_kwargs: Optional[Dict[str, Dict]] = None,
98100
) -> Union[Step, List[Step]]:
99101
"""Instantiate steppers assigned to the model variables.
@@ -149,7 +151,12 @@ def instantiate_steppers(
149151
return steps
150152

151153

152-
def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
154+
def assign_step_methods(
155+
model: Model,
156+
step: Optional[Union[Step, Sequence[Step]]] = None,
157+
methods: Optional[Sequence[Type[BlockedStep]]] = None,
158+
step_kwargs: Optional[Dict[str, Any]] = None,
159+
) -> Union[Step, List[Step]]:
153160
"""Assign model variables to appropriate step methods.
154161
155162
Passing a specified model will auto-assign its constituent stochastic
@@ -179,49 +186,48 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
179186
methods : list
180187
List of step methods associated with the model's variables.
181188
"""
182-
steps = []
183-
assigned_vars = set()
184-
185-
if methods is None:
186-
methods = pm.STEP_METHODS
189+
steps: List[Step] = []
190+
assigned_vars: Set[Variable] = set()
187191

188192
if step is not None:
189-
try:
190-
steps += list(step)
191-
except TypeError:
193+
if isinstance(step, (BlockedStep, CompoundStep)):
192194
steps.append(step)
195+
else:
196+
steps.extend(step)
193197
for step in steps:
194198
for var in step.vars:
195199
if var not in model.value_vars:
196200
raise ValueError(
197-
f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables."
201+
f"{var} assigned to {step} sampler is not a value variable in the model. "
202+
"You can use `util.get_value_vars_from_user_vars` to parse user provided variables."
198203
)
199204
assigned_vars = assigned_vars.union(set(step.vars))
200205

201206
# Use competence classmethods to select step methods for remaining
202207
# variables
203-
selected_steps = defaultdict(list)
208+
methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS)
209+
selected_steps: Dict[Type[BlockedStep], List] = {}
204210
model_logp = model.logp()
205211

206212
for var in model.value_vars:
207213
if var not in assigned_vars:
208214
# determine if a gradient can be computed
209-
has_gradient = var.dtype not in discrete_types
215+
has_gradient = getattr(var, "dtype") not in discrete_types
210216
if has_gradient:
211217
try:
212-
tg.grad(model_logp, var)
218+
tg.grad(model_logp, var) # type: ignore
213219
except (NotImplementedError, tg.NullTypeGradError):
214220
has_gradient = False
215221

216222
# select the best method
217223
rv_var = model.values_to_rvs[var]
218224
selected = max(
219-
methods,
220-
key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(
225+
methods_list,
226+
key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( # type: ignore
221227
var, has_gradient
222228
),
223229
)
224-
selected_steps[selected].append(var)
230+
selected_steps.setdefault(selected, []).append(var)
225231

226232
return instantiate_steppers(model, steps, selected_steps, step_kwargs)
227233

pymc/step_methods/compound.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def reset_tuning(self):
246246
method.reset_tuning()
247247

248248
@property
249-
def vars(self):
249+
def vars(self) -> List[Variable]:
250250
return [var for method in self.methods for var in method.vars]
251251

252252

0 commit comments

Comments
 (0)