@@ -547,19 +547,27 @@ def test_batched_size(self, explicit_shape, batched_param):
547
547
initial_vol = 2.5 ,
548
548
)
549
549
kwargs0 = init_kwargs .copy ()
550
- kwargs0 [arg_name ] = init_kwargs [arg_name ] * param_val
550
+ kwargs0 [batched_param ] = init_kwargs [batched_param ] * param_val
551
+ if explicit_shape :
552
+ kwargs0 ["shape" ] = (batch_size , steps )
553
+ else :
554
+ kwargs0 ["steps" ] = steps - 1
551
555
with Model () as t0 :
552
- y = GARCH11 ("y" , shape = ( batch_size , steps ), ** kwargs0 )
556
+ y = GARCH11 ("y" , ** kwargs0 )
553
557
554
558
y_eval = draw (y , draws = 2 )
555
559
assert y_eval [0 ].shape == (batch_size , steps )
556
560
assert not np .any (np .isclose (y_eval [0 ], y_eval [1 ]))
557
561
558
562
kwargs1 = init_kwargs .copy ()
563
+ if explicit_shape :
564
+ kwargs1 ["shape" ] = steps
565
+ else :
566
+ kwargs1 ["steps" ] = steps - 1
559
567
with Model () as t1 :
560
568
for i in range (batch_size ):
561
- kwargs1 [arg_name ] = init_kwargs [arg_name ] * param_val [i ]
562
- GARCH11 (f"y_{ i } " , shape = steps , ** kwargs1 )
569
+ kwargs1 [batched_param ] = init_kwargs [batched_param ] * param_val [i ]
570
+ GARCH11 (f"y_{ i } " , ** kwargs1 )
563
571
564
572
np .testing .assert_allclose (
565
573
t0 .compile_logp ()(t0 .initial_point ()),
@@ -584,7 +592,7 @@ def test_moment(self, size, expected):
584
592
steps = 7 ,
585
593
size = size ,
586
594
)
587
- assert_moment_is_expected (model , expected , check_finite_logp = False )
595
+ assert_moment_is_expected (model , expected , check_finite_logp = True )
588
596
589
597
def test_change_dist_size (self ):
590
598
base_dist = pm .GARCH11 .dist (
0 commit comments