diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index c5b8a80cf..1e5eb2f9e 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -499,6 +499,63 @@ def simple_model(simple_model_data): return model +@pytest.fixture +def hierarchical_model_data(): + group_coords = { + "group_d1": np.arange(3), + "group_d2": np.arange(7), + } + group_shape = tuple(len(d) for d in group_coords.values()) + data_coords = {"data_d": np.arange(11)} | group_coords + + data_shape = tuple(len(d) for d in data_coords.values()) + + mu = -5.0 + + sigma_group_mu = 3 + group_mu = sigma_group_mu * np.random.randn(*group_shape) + + sigma = 3.0 + + data = sigma * np.random.randn(*data_shape) + group_mu + mu + + return dict( + group_coords=group_coords, + group_shape=group_shape, + data_coords=data_coords, + data_shape=data_shape, + mu=mu, + sigma_group_mu=sigma_group_mu, + sigma=sigma, + group_mu=group_mu, + data=data, + ) + + +@pytest.fixture +def hierarchical_model(hierarchical_model_data): + with pm.Model(coords=hierarchical_model_data["data_coords"]) as model: + mu = pm.Normal("mu", mu=0, sigma=10) + sigma_group_mu = pm.HalfNormal("sigma_group_mu", sigma=5) + + group_mu = pm.Normal( + "group_mu", + mu=0, + sigma=sigma_group_mu, + dims=list(hierarchical_model_data["group_coords"].keys()), + ) + + sigma = pm.HalfNormal("sigma", sigma=3) + + pm.Normal( + "data", + mu=(mu + group_mu), + sigma=sigma, + observed=hierarchical_model_data["data"], + ) + return model + + @pytest.fixture( scope="module", params=[ @@ -571,6 +628,27 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data): np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2) +def test_fit_data(inference, fit_kwargs, simple_model_data): + fitted = inference.fit(**fit_kwargs) + mu_post = simple_model_data["mu_post"] + d = simple_model_data["d"] + np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05) + np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2) + + +def test_fit_data_coords(hierarchical_model, hierarchical_model_data): + with hierarchical_model: + fitted = pm.fit(1) + + for data in [fitted.mean_data, fitted.std_data]: + assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"} + assert data["group_mu"].shape == hierarchical_model_data["group_shape"] + assert list(data["group_mu"].coords.keys()) == list( + hierarchical_model_data["group_coords"].keys() + ) + assert data["mu"].shape == tuple() + + def test_profile(inference): inference.run_profiling(n=100).summary() diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0e6cc7f65..b59f339f0 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -52,6 +52,7 @@ import aesara import aesara.tensor as at import numpy as np +import xarray from aesara.graph.basic import Variable @@ -1109,16 +1110,47 @@ def __str__(self): @node_property def std(self): + """Standard deviation of the latent variables as an unstructured 1-dimensional Aesara variable""" raise NotImplementedError @node_property def cov(self): + """Covariance between the latent variables as an unstructured 2-dimensional Aesara variable""" raise NotImplementedError @node_property def mean(self): + """Mean of the latent variables as an unstructured 1-dimensional Aesara variable""" raise NotImplementedError + def var_to_data(self, shared): + """Takes a flat 1-dimensional Aesara variable and maps it to an xarray data set based on the information in + `self.ordering`. + """ + # This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have + # `RaveledVars` and need to take the information from `self.ordering` instead + shared = shared.eval() + result = dict() + for name, s, shape, dtype in self.ordering.values(): + dims = self.model.RV_dims.get(name, None) + if dims is not None: + coords = {d: np.array(self.model.coords[d]) for d in dims} + else: + coords = None + values = np.array(shared[s]).reshape(shape).astype(dtype) + result[name] = xarray.DataArray(values, coords=coords, dims=dims, name=name) + return xarray.Dataset(result) + + @property + def mean_data(self): + """Mean of the latent variables as an xarray Dataset""" + return self.var_to_data(self.mean) + + @property + def std_data(self): + """Standard deviation of the latent variables as an xarray Dataset""" + return self.var_to_data(self.std) + group_for_params = Group.group_for_params group_for_short_name = Group.group_for_short_name