Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,14 +398,24 @@ def constant_data_to_xarray(self):
if not constant_data:
return None

return dict_to_dataset(
xarray_dataset = dict_to_dataset(
constant_data,
library=pymc,
coords=self.coords,
dims=self.dims,
default_dims=[],
)

# provisional handling of scalars in constant
# data to prevent promotion to rank 1
# in the future this will be handled by arviz
scalars = [var_name for var_name, value in constant_data.items() if np.ndim(value) == 0]
for s in scalars:
s_dim_0_name = f"{s}_dim_0"
xarray_dataset = xarray_dataset.squeeze(s_dim_0_name, drop=True)

return xarray_dataset

def to_inference_data(self):
"""Convert all available data to an InferenceData object.

Expand Down
11 changes: 9 additions & 2 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,18 +436,25 @@ def test_constant_data(self, use_context):
with pm.Model() as model:
x = pm.ConstantData("x", [1.0, 2.0, 3.0])
y = pm.MutableData("y", [1.0, 2.0, 3.0])
beta = pm.Normal("beta", 0, 1)
beta_sigma = pm.MutableData("beta_sigma", 1)
beta = pm.Normal("beta", 0, beta_sigma)
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
trace = pm.sample(100, chains=2, tune=100, return_inferencedata=False)
if use_context:
inference_data = to_inference_data(trace=trace, log_likelihood=True)

if not use_context:
inference_data = to_inference_data(trace=trace, model=model, log_likelihood=True)
test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]}
test_dict = {
"posterior": ["beta"],
"observed_data": ["obs"],
"constant_data": ["x", "y", "beta_sigma"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert inference_data.log_likelihood["obs"].shape == (2, 100, 3)
# test that scalars are dimensionless in constant_data (issue #6755)
assert inference_data.constant_data["beta_sigma"].ndim == 0

def test_predictions_constant_data(self):
with pm.Model():
Expand Down