diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 7dbed2ab4..c5b8a80cf 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -571,28 +571,6 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data): np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2) -def test_fit_start(inference_spec, simple_model): - mu_init = 17 - mu_sigma_init = 13 - - with simple_model: - if type(inference_spec()) == ADVI: - has_start_sigma = True - else: - has_start_sigma = False - - kw = {"start": {"mu": mu_init}} - if has_start_sigma: - kw.update({"start_sigma": {"mu": mu_sigma_init}}) - - with simple_model: - inference = inference_spec(**kw) - trace = inference.fit(n=0).sample(10000) - np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05) - if has_start_sigma: - np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05) - - def test_profile(inference): inference.run_profiling(n=100).summary() diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index cace8f03b..c0b9c7e3d 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -67,27 +67,12 @@ def std(self): def __init_group__(self, group): super().__init_group__(group) if not self._check_user_params(): - self.shared_params = self.create_shared_params( - self._kwargs.get("start", None), self._kwargs.get("start_sigma", None) - ) + self.shared_params = self.create_shared_params(self._kwargs.get("start", None)) self._finalize_init() - def create_shared_params(self, start=None, start_sigma=None): - # NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and - # `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free - # variables are given by `self.group` and that the mapping between original variables and flat array is given - # by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or - # future code changes that break this assumption. + def create_shared_params(self, start=None): start = self._prepare_start(start) - rho1 = np.zeros((self.ddim,)) - - if start_sigma is not None: - for name, slice_, *_ in self.ordering.values(): - sigma = start_sigma.get(name) - if sigma is not None: - rho1[slice_] = np.log(np.expm1(np.abs(sigma))) - rho = rho1 - + rho = np.zeros((self.ddim,)) return { "mu": aesara.shared(pm.floatX(start), "mu"), "rho": aesara.shared(pm.floatX(rho), "rho"), diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index e932df1e5..88b5e7b74 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -257,9 +257,7 @@ def _infmean(input_array): ) ) else: - if n == 0: - logger.info(f"Initialization only") - elif n < 10: + if n < 10: logger.info(f"Finished [100%]: Loss = {scores[-1]:,.5g}") else: avg_loss = _infmean(scores[max(0, i - 1000) : i + 1]) @@ -435,10 +433,8 @@ class ADVI(KLqp): random_seed: None or int leave None to use package global RandomStream or other valid value to create instance specific one - start: `dict[str, np.ndarray]` or `StartDict` + start: `Point` starting point for inference - start_sigma: `dict[str, np.ndarray]` - starting standard deviation for inference, only available for method 'advi' References ---------- @@ -468,7 +464,7 @@ class FullRankADVI(KLqp): random_seed: None or int leave None to use package global RandomStream or other valid value to create instance specific one - start: `dict[str, np.ndarray]` or `StartDict` + start: `Point` starting point for inference References @@ -536,11 +532,13 @@ class SVGD(ImplicitGradient): kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))` temperature: float parameter responsible for exploration, higher temperature gives more broad posterior estimate - start: `dict[str, np.ndarray]` or `StartDict` + start: `dict` initial point for inference random_seed: None or int leave None to use package global RandomStream or other valid value to create instance specific one + start: `Point` + starting point for inference kwargs: other keyword arguments passed to estimator References @@ -631,11 +629,7 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar "is often **underestimated** when using temperature = 1." ) if approx is None: - approx = FullRank( - model=kwargs.pop("model", None), - random_seed=kwargs.pop("random_seed", None), - start=kwargs.pop("start", None), - ) + approx = FullRank(model=kwargs.pop("model", None)) super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs) def fit( @@ -666,7 +660,6 @@ def fit( model=None, random_seed=None, start=None, - start_sigma=None, inf_kwargs=None, **kwargs, ): @@ -691,10 +684,8 @@ def fit( valid value to create instance specific one inf_kwargs: dict additional kwargs passed to :class:`Inference` - start: `dict[str, np.ndarray]` or `StartDict` + start: `Point` starting point for inference - start_sigma: `dict[str, np.ndarray]` - starting standard deviation for inference, only available for method 'advi' Other Parameters ---------------- @@ -737,10 +728,6 @@ def fit( inf_kwargs["random_seed"] = random_seed if start is not None: inf_kwargs["start"] = start - if start_sigma is not None: - if method != "advi": - raise NotImplementedError("start_sigma is only available for method advi") - inf_kwargs["start_sigma"] = start_sigma if model is None: model = pm.modelcontext(model) _select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)