diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 3055fd2473..ce08a797d2 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,7 +30,7 @@ jobs: # → pytest will run only these files - | --ignore=pymc/tests/test_distributions_timeseries.py - --ignore=pymc/tests/test_initvals.py + --ignore=pymc/tests/test_initial_point.py --ignore=pymc/tests/test_mixture.py --ignore=pymc/tests/test_model_graph.py --ignore=pymc/tests/test_modelcontext.py @@ -61,7 +61,7 @@ jobs: --ignore=pymc/tests/test_idata_conversion.py - | - pymc/tests/test_initvals.py + pymc/tests/test_initial_point.py pymc/tests/test_distributions.py - | @@ -154,7 +154,7 @@ jobs: floatx: [float32, float64] test-subset: - | - pymc/tests/test_initvals.py + pymc/tests/test_initial_point.py pymc/tests/test_distributions_random.py pymc/tests/test_distributions_timeseries.py - | diff --git a/benchmarks/benchmarks/benchmarks.py b/benchmarks/benchmarks/benchmarks.py index 82771087bf..e8f029aed1 100644 --- a/benchmarks/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks/benchmarks.py @@ -173,12 +173,14 @@ class NUTSInitSuite: def time_glm_hierarchical_init(self, init): """How long does it take to run the initialization.""" with glm_hierarchical_model(): - pm.init_nuts(init=init, chains=self.chains, progressbar=False) + pm.init_nuts( + init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) + ) def track_glm_hierarchical_ess(self, init): with glm_hierarchical_model(): start, step = pm.init_nuts( - init=init, chains=self.chains, progressbar=False, random_seed=123 + init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) ) t0 = time.time() idata = pm.sample( @@ -187,7 +189,7 @@ def track_glm_hierarchical_ess(self, init): cores=4, chains=self.chains, start=start, - random_seed=100, + seeds=np.arange(self.chains), progressbar=False, compute_convergence_checks=False, ) @@ -199,7 +201,7 @@ def track_marginal_mixture_model_ess(self, init): model, start = mixture_model() with model: _, step = pm.init_nuts( - init=init, chains=self.chains, progressbar=False, random_seed=123 + init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) ) start = [{k: v for k, v in start.items()} for _ in range(self.chains)] t0 = time.time() @@ -209,7 +211,7 @@ def track_marginal_mixture_model_ess(self, init): cores=4, chains=self.chains, start=start, - random_seed=100, + seeds=np.arange(self.chains), progressbar=False, compute_convergence_checks=False, ) diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index a3e14c419a..a5e0023fff 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index b0bb7b4922..ce1eaf7dd0 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index a8e929b112..f86088aff0 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml index cb979c85ad..8092df0d63 100644 --- a/conda-envs/environment-test-py37.yml +++ b/conda-envs/environment-test-py37.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 1db9766278..e80765af2b 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index 8aedc89930..713d8c1bda 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools - cloudpickle diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index 77616756d9..bdf326f74f 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: # base dependencies (see install guide for Windows) -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml index 646bb2d01b..53fb5d9bf1 100644 --- a/conda-envs/windows-environment-test-py38.yml +++ b/conda-envs/windows-environment-test-py38.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: # base dependencies (see install guide for Windows) -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.2 - cachetools - cloudpickle diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index d5d5dd39ed..282fba1dc9 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -365,10 +365,13 @@ class Flat(Continuous): rv_op = flat + def __new__(cls, *args, **kwargs): + kwargs.setdefault("initval", "moment") + return super().__new__(cls, *args, **kwargs) + @classmethod def dist(cls, *, size=None, **kwargs): res = super().dist([], size=size, **kwargs) - res.tag.test_value = np.full(size, floatX(0.0)) return res def get_moment(rv, size, *rv_inputs): @@ -430,10 +433,13 @@ class HalfFlat(PositiveContinuous): rv_op = halfflat + def __new__(cls, *args, **kwargs): + kwargs.setdefault("initval", "moment") + return super().__new__(cls, *args, **kwargs) + @classmethod def dist(cls, *, size=None, **kwargs): res = super().dist([], size=size, **kwargs) - res.tag.test_value = np.full(size, floatX(1.0)) return res def get_moment(value_var, size, *rv_inputs): diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 561f0d5820..efb648c087 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -165,8 +165,10 @@ def __new__( dims : tuple, optional A tuple of dimension names known to the model. initval : optional - Test value to be attached to the output RV. - Must match its shape exactly. + Numeric or symbolic untransformed initial value of matching shape, + or one of the following initial value strategies: "moment", "prior". + Depending on the sampler's settings, a random jitter may be added to numeric, symbolic + or moment-based initial values in the transformed space. observed : optional Observed data to be passed when registering the random variable in the model. See ``Model.register_rv``. @@ -600,31 +602,16 @@ def dist(cls, *args, **kwargs): else: dtype = cls.rv_op.dtype ndim_supp = cls.rv_op.ndim_supp - if not hasattr(output.tag, "test_value"): - size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp - output.tag.test_value = np.zeros(size, dtype) return output def default_not_implemented(rv_name, method_name): - if method_name == "random": - # This is a hack to catch the NotImplementedError when creating the RV without random - # If the message starts with "Cannot sample from", then it uses the test_value as - # the initial_val. - message = ( - f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} " - "keyword argument was not provided when the distribution was " - f"but this method had not been provided when the distribution was " - f"constructed. Please re-build your model and provide a callable " - f"to '{rv_name}'s {method_name} keyword argument.\n" - ) - else: - message = ( - f"Attempted to run {method_name} on the DensityDist '{rv_name}', " - f"but this method had not been provided when the distribution was " - f"constructed. Please re-build your model and provide a callable " - f"to '{rv_name}'s {method_name} keyword argument.\n" - ) + message = ( + f"Attempted to run {method_name} on the DensityDist '{rv_name}', " + f"but this method had not been provided when the distribution was " + f"constructed. Please re-build your model and provide a callable " + f"to '{rv_name}'s {method_name} keyword argument.\n" + ) def func(*args, **kwargs): raise NotImplementedError(message) diff --git a/pymc/initial_point.py b/pymc/initial_point.py new file mode 100644 index 0000000000..34dee7e381 --- /dev/null +++ b/pymc/initial_point.py @@ -0,0 +1,320 @@ +# Copyright 2021 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +from typing import Callable, Dict, List, Optional, Sequence, Set, Union + +import aesara +import aesara.tensor as at +import numpy as np + +from aesara.graph.basic import Variable, graph_inputs +from aesara.graph.fg import FunctionGraph +from aesara.tensor.var import TensorVariable + +from pymc.aesaraf import compile_rv_inplace +from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name + +StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]] +PointType = Dict[str, np.ndarray] + + +def convert_str_to_rv_dict( + model, start: StartDict +) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]: + """Helper function for converting a user-provided start dict with str keys of (transformed) variable names + to a dict mapping the RV tensors to untransformed initvals. + TODO: Deprecate this functionality and only accept TensorVariables as keys + """ + initvals = {} + for key, initval in start.items(): + if isinstance(key, str): + if is_transformed_name(key): + rv = model[get_untransformed_name(key)] + initvals[rv] = model.rvs_to_values[rv].tag.transform.backward(rv, initval) + else: + initvals[model[key]] = initval + else: + initvals[key] = initval + return initvals + + +def filter_rvs_to_jitter(step) -> Set[TensorVariable]: + """Find the set of RVs for which the responsible step methods ask for + the addition of jitter to the initial point. + + Parameters + ---------- + step : BlockedStep or CompoundStep + One or many step methods that were assigned model variables. + + Returns + ------- + rvs_to_jitter : set + The random variables for which jitter should be added. + """ + # TODO: implement this + return {} + + +def make_initial_point_fns_per_chain( + *, + model, + overrides: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], + jitter_rvs: Set[TensorVariable], + chains: int, +) -> List[Callable]: + """Create an initial point function for each chain, as defined by initvals + + If a single initval dictionary is passed, the function is replicated for each + chain, otherwise a unique function is compiled for each entry in the dictionary. + + Parameters + ---------- + overrides : optional, list or dict + Initial value strategy overrides that should take precedence over the defaults from the model. + A sequence of None or dicts will be treated as chain-wise strategies and must have the same length as `seeds`. + jitter_rvs : set + Random variable tensors for which U(-1, 1) jitter shall be applied. + (To the transformed space if applicable.) + + Raises + ------ + ValueError + If the number of entries in initvals is different than the number of chains + + """ + if isinstance(overrides, dict) or overrides is None: + # One strategy for all chains + # Only one function compilation is needed. + ipfns = [ + make_initial_point_fn( + model=model, + overrides=overrides, + jitter_rvs=jitter_rvs, + return_transformed=True, + ) + ] * chains + elif len(overrides) == chains: + ipfns = [ + make_initial_point_fn( + model=model, + jitter_rvs=jitter_rvs, + overrides=chain_overrides, + return_transformed=True, + ) + for chain_overrides in overrides + ] + else: + raise ValueError( + f"Number of initval dicts ({len(overrides)}) does not match the number of chains ({chains})." + ) + + return ipfns + + +def make_initial_point_fn( + *, + model, + overrides: Optional[StartDict] = None, + jitter_rvs: Optional[Set[TensorVariable]] = None, + default_strategy: str = "prior", + return_transformed: bool = True, +) -> Callable: + """Create seeded function that computes initial values for all free model variables. + + Parameters + ---------- + jitter_rvs : set + The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be + added to the initial value. Only available for variables that have a transform or real-valued support. + default_strategy : str + Which of { "moment", "prior" } to prefer if the initval setting for an RV is None. + overrides : dict + Initial value (strategies) to use instead of what's specified in `Model.initial_values`. + return_transformed : bool + If `True` the returned variables will correspond to transformed initial values. + """ + + def find_rng_nodes(variables): + return [ + node + for node in graph_inputs(variables) + if isinstance( + node, + ( + at.random.var.RandomStateSharedVariable, + at.random.var.RandomGeneratorSharedVariable, + ), + ) + ] + + overrides = convert_str_to_rv_dict(model, overrides or {}) + + initial_values = make_initial_point_expression( + free_rvs=model.free_RVs, + rvs_to_values=model.rvs_to_values, + initval_strategies={**model.initial_values, **(overrides or {})}, + jitter_rvs=jitter_rvs, + default_strategy=default_strategy, + return_transformed=return_transformed, + ) + + # Replace original rng shared variables so that we don't mess with them + # when calling the final seeded function + graph = FunctionGraph(outputs=initial_values, clone=False) + rng_nodes = find_rng_nodes(graph.outputs) + new_rng_nodes = [] + for rng_node in rng_nodes: + if isinstance(rng_node, at.random.var.RandomStateSharedVariable): + new_rng = np.random.RandomState(np.random.PCG64()) + else: + new_rng = np.random.Generator(np.random.PCG64()) + new_rng_nodes.append(aesara.shared(new_rng)) + graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True) + func = compile_rv_inplace( + inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE + ) + + varnames = [] + for var in model.free_RVs: + transform = getattr(model.rvs_to_values[var].tag, "transform", None) + if transform is not None and return_transformed: + name = get_transformed_name(var.name, transform) + else: + name = var.name + varnames.append(name) + + def make_seeded_function(func): + + rngs = find_rng_nodes(func.maker.fgraph.outputs) + + @functools.wraps(func) + def inner(seed, *args, **kwargs): + seeds = [ + np.random.PCG64(sub_seed) + for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs)) + ] + for rng, seed in zip(rngs, seeds): + if isinstance(rng, at.random.var.RandomStateSharedVariable): + new_rng = np.random.RandomState(seed) + else: + new_rng = np.random.Generator(seed) + rng.set_value(new_rng, True) + values = func(*args, **kwargs) + return dict(zip(varnames, values)) + + return inner + + return make_seeded_function(func) + + +def make_initial_point_expression( + *, + free_rvs: Sequence[TensorVariable], + rvs_to_values: Dict[TensorVariable, TensorVariable], + initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], + jitter_rvs: Set[TensorVariable] = None, + default_strategy: str = "prior", + return_transformed: bool = False, +) -> List[TensorVariable]: + """Creates the tensor variables that need to be evaluated to obtain an initial point. + + Parameters + ---------- + free_rvs : list + Tensors of free random variables in the model. + rvs_to_values : dict + Mapping of free random variable tensors to value variable tensors. + initval_strategies : dict + Mapping of free random variable tensors to initial value strategies. + For example the `Model.initial_values` dictionary. + jitter_rvs : set + The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be + added to the initial value. Only available for variables that have a transform or real-valued support. + default_strategy : str + Which of { "moment", "prior" } to prefer if the initval strategy setting for an RV is None. + return_transformed : bool + Switches between returning the tensors for untransformed or transformed initial points. + + Returns + ------- + initial_points : list of TensorVariable + Aesara expressions for initial values of the free random variables. + """ + from pymc.distributions.distribution import get_moment + + if jitter_rvs is None: + jitter_rvs = set() + + initial_values = [] + initial_values_transformed = [] + + for variable in free_rvs: + strategy = initval_strategies.get(variable, None) + + if strategy is None: + strategy = default_strategy + + if strategy == "moment": + value = get_moment(variable) + elif strategy == "prior": + value = variable + else: + value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) + + transform = getattr(rvs_to_values[variable].tag, "transform", None) + + if transform is not None: + value = transform.forward(variable, value) + + if variable in jitter_rvs: + jitter = at.random.uniform(-1, 1, size=value.shape) + jitter.name = f"{variable.name}_jitter" + value = value + jitter + + initial_values_transformed.append(value) + + if transform is not None: + value = transform.backward(variable, value) + + initial_values.append(value) + + all_outputs = [] + all_outputs.extend(free_rvs) + all_outputs.extend(initial_values) + all_outputs.extend(initial_values_transformed) + + copy_graph = FunctionGraph(outputs=all_outputs, clone=True) + + n_variables = len(free_rvs) + free_rvs_clone = copy_graph.outputs[:n_variables] + initial_values_clone = copy_graph.outputs[n_variables:-n_variables] + initial_values_transformed_clone = copy_graph.outputs[-n_variables:] + + # In the order the variables were created, replace each previous variable + # with the init_point for that variable. + initial_values = [] + initial_values_transformed = [] + + for i in range(n_variables): + outputs = [initial_values_clone[i], initial_values_transformed_clone[i]] + graph = FunctionGraph(outputs=outputs, clone=False) + graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True) + initial_values.append(graph.outputs[0]) + initial_values_transformed.append(graph.outputs[1]) + + if return_transformed: + return initial_values_transformed + return initial_values diff --git a/pymc/model.py b/pymc/model.py index 240921e99e..dc08a5d2d9 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -38,7 +38,6 @@ import numpy as np import scipy.sparse as sps -from aesara.compile.mode import Mode, get_mode from aesara.compile.sharedvalue import SharedVariable from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.fg import FunctionGraph @@ -59,8 +58,8 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, Minibatch from pymc.distributions import logp_transform, logpt, logpt_sum -from pymc.distributions.transforms import Transform from pymc.exceptions import ImputationWarning, SamplingError, ShapeError +from pymc.initial_point import make_initial_point_fn from pymc.math import flatten_list from pymc.util import ( UNSET, @@ -645,7 +644,6 @@ def __init__( # The sequence of model-generated RNGs self.rng_seq = [] self._initial_values = {} - self._initial_point_cache = {} if self.parent is not None: self.named_vars = treedict(parent=self.parent.named_vars) @@ -917,121 +915,51 @@ def cont_vars(self): @property def test_point(self) -> Dict[str, np.ndarray]: - """Deprecated alias for `Model.initial_point`.""" + """Deprecated alias for `Model.recompute_initial_point(seed=None)`.""" warnings.warn( - "`Model.test_point` has been deprecated. Use `Model.initial_point` or `Model.recompute_initial_point()`.", + "`Model.test_point` has been deprecated. Use `Model.recompute_initial_point(seed=None)`.", DeprecationWarning, ) - return self.initial_point + return self.recompute_initial_point() @property def initial_point(self) -> Dict[str, np.ndarray]: - """Maps free variable names to transformed, numeric initial values.""" - if set(self._initial_point_cache) != {get_var_name(k) for k in self.initial_values}: - return self.recompute_initial_point() - return self._initial_point_cache + """Deprecated alias for `Model.recompute_initial_point(seed=None)`.""" + warnings.warn( + "`Model.initial_point` has been deprecated. Use `Model.recompute_initial_point(seed=None)`.", + DeprecationWarning, + ) + return self.recompute_initial_point() - def recompute_initial_point(self) -> Dict[str, np.ndarray]: - """Recomputes numeric initial values for all free model variables. + def recompute_initial_point(self, seed=None) -> Dict[str, np.ndarray]: + """Recomputes the initial point of the model. Returns ------- - initial_point : dict - Maps free variable names to transformed, numeric initial values. + ip : dict + Maps names of transformed variables to numeric initial values in the transformed space. """ - self._initial_point_cache = Point(list(self.initial_values.items()), model=self) - return self._initial_point_cache + if seed is None: + seed = self.rng_seeder.randint(2 ** 30, dtype=np.int64) + fn = make_initial_point_fn(model=self, return_transformed=True) + return Point(fn(seed), model=self) @property - def initial_values(self) -> Dict[TensorVariable, np.ndarray]: - """Maps transformed variables to initial values. + def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]: + """Maps transformed variables to initial value placeholders. - ⚠ The keys are NOT the objects returned by, `pm.Normal(...)`. - For a name-based dictionary use the `initial_point` property. + Keys are the random variables (as returned by e.g. ``pm.Uniform()``) and + values are the numeric/symbolic initial values, strings denoting the strategy to get them, or None. """ return self._initial_values def set_initval(self, rv_var, initval): - if initval is not None: + """Sets an initial value (strategy) for a random variable.""" + if initval is not None and not isinstance(initval, (Variable, str)): + # Convert scalars or array-like inputs to ndarrays initval = rv_var.type.filter(initval) - test_value = getattr(rv_var.tag, "test_value", None) - - rv_value_var = self.rvs_to_values[rv_var] - transform = getattr(rv_value_var.tag, "transform", None) - - if initval is None or transform: - initval = self._eval_initval(rv_var, initval, test_value, transform) - - self.initial_values[rv_value_var] = initval - - def _eval_initval( - self, - rv_var: TensorVariable, - initval: Optional[Variable], - test_value: Optional[np.ndarray], - transform: Optional[Transform], - ) -> np.ndarray: - """Sample/evaluate an initial value using the existing initial values, - and with the least effect on the RNGs involved (i.e. no in-placing). - - Parameters - ---------- - rv_var : TensorVariable - The model variable the initival belongs to. - initval : Variable or None - The initial value to be evaluated. - If `None` a random draw will be made. - test_value : optional, ndarray - Fallback option if initval is None and random draws are not implemented. - This is relevant for pm.Flat or pm.HalfFlat distributions and is subject - to ongoing refactoring of the initval API. - transform : optional, Transform - A transformation associated with the random variable. - Transformations are automatically applied to initial values. - - Returns - ------- - initval : np.ndarray - Numeric (transformed) initial value. - """ - mode = get_mode(None) - opt_qry = mode.provided_optimizer.excluding("random_make_inplace") - mode = Mode(linker=mode.linker, optimizer=opt_qry) - - if transform: - if initval is not None: - value = initval - else: - value = rv_var - rv_var = at.as_tensor_variable(transform.forward(rv_var, value)) - - def initval_to_rvval(value_var, value): - rv_var = self.values_to_rvs[value_var] - initval = value_var.type.make_constant(value) - transform = getattr(value_var.tag, "transform", None) - if transform: - return transform.backward(rv_var, initval) - else: - return initval - - givens = { - self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in self.initial_values.items() - } - initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore") - try: - initval = initval_fn() - except NotImplementedError as ex: - if "Cannot sample from" in ex.args[0]: - # The RV does not have a random number generator. - # Our last chance is to take the test_value. - # Note that this is a workaround for Flat and HalfFlat - # until an initval default mechanism is implemented (#4752). - initval = test_value - else: - raise - - return initval + self.initial_values[rv_var] = initval def next_rng(self) -> RandomStateSharedVariable: """Generate a new ``RandomStateSharedVariable``. @@ -1594,25 +1522,10 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]): conditional on the values of `b` and stored in `b`. """ - # TODO FIXME XXX: If we're going to incrementally update transformed - # variables, we should do it in topological order. - for a_name, a_value in tuple(a.items()): - # If the name is a random variable, get its value variable and - # potentially transform it - var = self.named_vars.get(a_name, None) - value_var = self.rvs_to_values.get(var, None) - if value_var: - transform = getattr(value_var.tag, "transform", None) - if transform: - fval_graph = transform.forward(var, a_value) - (fval_graph,), _ = rvs_to_value_vars((fval_graph,), apply_transforms=True) - fval_graph_inputs = {i: b[i.name] for i in inputvars(fval_graph) if i.name in b} - rv_var_value = fval_graph.eval(fval_graph_inputs) - # Why are these transformed values stored in `b`? They're - # not going to be used to update `a`. - b[value_var.name] = rv_var_value - - a.update({k: v for k, v in b.items() if k not in a}) + raise DeprecationWarning( + "The `Model.update_start_vals` method was removed." + " To change initial values you may set the items of `Model.initial_values` directly." + ) def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]: """Evaluates shapes of untransformed AND transformed free variables. diff --git a/pymc/sampling.py b/pymc/sampling.py index e3d5c1ac7f..3c6da9f6f8 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -45,6 +45,12 @@ from pymc.blocking import DictToArrayBijection from pymc.distributions import NoDistribution from pymc.exceptions import IncorrectArgumentsError, SamplingError +from pymc.initial_point import ( + PointType, + StartDict, + filter_rvs_to_jitter, + make_initial_point_fns_per_chain, +) from pymc.model import Model, Point, modelcontext from pymc.parallel_sampling import Draw, _cpu_count from pymc.step_methods import ( @@ -93,7 +99,6 @@ Step = Union[BlockedStep, CompoundStep] ArrayLike = Union[np.ndarray, List[float]] -PointType = Dict[str, np.ndarray] PointList = List[PointType] Backend = Union[BaseTrace, MultiTrace, NDArray] @@ -253,7 +258,7 @@ def sample( step=None, init="auto", n_init=200_000, - initvals: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, trace: Optional[Union[BaseTrace, List[str]]] = None, chain_idx=0, chains=None, @@ -292,7 +297,7 @@ def sample( n_init : int Number of iterations of initializer. Only works for 'ADVI' init methods. initvals : optional, dict, array of dict - Dict or list of dicts with initial values to use instead of the defaults from `Model.initial_values`. + Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`. The keys should be names of transformed random variables. Initialization methods for NUTS (see ``init`` keyword) can overwrite the default. trace : backend or list @@ -420,7 +425,7 @@ def sample( if initvals is not None: raise ValueError("Passing both `start` and `initvals` is not supported.") warnings.warn( - "The `start` kwarg was renamed to `initvals`. Please check the docstring.", + "The `start` kwarg was renamed to `initvals` and can now do more. Please check the docstring.", FutureWarning, stacklevel=2, ) @@ -482,7 +487,7 @@ def sample( chains=chains, n_init=n_init, model=model, - random_seed=random_seed, + seeds=random_seed, progressbar=progressbar, jitter_max_retries=jitter_max_retries, tune=tune, @@ -501,15 +506,14 @@ def sample( step = CompoundStep(step) if initial_points is None: - initvals = initvals or {} - if isinstance(initvals, dict): - initvals = [initvals] * chains - initial_points = [] - mip = model.initial_point - for ivals in initvals: - ivals = deepcopy(ivals) - model.update_start_vals(ivals, mip) - initial_points.append(ivals) + # Time to draw/evaluate numeric start points for each chain. + ipfns = make_initial_point_fns_per_chain( + model=model, + overrides=initvals, + jitter_rvs=filter_rvs_to_jitter(step), + chains=chains, + ) + initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed)] # One final check that shapes and logps at the starting points are okay. for ip in initial_points: @@ -1234,7 +1238,7 @@ def _prepare_iter_population( raise ValueError("Argument `draws` should be above 0.") # The initialization of traces, samplers and points must happen in the right order: - # 1. traces are initialized and update_start_vals configures variable transforms + # 1. traces are initialized # 2. population of points is created # 3. steppers are initialized and linked to the points object # 4. traces are configured to track the sampler stats @@ -1245,7 +1249,7 @@ def _prepare_iter_population( # 2. create a population (points) that tracks each chain # it is updated as the chains are advanced - population = [Point(start[c], model=model) for c in range(nchains)] + population = [start[c] for c in range(nchains)] # 3. Set up the steppers steppers: List[Step] = [] @@ -1983,7 +1987,13 @@ def sample_prior_predictive( return prior -def _init_jitter(model, point, chains, jitter_max_retries): +def _init_jitter( + model: Model, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], + seeds: Sequence[int], + jitter: bool, + jitter_max_retries: int, +) -> PointType: """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. ``model.check_start_vals`` is used to test whether the jittered starting @@ -1993,9 +2003,8 @@ def _init_jitter(model, point, chains, jitter_max_retries): Parameters ---------- - model : pymc.Model - point : dict - chains : int + jitter: bool + Whether to apply jitter or not. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). @@ -2004,36 +2013,45 @@ def _init_jitter(model, point, chains, jitter_max_retries): start : ``pymc.model.Point`` Starting point for sampler """ - start = [] - for _ in range(chains): - for i in range(jitter_max_retries + 1): - mean = {var: val.copy() for var, val in point.items()} - for val in mean.values(): - val[...] += 2 * np.random.rand(*val.shape) - 1 + ipfns = make_initial_point_fns_per_chain( + model=model, + overrides=initvals, + jitter_rvs=set(model.free_RVs) if jitter else {}, + chains=len(seeds), + ) + + if not jitter: + return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] + + initial_points = [] + for ipfn, seed in zip(ipfns, seeds): + rng = np.random.RandomState(seed) + for i in range(jitter_max_retries + 1): + point = ipfn(seed) if i < jitter_max_retries: try: - model.check_start_vals(mean) + model.check_start_vals(point) except SamplingError: - pass + # Retry with a new seed + seed = rng.randint(2 ** 30, dtype=np.int64) else: break - - start.append(mean) - return start + initial_points.append(point) + return initial_points def init_nuts( + *, init="auto", chains=1, - n_init=500000, + n_init=500_000, model=None, - random_seed=None, + seeds: Sequence[int] = None, progressbar=True, jitter_max_retries=10, tune=None, - *, - initvals: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, **kwargs, ) -> Tuple[Sequence[PointType], NUTS]: """Set up the mass matrix initialization for NUTS. @@ -2076,6 +2094,8 @@ def init_nuts( n_init : int Number of iterations of initializer. Only works for 'ADVI' init methods. model : Model (optional if in ``with`` context) + seeds : list + Seed values for each chain. progressbar : bool Whether or not to display a progressbar for advi sampling. jitter_max_retries : int @@ -2109,35 +2129,45 @@ def init_nuts( if init == "auto": init = "jitter+adapt_diag" - _log.info(f"Initializing NUTS using {init}...") + if seeds is None: + seeds = model.rng_seeder.randint(2 ** 30, dtype=np.int64, size=chains) + if not isinstance(seeds, (list, tuple, np.ndarray)): + raise ValueError(f"The `seeds` must be array-like. Got {type(seeds)} instead.") + if len(seeds) != chains: + raise ValueError( + f"Number of seeds ({len(seeds)}) does not match the number of chains ({chains})." + ) - if random_seed is not None: - random_seed = int(np.atleast_1d(random_seed)[0]) + _log.info(f"Initializing NUTS using {init}...") cb = [ pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"), pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), ] - # TODO: Consider `initvals` for selecting the starting point. + initial_points = _init_jitter( + model, + initvals, + seeds=seeds, + jitter="jitter" in init, + jitter_max_retries=jitter_max_retries, + ) - apoint = DictToArrayBijection.map(model.initial_point) + apoints = [DictToArrayBijection.map(point) for point in initial_points] + apoints_data = [apoint.data for apoint in apoints] if init == "adapt_diag": - start = [model.initial_point] * chains - mean = np.mean([apoint.data] * chains, axis=0) + mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10) elif init == "jitter+adapt_diag": - start = _init_jitter(model, model.initial_point, chains, jitter_max_retries) - mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0) + mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10) elif init == "jitter+adapt_diag_grad": - start = _init_jitter(model, model.initial_point, chains, jitter_max_retries) - mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0) + mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) @@ -2155,7 +2185,7 @@ def init_nuts( ) elif init == "advi+adapt_diag": approx = pm.fit( - random_seed=random_seed, + random_seed=seeds[0], n=n_init, method="advi", model=model, @@ -2163,8 +2193,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - start = approx.sample(draws=chains) - start = list(start) + initial_points = list(approx.sample(draws=chains)) std_apoint = approx.std.eval() cov = std_apoint ** 2 mean = approx.mean.get_value() @@ -2173,7 +2202,7 @@ def init_nuts( potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight) elif init == "advi": approx = pm.fit( - random_seed=random_seed, + random_seed=seeds[0], n=n_init, method="advi", model=model, @@ -2181,41 +2210,37 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - start = approx.sample(draws=chains) - start = list(start) + initial_points = list(approx.sample(draws=chains)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "advi_map": start = pm.find_MAP(include_transformed=True) approx = pm.MeanField(model=model, start=start) pm.fit( - random_seed=random_seed, + random_seed=seeds[0], n=n_init, method=pm.KLqp(approx), callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - start = approx.sample(draws=chains) - start = list(start) + initial_points = list(approx.sample(draws=chains)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "map": start = pm.find_MAP(include_transformed=True) cov = pm.find_hessian(point=start) - start = [start] * chains + initial_points = [start] * chains potential = quadpotential.QuadPotentialFull(cov) elif init == "adapt_full": - initial_point = model.initial_point - start = [initial_point] * chains - mean = np.mean([apoint.data] * chains, axis=0) + mean = np.mean(apoints_data * chains, axis=0) + initial_point = initial_points[0] initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars) cov = np.eye(initial_point_model_size) potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10) elif init == "jitter+adapt_full": - initial_point = model.initial_point - start = _init_jitter(model, initial_point, chains, jitter_max_retries) - mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0) + mean = np.mean(apoints_data, axis=0) + initial_point = initial_points[0] initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars) cov = np.eye(initial_point_model_size) potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10) @@ -2224,25 +2249,4 @@ def init_nuts( step = pm.NUTS(potential=potential, model=model, **kwargs) - # The "start" dict determined from initialization methods does not always respect the support of variables. - # The next block combines it with the user-provided initvals such that initvals take priority. - if initvals is None or isinstance(initvals, dict): - initvals = [initvals or {}] * chains - if isinstance(start, dict): - start = [start] * chains - mip = model.initial_point - initial_points = [] - for st, iv in zip(start, initvals): - from_init = deepcopy(st) - model.update_start_vals(from_init, mip) - - from_user = deepcopy(iv) - model.update_start_vals(from_user, mip) - - initial_points.append( - { - **from_init, - **from_user, # prioritize user-provided - } - ) return initial_points, step diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 75dffe586d..486c5f3677 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -1007,7 +1007,6 @@ def test_flat(self): self.check_logp(Flat, Runif, {}, lambda value: 0) with Model(): x = Flat("a") - assert_allclose(x.tag.test_value, 0) self.check_logcdf(Flat, R, {}, lambda value: np.log(0.5)) # Check infinite cases individually. assert 0.0 == logcdf(Flat.dist(), np.inf).eval() @@ -1017,8 +1016,6 @@ def test_half_flat(self): self.check_logp(HalfFlat, Rplus, {}, lambda value: 0) with Model(): x = HalfFlat("a", size=2) - assert_allclose(x.tag.test_value, 1) - assert x.tag.test_value.shape == (2,) self.check_logcdf(HalfFlat, Rplus, {}, lambda value: -np.inf) # Check infinite cases individually. assert 0.0 == logcdf(HalfFlat.dist(), np.inf).eval() @@ -3232,9 +3229,12 @@ def test_serialize_density_dist(): def func(x): return -2 * (x ** 2).sum() + def random(rng, size): + return rng.uniform(-2, 2, size=size) + with pm.Model(): pm.Normal("x") - y = pm.DensityDist("y", logp=func) + y = pm.DensityDist("y", logp=func, random=random) pm.sample(draws=5, tune=1, mp_ctx="spawn") import cloudpickle diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index b6990c6a38..119d8f6202 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1861,8 +1861,9 @@ def test_density_dist_without_random(self): mu, logp=lambda value, mu: pm.Normal.logp(value, mu, 1), observed=np.random.randn(100), + initval=0, ) - idata = pm.sample(100, cores=1) + idata = pm.sample(tune=50, draws=100, cores=1, step=pm.Metropolis()) samples = 500 with pytest.raises(NotImplementedError): diff --git a/pymc/tests/test_initial_point.py b/pymc/tests/test_initial_point.py new file mode 100644 index 0000000000..918e37a710 --- /dev/null +++ b/pymc/tests/test_initial_point.py @@ -0,0 +1,220 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import aesara +import aesara.tensor as at +import numpy as np +import pytest + +import pymc as pm + +from pymc.distributions.distribution import get_moment +from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain + + +def transform_fwd(rv, expected_untransformed): + return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval() + + +def transform_back(rv, transformed) -> np.ndarray: + return rv.tag.value_var.tag.transform.backward(rv, transformed).eval() + + +class TestInitvalAssignment: + def test_dist_warnings_and_errors(self): + with pytest.warns(DeprecationWarning, match="argument is deprecated and has no effect"): + rv = pm.Exponential.dist(lam=1, testval=0.5) + assert not hasattr(rv.tag, "test_value") + + with pytest.raises(TypeError, match="Unexpected keyword argument `initval`."): + pm.Normal.dist(1, 2, initval=None) + pass + + def test_new_warnings(self): + with pm.Model() as pmodel: + with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"): + rv = pm.Uniform("u", 0, 1, testval=0.75) + initial_point = pmodel.recompute_initial_point(seed=0) + assert initial_point["u_interval__"] == transform_fwd(rv, 0.75) + assert not hasattr(rv.tag, "test_value") + pass + + +class TestInitvalEvaluation: + def test_make_initial_point_fns_per_chain_checks_kwargs(self): + with pm.Model() as pmodel: + A = pm.Uniform("A", 0, 1, initval=0.5) + B = pm.Uniform("B", lower=A, upper=1.5, transform=None, initval="moment") + with pytest.raises(ValueError, match="Number of initval dicts"): + make_initial_point_fns_per_chain( + model=pmodel, + overrides=[{}, None], + jitter_rvs={}, + chains=1, + ) + pass + + def test_dependent_initvals(self): + with pm.Model() as pmodel: + L = pm.Uniform("L", 0, 1, initval=0.5) + B = pm.Uniform("B", lower=L, upper=2, initval=1.25) + ip = pmodel.recompute_initial_point(seed=0) + assert ip["L_interval__"] == 0 + assert ip["B_interval__"] == 0 + + # Modify initval of L and re-evaluate + pmodel.initial_values[L] = 0.9 + ip = pmodel.recompute_initial_point(seed=0) + assert ip["B_interval__"] < 0 + pass + + def test_initval_resizing(self): + with pm.Model() as pmodel: + data = aesara.shared(np.arange(4)) + rv = pm.Uniform("u", lower=data, upper=10, initval="prior") + + ip = pmodel.recompute_initial_point(seed=0) + assert np.shape(ip["u_interval__"]) == (4,) + + data.set_value(np.arange(5)) + ip = pmodel.recompute_initial_point(seed=0) + assert np.shape(ip["u_interval__"]) == (5,) + pass + + def test_seeding(self): + with pm.Model() as pmodel: + pm.Normal("A", initval="prior") + pm.Uniform("B", initval="prior") + pm.Normal("C", initval="moment") + ip1 = pmodel.recompute_initial_point(seed=42) + ip2 = pmodel.recompute_initial_point(seed=42) + ip3 = pmodel.recompute_initial_point(seed=15) + assert ip1 == ip2 + assert ip3 != ip2 + pass + + def test_untransformed_initial_point(self): + with pm.Model() as pmodel: + pm.Flat("A", initval="moment") + pm.HalfFlat("B", initval="moment") + fn = make_initial_point_fn(model=pmodel, jitter_rvs={}, return_transformed=False) + iv = fn(0) + assert iv["A"] == 0 + assert iv["B"] == 1 + pass + + def test_adds_jitter(self): + with pm.Model() as pmodel: + A = pm.Flat("A", initval="moment") + B = pm.HalfFlat("B", initval="moment") + C = pm.Normal("C", mu=A + B, initval="moment") + fn = make_initial_point_fn(model=pmodel, jitter_rvs={B}, return_transformed=True) + iv = fn(0) + # Moment of the Flat is 0 + assert iv["A"] == 0 + # Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default + # so the transformed initial value with jitter will be zero plus a jitter between [-1, 1]. + b_transformed = iv["B_log__"] + b_untransformed = transform_back(B, b_transformed) + assert b_transformed != 0 + assert -1 < b_transformed < 1 + # C is centered on 0 + untransformed initval of B + assert np.isclose(iv["C"], np.array(0 + b_untransformed, dtype=aesara.config.floatX)) + # Test jitter respects seeding. + assert fn(0) == fn(0) + assert fn(0) != fn(1) + + def test_respects_overrides(self): + with pm.Model() as pmodel: + A = pm.Flat("A", initval="moment") + B = pm.HalfFlat("B", initval=4) + C = pm.Normal("C", mu=A + B, initval="moment") + fn = make_initial_point_fn( + model=pmodel, + jitter_rvs={}, + return_transformed=True, + overrides={ + A: at.as_tensor(2, dtype=int), + B: 3, + C: 5, + }, + ) + iv = fn(0) + assert iv["A"] == 2 + assert np.isclose(iv["B_log__"], np.log(3)) + assert iv["C"] == 5 + + def test_string_overrides_work(self): + with pm.Model() as pmodel: + A = pm.Flat("A", initval=10) + B = pm.HalfFlat("B", initval=10) + C = pm.HalfFlat("C", initval=10) + + fn = make_initial_point_fn( + model=pmodel, + jitter_rvs={}, + return_transformed=True, + overrides={ + "A": 1, + "B": 1, + "C_log__": 0, + }, + ) + iv = fn(0) + assert iv["A"] == 1 + assert np.isclose(iv["B_log__"], 0) + assert iv["C_log__"] == 0 + + +class TestMoment: + def test_basic(self): + # Standard distributions + rv = pm.Normal.dist(mu=2.3) + np.testing.assert_allclose(get_moment(rv).eval(), 2.3) + + # Special distributions + rv = pm.Flat.dist() + assert get_moment(rv).eval() == np.zeros(()) + rv = pm.HalfFlat.dist() + assert get_moment(rv).eval() == np.ones(()) + rv = pm.Flat.dist(size=(2, 4)) + assert np.all(get_moment(rv).eval() == np.zeros((2, 4))) + rv = pm.HalfFlat.dist(size=(2, 4)) + assert np.all(get_moment(rv).eval() == np.ones((2, 4))) + + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_numeric_moment_shape(self, rv_cls): + rv = rv_cls.dist(shape=(2,)) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval()) == (2,) + + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_symbolic_moment_shape(self, rv_cls): + s = at.scalar() + rv = rv_cls.dist(shape=(s,)) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval({s: 4})) == (4,) + pass + + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_moment_from_dims(self, rv_cls): + with pm.Model( + coords={ + "year": [2019, 2020, 2021, 2022], + "city": ["Bonn", "Paris", "Lisbon"], + } + ): + rv = rv_cls("rv", dims=("year", "city")) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval()) == (4, 3) + pass diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initvals.py deleted file mode 100644 index 6b4ef717a4..0000000000 --- a/pymc/tests/test_initvals.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import aesara.tensor as at -import numpy as np -import pytest - -import pymc as pm - -from pymc.distributions.distribution import get_moment - - -def transform_fwd(rv, expected_untransformed): - return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval() - - -class TestInitvalAssignment: - def test_dist_warnings_and_errors(self): - with pytest.warns(DeprecationWarning, match="argument is deprecated and has no effect"): - rv = pm.Exponential.dist(lam=1, testval=0.5) - assert not hasattr(rv.tag, "test_value") - - with pytest.raises(TypeError, match="Unexpected keyword argument `initval`."): - pm.Normal.dist(1, 2, initval=None) - pass - - def test_new_warnings(self): - with pm.Model() as pmodel: - with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"): - rv = pm.Uniform("u", 0, 1, testval=0.75) - assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75) - assert not hasattr(rv.tag, "test_value") - pass - - -class TestInitvalEvaluation: - def test_random_draws(self): - pmodel = pm.Model() - rv = pm.Uniform.dist(lower=1, upper=2) - iv = pmodel._eval_initval( - rv_var=rv, - initval=None, - test_value=None, - transform=None, - ) - assert isinstance(iv, np.ndarray) - assert 1 <= iv <= 2 - pass - - def test_applies_transform(self): - pmodel = pm.Model() - rv = pm.Uniform.dist() - tf = pm.Uniform.default_transform() - iv = pmodel._eval_initval( - rv_var=rv, - initval=0.5, - test_value=None, - transform=tf, - ) - assert isinstance(iv, np.ndarray) - assert iv == 0 - pass - - def test_falls_back_to_test_value(self): - pmodel = pm.Model() - rv = pm.Flat.dist() - iv = pmodel._eval_initval( - rv_var=rv, - initval=None, - test_value=0.6, - transform=None, - ) - assert iv == 0.6 - pass - - -class TestSpecialDistributions: - def test_automatically_assigned_test_values(self): - # ...because they don't have random number generators. - rv = pm.Flat.dist() - assert hasattr(rv.tag, "test_value") - rv = pm.HalfFlat.dist() - assert hasattr(rv.tag, "test_value") - pass - - -class TestMoment: - def test_basic(self): - # Standard distributions - rv = pm.Normal.dist(mu=2.3) - np.testing.assert_allclose(get_moment(rv).eval(), 2.3) - - # Special distributions - rv = pm.Flat.dist() - assert get_moment(rv).eval() == np.zeros(()) - rv = pm.HalfFlat.dist() - assert get_moment(rv).eval() == np.ones(()) - rv = pm.Flat.dist(size=(2, 4)) - assert np.all(get_moment(rv).eval() == np.zeros((2, 4))) - rv = pm.HalfFlat.dist(size=(2, 4)) - assert np.all(get_moment(rv).eval() == np.ones((2, 4))) - - @pytest.mark.xfail(reason="Test values are still used for initvals.") - @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) - def test_numeric_moment_shape(self, rv_cls): - rv = rv_cls.dist(shape=(2,)) - assert not hasattr(rv.tag, "test_value") - assert tuple(get_moment(rv).shape.eval()) == (2,) - - @pytest.mark.xfail(reason="Test values are still used for initvals.") - @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) - def test_symbolic_moment_shape(self, rv_cls): - s = at.scalar() - rv = rv_cls.dist(shape=(s,)) - assert not hasattr(rv.tag, "test_value") - assert tuple(get_moment(rv).shape.eval({s: 4})) == (4,) - pass - - @pytest.mark.xfail(reason="Test values are still used for initvals.") - @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) - def test_moment_from_dims(self, rv_cls): - with pm.Model( - coords={ - "year": [2019, 2020, 2021, 2022], - "city": ["Bonn", "Paris", "Lisbon"], - } - ): - rv = rv_cls("rv", dims=("year", "city")) - assert not hasattr(rv.tag, "test_value") - assert tuple(get_moment(rv).shape.eval()) == (4, 3) - pass diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 16c88e18fa..a5b1bf1487 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -513,10 +513,11 @@ def test_initial_point(): with model: y = pm.Normal("y", initval=y_initval) - assert model.rvs_to_values[a] in model.initial_values - assert model.rvs_to_values[x] in model.initial_values - assert model.initial_values[b_value_var] == b_initval_trans - assert model.initial_values[model.rvs_to_values[y]] == y_initval + assert a in model.initial_values + assert x in model.initial_values + assert model.initial_values[b] == b_initval + assert model.recompute_initial_point(0)["b_interval__"] == b_initval_trans + assert model.initial_values[y] == y_initval def test_point_logps(): @@ -532,68 +533,6 @@ def test_point_logps(): assert "a" in logp_vals.keys() -class TestUpdateStartVals(SeededTest): - def setup_method(self): - super().setup_method() - - def test_soft_update_all_present(self): - model = pm.Model() - start = {"a": 1, "b": 2} - test_point = {"a": 3, "b": 4} - model.update_start_vals(start, test_point) - assert start == {"a": 1, "b": 2} - - def test_soft_update_one_missing(self): - model = pm.Model() - start = { - "a": 1, - } - test_point = {"a": 3, "b": 4} - model.update_start_vals(start, test_point) - assert start == {"a": 1, "b": 4} - - def test_soft_update_empty(self): - model = pm.Model() - start = {} - test_point = {"a": 3, "b": 4} - model.update_start_vals(start, test_point) - assert start == test_point - - def test_soft_update_transformed(self): - with pm.Model() as model: - pm.Exponential("a", 1) - start = {"a": 2.0} - test_point = {"a_log__": 0} - model.update_start_vals(start, test_point) - assert_almost_equal(np.exp(start["a_log__"]), start["a"]) - - def test_soft_update_parent(self): - with pm.Model() as model: - a = pm.Uniform("a", lower=0.0, upper=1.0) - b = pm.Uniform("b", lower=2.0, upper=3.0) - pm.Uniform("lower", lower=a, upper=3.0) - pm.Uniform("upper", lower=0.0, upper=b) - pm.Uniform("interv", lower=a, upper=b) - - initial_point = { - "a_interval__": np.array(0.0, dtype=aesara.config.floatX), - "b_interval__": np.array(0.0, dtype=aesara.config.floatX), - "lower_interval__": np.array(0.0, dtype=aesara.config.floatX), - "upper_interval__": np.array(0.0, dtype=aesara.config.floatX), - "interv_interval__": np.array(0.0, dtype=aesara.config.floatX), - } - start = {"a": 0.3, "b": 2.1, "lower": 1.4, "upper": 1.4, "interv": 1.4} - test_point = { - "lower_interval__": -0.3746934494414109, - "upper_interval__": 0.693147180559945, - "interv_interval__": 0.4519851237430569, - } - model.update_start_vals(start, initial_point) - assert_almost_equal(start["lower_interval__"], test_point["lower_interval__"]) - assert_almost_equal(start["upper_interval__"], test_point["upper_interval__"]) - assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"]) - - class TestShapeEvaluation: def test_eval_rv_shapes(self): with pm.Model( @@ -625,8 +564,10 @@ def test_valid_start_point(self): a = pm.Uniform("a", lower=0.0, upper=1.0) b = pm.Uniform("b", lower=2.0, upper=3.0) - start = {"a": 0.3, "b": 2.1} - model.update_start_vals(start, model.initial_point) + start = { + "a_interval__": model.rvs_to_values[a].tag.transform.forward(a, 0.3).eval(), + "b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(), + } model.check_start_vals(start) def test_invalid_start_point(self): @@ -634,8 +575,10 @@ def test_invalid_start_point(self): a = pm.Uniform("a", lower=0.0, upper=1.0) b = pm.Uniform("b", lower=2.0, upper=3.0) - start = {"a": np.nan, "b": np.nan} - model.update_start_vals(start, model.initial_point) + start = { + "a_interval__": np.nan, + "b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(), + } with pytest.raises(pm.exceptions.SamplingError): model.check_start_vals(start) @@ -644,8 +587,11 @@ def test_invalid_variable_name(self): a = pm.Uniform("a", lower=0.0, upper=1.0) b = pm.Uniform("b", lower=2.0, upper=3.0) - start = {"a": 0.3, "b": 2.1, "c": 1.0} - model.update_start_vals(start, model.initial_point) + start = { + "a_interval__": model.rvs_to_values[a].tag.transform.forward(a, 0.3).eval(), + "b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(), + "c": 1.0, + } with pytest.raises(KeyError): model.check_start_vals(start) @@ -661,9 +607,9 @@ def test_set_initval(): alpha = pm.HalfNormal("alpha", initval=100) value = pm.NegativeBinomial("value", mu=mu, alpha=alpha) - assert np.array_equal(model.initial_values[model.rvs_to_values[mu]], np.array([[100.0]])) - np.testing.assert_almost_equal(model.initial_values[model.rvs_to_values[alpha]], np.log(100)) - assert 50 < model.initial_values[model.rvs_to_values[value]] < 150 + assert np.array_equal(model.initial_values[mu], np.array([[100.0]])) + np.testing.assert_array_equal(model.initial_values[alpha], np.array(100)) + assert model.initial_values[value] is None # `Flat` cannot be sampled, so let's make sure that doesn't break initial # value computations @@ -671,7 +617,7 @@ def test_set_initval(): x = pm.Flat("x") y = pm.Normal("y", x, 1) - assert model.rvs_to_values[y] in model.initial_values + assert y in model.initial_values def test_datalogpt_multiple_shapes(): diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py index 30e1fafbcf..e3edbd1fe7 100644 --- a/pymc/tests/test_ndarray_backend.py +++ b/pymc/tests/test_ndarray_backend.py @@ -221,6 +221,9 @@ def setup_class(cls): with TestSaveLoad.model(): cls.trace = pm.sample(return_inferencedata=False) + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_save_new_model(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) save_dir = pm.save_trace(self.trace, directory, overwrite=True) @@ -239,6 +242,9 @@ def test_save_new_model(self, tmpdir_factory): assert (new_trace["w"] == new_trace_copy["w"]).all() + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_save_and_load(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) save_dir = pm.save_trace(self.trace, directory, overwrite=True) @@ -256,11 +262,17 @@ def test_save_and_load(self, tmpdir_factory): "Restored value of statistic %s does not match stored value" % stat ) + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_bad_load(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) with pytest.raises(pm.TraceDirectoryError): pm.load_trace(directory, model=TestSaveLoad.model()) + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_sample_posterior_predictive(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) save_dir = pm.save_trace(self.trace, directory, overwrite=True) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 132781815d..3265802d37 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -39,6 +39,19 @@ from pymc.tests.models import simple_init +class TestInitNuts(SeededTest): + def setup_method(self): + super().setup_method() + self.model, self.start, self.step, _ = simple_init() + + def test_checks_seeds_kwarg(self): + with self.model: + with pytest.raises(ValueError, match="must be array-like"): + pm.sampling.init_nuts(seeds=1) + with pytest.raises(ValueError, match="Number of seeds"): + pm.sampling.init_nuts(chains=2, seeds=[1]) + + class TestSample(SeededTest): def setup_method(self): super().setup_method() @@ -160,7 +173,7 @@ def test_reset_tuning(self): with self.model: tune = 50 chains = 2 - start, step = pm.sampling.init_nuts(chains=chains) + start, step = pm.sampling.init_nuts(chains=chains, seeds=[1, 2]) pm.sample(draws=2, tune=tune, chains=chains, step=step, start=start, cores=1) assert step.potential._n_samples == tune assert step.step_adapt._count == tune + 1 @@ -301,7 +314,6 @@ def test_exceptions(self): xvars = [t["mu"] for t in trace] -@pytest.mark.xfail(reason="Lognormal not refactored for v4") def test_sample_find_MAP_does_not_modify_start(): # see https://github.com/pymc-devs/pymc/pull/4458 with pm.Model(): @@ -831,13 +843,13 @@ def check_exec_nuts_init(method): pm.Normal("a", mu=0, sigma=1, size=2) pm.HalfNormal("b", sigma=1) with model: - start, _ = pm.init_nuts(init=method, n_init=10) + start, _ = pm.init_nuts(init=method, n_init=10, seeds=[1]) assert isinstance(start, list) assert len(start) == 1 assert isinstance(start[0], dict) assert model.a.tag.value_var.name in start[0] assert model.b.tag.value_var.name in start[0] - start, _ = pm.init_nuts(init=method, n_init=10, chains=2) + start, _ = pm.init_nuts(init=method, n_init=10, chains=2, seeds=[1, 2]) assert isinstance(start, list) assert len(start) == 2 assert isinstance(start[0], dict) @@ -873,6 +885,7 @@ def test_exec_nuts_init(method): check_exec_nuts_init(method) +@pytest.mark.skip(reason="Test requires monkey patching of RandomGenerator") @pytest.mark.parametrize( "initval, jitter_max_retries, expectation", [ @@ -890,9 +903,13 @@ def test_init_jitter(initval, jitter_max_retries, expectation): with expectation: # Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1) # and positive (valid) when it returns 1 (jitter = 1) - with mock.patch("numpy.random.rand", side_effect=[0, 0, 0, 1, 0]): + with mock.patch("numpy.random.Generator.uniform", side_effect=[-1, -1, -1, 1, -1]): start = pm.sampling._init_jitter( - m, m.initial_point, chains=1, jitter_max_retries=jitter_max_retries + model=m, + initvals=None, + seeds=[1], + jitter=True, + jitter_max_retries=jitter_max_retries, ) m.check_start_vals(start) diff --git a/pymc/tests/test_tuning.py b/pymc/tests/test_tuning.py index 9686e265e1..e8e37978ab 100644 --- a/pymc/tests/test_tuning.py +++ b/pymc/tests/test_tuning.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pytest from numpy import inf @@ -33,17 +34,13 @@ def test_guess_scaling(): assert all((a1 > 0) & (a1 < 1e200)) -def test_mle_jacobian(): +@pytest.mark.parametrize("bounded", [False, True]) +def test_mle_jacobian(bounded): """Test MAP / MLE estimation for distributions with flat priors.""" truth = 10.0 # Simple normal model should give mu=10.0 - rtol = 1e-5 # this rtol should work on both floatX precisions + rtol = 1e-4 # this rtol should work on both floatX precisions - start, model, _ = models.simple_normal(bounded_prior=False) - with model: - map_estimate = find_MAP(method="BFGS", model=model) - np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol) - - start, model, _ = models.simple_normal(bounded_prior=True) + start, model, _ = models.simple_normal(bounded_prior=bounded) with model: map_estimate = find_MAP(method="BFGS", model=model) np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 1035d9e0ad..3cd3613a35 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -17,9 +17,10 @@ @author: johnsalvatier """ -import copy import sys +from typing import Optional + import aesara.gradient as tg import numpy as np @@ -31,7 +32,8 @@ from pymc.aesaraf import inputvars from pymc.blocking import DictToArrayBijection, RaveledVars -from pymc.model import Point, modelcontext +from pymc.initial_point import make_initial_point_fn +from pymc.model import modelcontext from pymc.util import get_default_varnames, get_var_name from pymc.vartypes import discrete_types, typefilter @@ -48,6 +50,7 @@ def find_MAP( maxeval=5000, model=None, *args, + seed: Optional[int] = None, **kwargs ): """Finds the local maximum a posteriori point given a model. @@ -95,15 +98,17 @@ def find_MAP( vars = inputvars(vars) disc_vars = list(typefilter(vars, discrete_types)) allinmodel(vars, model) - start = copy.deepcopy(start) - if start is None: - start = model.initial_point - else: - model.update_start_vals(start, model.initial_point) + ipfn = make_initial_point_fn( + model=model, + jitter_rvs={}, + return_transformed=True, + overrides=start, + ) + if seed is None: + seed = model.rng_seeder.randint(2 ** 30, dtype=np.int64) + start = ipfn(seed) model.check_start_vals(start) - start = Point(start, model=model) - x0 = DictToArrayBijection.map(start) # TODO: If the mapping is fixed, we can simply create graphs for the diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 5392e63133..54900d1ec8 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -23,6 +23,7 @@ from pymc.blocking import DictToArrayBijection from pymc.distributions.dist_math import rho2sigma +from pymc.initial_point import make_initial_point_fn from pymc.math import batched_diag from pymc.variational import flows, opvi from pymc.variational.opvi import Approximation, Group, node_property @@ -69,12 +70,13 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - if start is None: - start = self.model.initial_point - else: - start_ = start.copy() - self.model.update_start_vals(start_, self.model.initial_point) - start = start_ + ipfn = make_initial_point_fn( + model=self.model, + overrides=start, + jitter_rvs={}, + return_transformed=True, + ) + start = ipfn(self.model.rng_seeder.randint(2 ** 30, dtype=np.int64)) if self.batched: start = start[self.group[0].name][0] else: @@ -124,12 +126,13 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - if start is None: - start = self.model.initial_point - else: - start_ = start.copy() - self.model.update_start_vals(start_, self.model.initial_point) - start = start_ + ipfn = make_initial_point_fn( + model=self.model, + overrides=start, + jitter_rvs={}, + return_transformed=True, + ) + start = ipfn(self.model.rng_seeder.randint(2 ** 30, dtype=np.int64)) if self.batched: start = start[self.group[0].name][0] else: @@ -238,12 +241,13 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): if size is None: raise opvi.ParametrizationError("Need `trace` or `size` to initialize") else: - if start is None: - start = self.model.initial_point - else: - start_ = self.model.initial_point.copy() - self.model.update_start_vals(start_, start) - start = start_ + ipfn = make_initial_point_fn( + model=self.model, + overrides=start, + jitter_rvs={}, + return_transformed=True, + ) + start = ipfn(self.model.rng_seeder.randint(2 ** 30, dtype=np.int64)) start = pm.floatX(DictToArrayBijection.map(start)) # Initialize particles histogram = np.tile(start, (size, 1)) diff --git a/requirements-dev.txt b/requirements-dev.txt index d7691ddb7c..0f898473d2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. -aesara>=2.1.0 +aesara>=2.2.2 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle diff --git a/requirements.txt b/requirements.txt index 28c9e4b3b2..87066f3532 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aesara>=2.1.0 +aesara>=2.2.2 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle