|
20 | 20 | import time |
21 | 21 | import warnings |
22 | 22 |
|
23 | | -from collections import defaultdict |
24 | 23 | from typing import ( |
25 | 24 | Any, |
26 | 25 | Dict, |
27 | 26 | Iterator, |
28 | 27 | List, |
29 | 28 | Literal, |
| 29 | + Mapping, |
30 | 30 | Optional, |
31 | 31 | Sequence, |
| 32 | + Set, |
32 | 33 | Tuple, |
33 | 34 | Type, |
34 | 35 | Union, |
|
40 | 41 |
|
41 | 42 | from arviz import InferenceData |
42 | 43 | from fastprogress.fastprogress import progress_bar |
| 44 | +from pytensor.graph.basic import Variable |
43 | 45 | from typing_extensions import Protocol, TypeAlias |
44 | 46 |
|
45 | 47 | import pymc as pm |
@@ -91,9 +93,9 @@ def __call__(self, trace: IBaseTrace, draw: Draw): |
91 | 93 |
|
92 | 94 |
|
93 | 95 | def instantiate_steppers( |
94 | | - model, |
| 96 | + model: Model, |
95 | 97 | steps: List[Step], |
96 | | - selected_steps: Dict[Type[BlockedStep], List[Any]], |
| 98 | + selected_steps: Mapping[Type[BlockedStep], List[Any]], |
97 | 99 | step_kwargs: Optional[Dict[str, Dict]] = None, |
98 | 100 | ) -> Union[Step, List[Step]]: |
99 | 101 | """Instantiate steppers assigned to the model variables. |
@@ -149,7 +151,12 @@ def instantiate_steppers( |
149 | 151 | return steps |
150 | 152 |
|
151 | 153 |
|
152 | | -def assign_step_methods(model, step=None, methods=None, step_kwargs=None): |
| 154 | +def assign_step_methods( |
| 155 | + model: Model, |
| 156 | + step: Optional[Union[Step, Sequence[Step]]] = None, |
| 157 | + methods: Optional[Sequence[Type[BlockedStep]]] = None, |
| 158 | + step_kwargs: Optional[Dict[str, Any]] = None, |
| 159 | +) -> Union[Step, List[Step]]: |
153 | 160 | """Assign model variables to appropriate step methods. |
154 | 161 |
|
155 | 162 | Passing a specified model will auto-assign its constituent stochastic |
@@ -179,49 +186,48 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None): |
179 | 186 | methods : list |
180 | 187 | List of step methods associated with the model's variables. |
181 | 188 | """ |
182 | | - steps = [] |
183 | | - assigned_vars = set() |
184 | | - |
185 | | - if methods is None: |
186 | | - methods = pm.STEP_METHODS |
| 189 | + steps: List[Step] = [] |
| 190 | + assigned_vars: Set[Variable] = set() |
187 | 191 |
|
188 | 192 | if step is not None: |
189 | | - try: |
190 | | - steps += list(step) |
191 | | - except TypeError: |
| 193 | + if isinstance(step, (BlockedStep, CompoundStep)): |
192 | 194 | steps.append(step) |
| 195 | + else: |
| 196 | + steps.extend(step) |
193 | 197 | for step in steps: |
194 | 198 | for var in step.vars: |
195 | 199 | if var not in model.value_vars: |
196 | 200 | raise ValueError( |
197 | | - f"{var} assigned to {step} sampler is not a value variable in the model. You can use `util.get_value_vars_from_user_vars` to parse user provided variables." |
| 201 | + f"{var} assigned to {step} sampler is not a value variable in the model. " |
| 202 | + "You can use `util.get_value_vars_from_user_vars` to parse user provided variables." |
198 | 203 | ) |
199 | 204 | assigned_vars = assigned_vars.union(set(step.vars)) |
200 | 205 |
|
201 | 206 | # Use competence classmethods to select step methods for remaining |
202 | 207 | # variables |
203 | | - selected_steps = defaultdict(list) |
| 208 | + methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS) |
| 209 | + selected_steps: Dict[Type[BlockedStep], List] = {} |
204 | 210 | model_logp = model.logp() |
205 | 211 |
|
206 | 212 | for var in model.value_vars: |
207 | 213 | if var not in assigned_vars: |
208 | 214 | # determine if a gradient can be computed |
209 | | - has_gradient = var.dtype not in discrete_types |
| 215 | + has_gradient = getattr(var, "dtype") not in discrete_types |
210 | 216 | if has_gradient: |
211 | 217 | try: |
212 | | - tg.grad(model_logp, var) |
| 218 | + tg.grad(model_logp, var) # type: ignore |
213 | 219 | except (NotImplementedError, tg.NullTypeGradError): |
214 | 220 | has_gradient = False |
215 | 221 |
|
216 | 222 | # select the best method |
217 | 223 | rv_var = model.values_to_rvs[var] |
218 | 224 | selected = max( |
219 | | - methods, |
220 | | - key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( |
| 225 | + methods_list, |
| 226 | + key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( # type: ignore |
221 | 227 | var, has_gradient |
222 | 228 | ), |
223 | 229 | ) |
224 | | - selected_steps[selected].append(var) |
| 230 | + selected_steps.setdefault(selected, []).append(var) |
225 | 231 |
|
226 | 232 | return instantiate_steppers(model, steps, selected_steps, step_kwargs) |
227 | 233 |
|
|
0 commit comments