Skip to content

Commit 26048a4

Browse files
fonnesbeckmarkusschmausmichaelosthege
authored
Make VI (posterior) mean and std accessible as a structured xarray (#6387)
* make vi (posterior) mean and std accessible as a structured xarray data set * add doc strings * convert to original dtype * add docstring and comment * test `mean_data` and `std_data` for shape and coords * Moved tests to tests/variational * Formatting * Updated deprecated function; removed dict comparison operator * Removed bad dict operator in tests * Fixed type hint * Alternative type hint * Property return type hints Co-authored-by: Michael Osthege <[email protected]> * Property return type hints Co-authored-by: Michael Osthege <[email protected]> * Removed clobbering of shared reference Co-authored-by: Michael Osthege <[email protected]> * Add type hint Co-authored-by: Michael Osthege <[email protected]> --------- Co-authored-by: Markus Schmaus <[email protected]> Co-authored-by: Michael Osthege <[email protected]>
1 parent f2a7174 commit 26048a4

File tree

2 files changed

+117
-7
lines changed

2 files changed

+117
-7
lines changed

pymc/tests/variational/test_inference.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,81 @@ def test_clear_cache():
350350
assert any(len(c) != 0 for c in inference_new.approx._cache.values())
351351
inference_new.approx._cache.clear()
352352
assert all(len(c) == 0 for c in inference_new.approx._cache.values())
353+
354+
355+
def test_fit_data(inference, fit_kwargs, simple_model_data):
356+
fitted = inference.fit(**fit_kwargs)
357+
mu_post = simple_model_data["mu_post"]
358+
d = simple_model_data["d"]
359+
np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05)
360+
np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2)
361+
362+
363+
@pytest.fixture
364+
def hierarchical_model_data():
365+
group_coords = {
366+
"group_d1": np.arange(3),
367+
"group_d2": np.arange(7),
368+
}
369+
group_shape = tuple(len(d) for d in group_coords.values())
370+
data_coords = {"data_d": np.arange(11), **group_coords}
371+
372+
data_shape = tuple(len(d) for d in data_coords.values())
373+
374+
mu = -5.0
375+
376+
sigma_group_mu = 3
377+
group_mu = sigma_group_mu * np.random.randn(*group_shape)
378+
379+
sigma = 3.0
380+
381+
data = sigma * np.random.randn(*data_shape) + group_mu + mu
382+
383+
return dict(
384+
group_coords=group_coords,
385+
group_shape=group_shape,
386+
data_coords=data_coords,
387+
data_shape=data_shape,
388+
mu=mu,
389+
sigma_group_mu=sigma_group_mu,
390+
sigma=sigma,
391+
group_mu=group_mu,
392+
data=data,
393+
)
394+
395+
396+
@pytest.fixture
397+
def hierarchical_model(hierarchical_model_data):
398+
with pm.Model(coords=hierarchical_model_data["data_coords"]) as model:
399+
mu = pm.Normal("mu", mu=0, sigma=10)
400+
sigma_group_mu = pm.HalfNormal("sigma_group_mu", sigma=5)
401+
402+
group_mu = pm.Normal(
403+
"group_mu",
404+
mu=0,
405+
sigma=sigma_group_mu,
406+
dims=list(hierarchical_model_data["group_coords"].keys()),
407+
)
408+
409+
sigma = pm.HalfNormal("sigma", sigma=3)
410+
411+
pm.Normal(
412+
"data",
413+
mu=(mu + group_mu),
414+
sigma=sigma,
415+
observed=hierarchical_model_data["data"],
416+
)
417+
return model
418+
419+
420+
def test_fit_data_coords(hierarchical_model, hierarchical_model_data):
421+
with hierarchical_model:
422+
fitted = pm.fit(1)
423+
424+
for data in [fitted.mean_data, fitted.std_data]:
425+
assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"}
426+
assert data["group_mu"].shape == hierarchical_model_data["group_shape"]
427+
assert list(data["group_mu"].coords.keys()) == list(
428+
hierarchical_model_data["group_coords"].keys()
429+
)
430+
assert data["mu"].shape == tuple()

pymc/variational/opvi.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import numpy as np
5555
import pytensor
5656
import pytensor.tensor as at
57+
import xarray
5758

5859
from pytensor.graph.basic import Variable
5960

@@ -977,7 +978,7 @@ def symbolic_random(self):
977978

978979
@pytensor.config.change_flags(compute_test_value="off")
979980
def set_size_and_deterministic(
980-
self, node: Variable, s, d: bool, more_replacements: dict | None = None
981+
self, node: Variable, s, d: bool, more_replacements: dict = None
981982
) -> list[Variable]:
982983
"""*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or
983984
:func:`symbolic_single_sample` new random generator can be allocated and applied to node
@@ -1105,16 +1106,47 @@ def __str__(self):
11051106
return f"{self.__class__.__name__}[{shp}]"
11061107

11071108
@node_property
1108-
def std(self):
1109-
raise NotImplementedError
1109+
def std(self) -> at.TensorVariable:
1110+
"""Standard deviation of the latent variables as an unstructured 1-dimensional tensor variable"""
1111+
raise NotImplementedError()
11101112

11111113
@node_property
1112-
def cov(self):
1113-
raise NotImplementedError
1114+
def cov(self) -> at.TensorVariable:
1115+
"""Covariance between the latent variables as an unstructured 2-dimensional tensor variable"""
1116+
raise NotImplementedError()
11141117

11151118
@node_property
1116-
def mean(self):
1117-
raise NotImplementedError
1119+
def mean(self) -> at.TensorVariable:
1120+
"""Mean of the latent variables as an unstructured 1-dimensional tensor variable"""
1121+
raise NotImplementedError()
1122+
1123+
def var_to_data(self, shared: at.TensorVariable) -> xarray.Dataset:
1124+
"""Takes a flat 1-dimensional tensor variable and maps it to an xarray data set based on the information in
1125+
`self.ordering`.
1126+
"""
1127+
# This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have
1128+
# `RaveledVars` and need to take the information from `self.ordering` instead
1129+
shared_nda = shared.eval()
1130+
result = dict()
1131+
for name, s, shape, dtype in self.ordering.values():
1132+
dims = self.model.named_vars_to_dims.get(name, None)
1133+
if dims is not None:
1134+
coords = {d: np.array(self.model.coords[d]) for d in dims}
1135+
else:
1136+
coords = None
1137+
values = shared_nda[s].reshape(shape).astype(dtype)
1138+
result[name] = xarray.DataArray(values, coords=coords, dims=dims, name=name)
1139+
return xarray.Dataset(result)
1140+
1141+
@property
1142+
def mean_data(self) -> xarray.Dataset:
1143+
"""Mean of the latent variables as an xarray Dataset"""
1144+
return self.var_to_data(self.mean)
1145+
1146+
@property
1147+
def std_data(self) -> xarray.Dataset:
1148+
"""Standard deviation of the latent variables as an xarray Dataset"""
1149+
return self.var_to_data(self.std)
11181150

11191151

11201152
group_for_params = Group.group_for_params

0 commit comments

Comments
 (0)