@@ -436,18 +436,25 @@ def test_constant_data(self, use_context):
436
436
with pm .Model () as model :
437
437
x = pm .ConstantData ("x" , [1.0 , 2.0 , 3.0 ])
438
438
y = pm .MutableData ("y" , [1.0 , 2.0 , 3.0 ])
439
- beta = pm .Normal ("beta" , 0 , 1 )
439
+ beta_sigma = pm .MutableData ("beta_sigma" , 1 )
440
+ beta = pm .Normal ("beta" , 0 , beta_sigma )
440
441
obs = pm .Normal ("obs" , x * beta , 1 , observed = y ) # pylint: disable=unused-variable
441
442
trace = pm .sample (100 , chains = 2 , tune = 100 , return_inferencedata = False )
442
443
if use_context :
443
444
inference_data = to_inference_data (trace = trace , log_likelihood = True )
444
445
445
446
if not use_context :
446
447
inference_data = to_inference_data (trace = trace , model = model , log_likelihood = True )
447
- test_dict = {"posterior" : ["beta" ], "observed_data" : ["obs" ], "constant_data" : ["x" ]}
448
+ test_dict = {
449
+ "posterior" : ["beta" ],
450
+ "observed_data" : ["obs" ],
451
+ "constant_data" : ["x" , "y" , "beta_sigma" ],
452
+ }
448
453
fails = check_multiple_attrs (test_dict , inference_data )
449
454
assert not fails
450
455
assert inference_data .log_likelihood ["obs" ].shape == (2 , 100 , 3 )
456
+ # test that scalars are dimensionless in constant_data (issue #6755)
457
+ assert inference_data .constant_data ["beta_sigma" ].ndim == 0
451
458
452
459
def test_predictions_constant_data (self ):
453
460
with pm .Model ():
0 commit comments