Skip to content

Commit 654e9f9

Browse files
added test
1 parent 1c61183 commit 654e9f9

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/backends/test_arviz.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,18 +436,25 @@ def test_constant_data(self, use_context):
436436
with pm.Model() as model:
437437
x = pm.ConstantData("x", [1.0, 2.0, 3.0])
438438
y = pm.MutableData("y", [1.0, 2.0, 3.0])
439-
beta = pm.Normal("beta", 0, 1)
439+
beta_sigma = pm.MutableData("beta_sigma", 1)
440+
beta = pm.Normal("beta", 0, beta_sigma)
440441
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
441442
trace = pm.sample(100, chains=2, tune=100, return_inferencedata=False)
442443
if use_context:
443444
inference_data = to_inference_data(trace=trace, log_likelihood=True)
444445

445446
if not use_context:
446447
inference_data = to_inference_data(trace=trace, model=model, log_likelihood=True)
447-
test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]}
448+
test_dict = {
449+
"posterior": ["beta"],
450+
"observed_data": ["obs"],
451+
"constant_data": ["x", "y", "beta_sigma"],
452+
}
448453
fails = check_multiple_attrs(test_dict, inference_data)
449454
assert not fails
450455
assert inference_data.log_likelihood["obs"].shape == (2, 100, 3)
456+
# test that scalars are dimensionless in constant_data (issue #6755)
457+
assert inference_data.constant_data["beta_sigma"].ndim == 0
451458

452459
def test_predictions_constant_data(self):
453460
with pm.Model():

0 commit comments

Comments
 (0)