diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 0326ea57e0..bfd3008a3f 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -398,7 +398,7 @@ def constant_data_to_xarray(self): if not constant_data: return None - return dict_to_dataset( + xarray_dataset = dict_to_dataset( constant_data, library=pymc, coords=self.coords, @@ -406,6 +406,16 @@ def constant_data_to_xarray(self): default_dims=[], ) + # provisional handling of scalars in constant + # data to prevent promotion to rank 1 + # in the future this will be handled by arviz + scalars = [var_name for var_name, value in constant_data.items() if np.ndim(value) == 0] + for s in scalars: + s_dim_0_name = f"{s}_dim_0" + xarray_dataset = xarray_dataset.squeeze(s_dim_0_name, drop=True) + + return xarray_dataset + def to_inference_data(self): """Convert all available data to an InferenceData object. diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 16b6bb9e86..a16bf43db5 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -436,7 +436,8 @@ def test_constant_data(self, use_context): with pm.Model() as model: x = pm.ConstantData("x", [1.0, 2.0, 3.0]) y = pm.MutableData("y", [1.0, 2.0, 3.0]) - beta = pm.Normal("beta", 0, 1) + beta_sigma = pm.MutableData("beta_sigma", 1) + beta = pm.Normal("beta", 0, beta_sigma) obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable trace = pm.sample(100, chains=2, tune=100, return_inferencedata=False) if use_context: @@ -444,10 +445,16 @@ def test_constant_data(self, use_context): if not use_context: inference_data = to_inference_data(trace=trace, model=model, log_likelihood=True) - test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]} + test_dict = { + "posterior": ["beta"], + "observed_data": ["obs"], + "constant_data": ["x", "y", "beta_sigma"], + } fails = check_multiple_attrs(test_dict, inference_data) assert not fails assert inference_data.log_likelihood["obs"].shape == (2, 100, 3) + # test that scalars are dimensionless in constant_data (issue #6755) + assert inference_data.constant_data["beta_sigma"].ndim == 0 def test_predictions_constant_data(self): with pm.Model():