Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c981754
make vi (posterior) mean and std accessible as a structured xarray da…
markusschmaus Aug 31, 2022
3c6af3a
add doc strings
markusschmaus Aug 31, 2022
6aa21cb
convert to original dtype
markusschmaus Sep 1, 2022
8c96be1
add docstring and comment
markusschmaus Sep 1, 2022
671cb9c
test `mean_data` and `std_data` for shape and coords
markusschmaus Sep 2, 2022
171798b
Fix merge conflicts
Dec 12, 2022
7d73b53
Merge branch 'markusschmaus-advi_mean_data'
Dec 12, 2022
39fc213
Merge remote-tracking branch 'upstream/main'
fonnesbeck Jan 2, 2023
3e15041
Moved tests to tests/variational
fonnesbeck Jan 2, 2023
e94bef9
Merge branch 'main' of https://github.com/fonnesbeck/pymc3
fonnesbeck Jan 6, 2023
5e46943
Merge remote-tracking branch 'upstream/main'
fonnesbeck Jan 6, 2023
fcb9f52
Formatting
fonnesbeck Jan 6, 2023
1977583
Merge remote-tracking branch 'upstream/main'
fonnesbeck Jan 7, 2023
99dd715
Merge remote-tracking branch 'upstream/main'
fonnesbeck Jan 17, 2023
8512301
Merge remote-tracking branch 'upstream/main'
fonnesbeck Jan 23, 2023
2f501fb
Updated deprecated function; removed dict comparison operator
fonnesbeck Jan 23, 2023
addd6c1
Removed bad dict operator in tests
fonnesbeck Jan 24, 2023
0ed4cea
Fixed type hint
fonnesbeck Jan 24, 2023
a2e634f
Alternative type hint
fonnesbeck Jan 24, 2023
a33986e
Property return type hints
fonnesbeck Jan 30, 2023
087af09
Property return type hints
fonnesbeck Jan 30, 2023
ab56fed
Removed clobbering of shared reference
fonnesbeck Jan 30, 2023
c174630
Add type hint
fonnesbeck Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions pymc/tests/variational/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,81 @@ def test_clear_cache():
assert any(len(c) != 0 for c in inference_new.approx._cache.values())
inference_new.approx._cache.clear()
assert all(len(c) == 0 for c in inference_new.approx._cache.values())


def test_fit_data(inference, fit_kwargs, simple_model_data):
fitted = inference.fit(**fit_kwargs)
mu_post = simple_model_data["mu_post"]
d = simple_model_data["d"]
np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05)
np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2)


@pytest.fixture
def hierarchical_model_data():
group_coords = {
"group_d1": np.arange(3),
"group_d2": np.arange(7),
}
group_shape = tuple(len(d) for d in group_coords.values())
data_coords = {"data_d": np.arange(11)} | group_coords

data_shape = tuple(len(d) for d in data_coords.values())

mu = -5.0

sigma_group_mu = 3
group_mu = sigma_group_mu * np.random.randn(*group_shape)

sigma = 3.0

data = sigma * np.random.randn(*data_shape) + group_mu + mu

return dict(
group_coords=group_coords,
group_shape=group_shape,
data_coords=data_coords,
data_shape=data_shape,
mu=mu,
sigma_group_mu=sigma_group_mu,
sigma=sigma,
group_mu=group_mu,
data=data,
)


@pytest.fixture
def hierarchical_model(hierarchical_model_data):
with pm.Model(coords=hierarchical_model_data["data_coords"]) as model:
mu = pm.Normal("mu", mu=0, sigma=10)
sigma_group_mu = pm.HalfNormal("sigma_group_mu", sigma=5)

group_mu = pm.Normal(
"group_mu",
mu=0,
sigma=sigma_group_mu,
dims=list(hierarchical_model_data["group_coords"].keys()),
)

sigma = pm.HalfNormal("sigma", sigma=3)

pm.Normal(
"data",
mu=(mu + group_mu),
sigma=sigma,
observed=hierarchical_model_data["data"],
)
return model


def test_fit_data_coords(hierarchical_model, hierarchical_model_data):
with hierarchical_model:
fitted = pm.fit(1)

for data in [fitted.mean_data, fitted.std_data]:
assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"}
assert data["group_mu"].shape == hierarchical_model_data["group_shape"]
assert list(data["group_mu"].coords.keys()) == list(
hierarchical_model_data["group_coords"].keys()
)
assert data["mu"].shape == tuple()
32 changes: 32 additions & 0 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import numpy as np
import pytensor
import pytensor.tensor as at
import xarray

from pytensor.graph.basic import Variable

Expand Down Expand Up @@ -1106,16 +1107,47 @@ def __str__(self):

@node_property
def std(self):
"""Standard deviation of the latent variables as an unstructured 1-dimensional Aesara variable"""
raise NotImplementedError

@node_property
def cov(self):
"""Covariance between the latent variables as an unstructured 2-dimensional Aesara variable"""
raise NotImplementedError

@node_property
def mean(self):
"""Mean of the latent variables as an unstructured 1-dimensional Aesara variable"""
raise NotImplementedError

def var_to_data(self, shared):
"""Takes a flat 1-dimensional Aesara variable and maps it to an xarray data set based on the information in
`self.ordering`.
"""
# This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have
# `RaveledVars` and need to take the information from `self.ordering` instead
shared = shared.eval()
result = dict()
for name, s, shape, dtype in self.ordering.values():
dims = self.model.RV_dims.get(name, None)
if dims is not None:
coords = {d: np.array(self.model.coords[d]) for d in dims}
else:
coords = None
values = np.array(shared[s]).reshape(shape).astype(dtype)
result[name] = xarray.DataArray(values, coords=coords, dims=dims, name=name)
return xarray.Dataset(result)

@property
def mean_data(self):
"""Mean of the latent variables as an xarray Dataset"""
return self.var_to_data(self.mean)

@property
def std_data(self):
"""Standard deviation of the latent variables as an xarray Dataset"""
return self.var_to_data(self.std)


group_for_params = Group.group_for_params
group_for_short_name = Group.group_for_short_name
Expand Down