diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index c5b8a80cf4..7dbed2ab48 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -571,6 +571,28 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data): np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2) +def test_fit_start(inference_spec, simple_model): + mu_init = 17 + mu_sigma_init = 13 + + with simple_model: + if type(inference_spec()) == ADVI: + has_start_sigma = True + else: + has_start_sigma = False + + kw = {"start": {"mu": mu_init}} + if has_start_sigma: + kw.update({"start_sigma": {"mu": mu_sigma_init}}) + + with simple_model: + inference = inference_spec(**kw) + trace = inference.fit(n=0).sample(10000) + np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05) + if has_start_sigma: + np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05) + + def test_profile(inference): inference.run_profiling(n=100).summary() diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index c0b9c7e3d6..cace8f03b7 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -67,12 +67,27 @@ def std(self): def __init_group__(self, group): super().__init_group__(group) if not self._check_user_params(): - self.shared_params = self.create_shared_params(self._kwargs.get("start", None)) + self.shared_params = self.create_shared_params( + self._kwargs.get("start", None), self._kwargs.get("start_sigma", None) + ) self._finalize_init() - def create_shared_params(self, start=None): + def create_shared_params(self, start=None, start_sigma=None): + # NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and + # `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free + # variables are given by `self.group` and that the mapping between original variables and flat array is given + # by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or + # future code changes that break this assumption. start = self._prepare_start(start) - rho = np.zeros((self.ddim,)) + rho1 = np.zeros((self.ddim,)) + + if start_sigma is not None: + for name, slice_, *_ in self.ordering.values(): + sigma = start_sigma.get(name) + if sigma is not None: + rho1[slice_] = np.log(np.expm1(np.abs(sigma))) + rho = rho1 + return { "mu": aesara.shared(pm.floatX(start), "mu"), "rho": aesara.shared(pm.floatX(rho), "rho"), diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 88b5e7b744..e932df1e57 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -257,7 +257,9 @@ def _infmean(input_array): ) ) else: - if n < 10: + if n == 0: + logger.info(f"Initialization only") + elif n < 10: logger.info(f"Finished [100%]: Loss = {scores[-1]:,.5g}") else: avg_loss = _infmean(scores[max(0, i - 1000) : i + 1]) @@ -433,8 +435,10 @@ class ADVI(KLqp): random_seed: None or int leave None to use package global RandomStream or other valid value to create instance specific one - start: `Point` + start: `dict[str, np.ndarray]` or `StartDict` starting point for inference + start_sigma: `dict[str, np.ndarray]` + starting standard deviation for inference, only available for method 'advi' References ---------- @@ -464,7 +468,7 @@ class FullRankADVI(KLqp): random_seed: None or int leave None to use package global RandomStream or other valid value to create instance specific one - start: `Point` + start: `dict[str, np.ndarray]` or `StartDict` starting point for inference References @@ -532,13 +536,11 @@ class SVGD(ImplicitGradient): kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))` temperature: float parameter responsible for exploration, higher temperature gives more broad posterior estimate - start: `dict` + start: `dict[str, np.ndarray]` or `StartDict` initial point for inference random_seed: None or int leave None to use package global RandomStream or other valid value to create instance specific one - start: `Point` - starting point for inference kwargs: other keyword arguments passed to estimator References @@ -629,7 +631,11 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar "is often **underestimated** when using temperature = 1." ) if approx is None: - approx = FullRank(model=kwargs.pop("model", None)) + approx = FullRank( + model=kwargs.pop("model", None), + random_seed=kwargs.pop("random_seed", None), + start=kwargs.pop("start", None), + ) super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs) def fit( @@ -660,6 +666,7 @@ def fit( model=None, random_seed=None, start=None, + start_sigma=None, inf_kwargs=None, **kwargs, ): @@ -684,8 +691,10 @@ def fit( valid value to create instance specific one inf_kwargs: dict additional kwargs passed to :class:`Inference` - start: `Point` + start: `dict[str, np.ndarray]` or `StartDict` starting point for inference + start_sigma: `dict[str, np.ndarray]` + starting standard deviation for inference, only available for method 'advi' Other Parameters ---------------- @@ -728,6 +737,10 @@ def fit( inf_kwargs["random_seed"] = random_seed if start is not None: inf_kwargs["start"] = start + if start_sigma is not None: + if method != "advi": + raise NotImplementedError("start_sigma is only available for method advi") + inf_kwargs["start_sigma"] = start_sigma if model is None: model = pm.modelcontext(model) _select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)