diff --git a/pymc/tests/variational/test_inference.py b/pymc/tests/variational/test_inference.py index 238b0ea091..8c29760fb0 100644 --- a/pymc/tests/variational/test_inference.py +++ b/pymc/tests/variational/test_inference.py @@ -350,3 +350,81 @@ def test_clear_cache(): assert any(len(c) != 0 for c in inference_new.approx._cache.values()) inference_new.approx._cache.clear() assert all(len(c) == 0 for c in inference_new.approx._cache.values()) + + +def test_fit_data(inference, fit_kwargs, simple_model_data): + fitted = inference.fit(**fit_kwargs) + mu_post = simple_model_data["mu_post"] + d = simple_model_data["d"] + np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05) + np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2) + + +@pytest.fixture +def hierarchical_model_data(): + group_coords = { + "group_d1": np.arange(3), + "group_d2": np.arange(7), + } + group_shape = tuple(len(d) for d in group_coords.values()) + data_coords = {"data_d": np.arange(11), **group_coords} + + data_shape = tuple(len(d) for d in data_coords.values()) + + mu = -5.0 + + sigma_group_mu = 3 + group_mu = sigma_group_mu * np.random.randn(*group_shape) + + sigma = 3.0 + + data = sigma * np.random.randn(*data_shape) + group_mu + mu + + return dict( + group_coords=group_coords, + group_shape=group_shape, + data_coords=data_coords, + data_shape=data_shape, + mu=mu, + sigma_group_mu=sigma_group_mu, + sigma=sigma, + group_mu=group_mu, + data=data, + ) + + +@pytest.fixture +def hierarchical_model(hierarchical_model_data): + with pm.Model(coords=hierarchical_model_data["data_coords"]) as model: + mu = pm.Normal("mu", mu=0, sigma=10) + sigma_group_mu = pm.HalfNormal("sigma_group_mu", sigma=5) + + group_mu = pm.Normal( + "group_mu", + mu=0, + sigma=sigma_group_mu, + dims=list(hierarchical_model_data["group_coords"].keys()), + ) + + sigma = pm.HalfNormal("sigma", sigma=3) + + pm.Normal( + "data", + mu=(mu + group_mu), + sigma=sigma, + observed=hierarchical_model_data["data"], + ) + return model + + +def test_fit_data_coords(hierarchical_model, hierarchical_model_data): + with hierarchical_model: + fitted = pm.fit(1) + + for data in [fitted.mean_data, fitted.std_data]: + assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"} + assert data["group_mu"].shape == hierarchical_model_data["group_shape"] + assert list(data["group_mu"].coords.keys()) == list( + hierarchical_model_data["group_coords"].keys() + ) + assert data["mu"].shape == tuple() diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 80733bf1e5..c2f69ec9ec 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -54,6 +54,7 @@ import numpy as np import pytensor import pytensor.tensor as at +import xarray from pytensor.graph.basic import Variable @@ -977,7 +978,7 @@ def symbolic_random(self): @pytensor.config.change_flags(compute_test_value="off") def set_size_and_deterministic( - self, node: Variable, s, d: bool, more_replacements: dict | None = None + self, node: Variable, s, d: bool, more_replacements: dict = None ) -> list[Variable]: """*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or :func:`symbolic_single_sample` new random generator can be allocated and applied to node @@ -1105,16 +1106,47 @@ def __str__(self): return f"{self.__class__.__name__}[{shp}]" @node_property - def std(self): - raise NotImplementedError + def std(self) -> at.TensorVariable: + """Standard deviation of the latent variables as an unstructured 1-dimensional tensor variable""" + raise NotImplementedError() @node_property - def cov(self): - raise NotImplementedError + def cov(self) -> at.TensorVariable: + """Covariance between the latent variables as an unstructured 2-dimensional tensor variable""" + raise NotImplementedError() @node_property - def mean(self): - raise NotImplementedError + def mean(self) -> at.TensorVariable: + """Mean of the latent variables as an unstructured 1-dimensional tensor variable""" + raise NotImplementedError() + + def var_to_data(self, shared: at.TensorVariable) -> xarray.Dataset: + """Takes a flat 1-dimensional tensor variable and maps it to an xarray data set based on the information in + `self.ordering`. + """ + # This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have + # `RaveledVars` and need to take the information from `self.ordering` instead + shared_nda = shared.eval() + result = dict() + for name, s, shape, dtype in self.ordering.values(): + dims = self.model.named_vars_to_dims.get(name, None) + if dims is not None: + coords = {d: np.array(self.model.coords[d]) for d in dims} + else: + coords = None + values = shared_nda[s].reshape(shape).astype(dtype) + result[name] = xarray.DataArray(values, coords=coords, dims=dims, name=name) + return xarray.Dataset(result) + + @property + def mean_data(self) -> xarray.Dataset: + """Mean of the latent variables as an xarray Dataset""" + return self.var_to_data(self.mean) + + @property + def std_data(self) -> xarray.Dataset: + """Standard deviation of the latent variables as an xarray Dataset""" + return self.var_to_data(self.std) group_for_params = Group.group_for_params