@@ -1700,12 +1700,14 @@ def test_bernoulli_wrong_arguments(self):
17001700 Bernoulli ("x" )
17011701
17021702 def test_discrete_weibull (self ):
1703- check_logp (
1704- DiscreteWeibull ,
1705- Nat ,
1706- {"q" : Unit , "beta" : NatSmall },
1707- discrete_weibull_logpmf ,
1708- )
1703+ with warnings .catch_warnings ():
1704+ warnings .filterwarnings ("ignore" , "divide by zero encountered in log" , RuntimeWarning )
1705+ check_logp (
1706+ DiscreteWeibull ,
1707+ Nat ,
1708+ {"q" : Unit , "beta" : NatSmall },
1709+ discrete_weibull_logpmf ,
1710+ )
17091711 check_selfconsistency_discrete_logcdf (
17101712 DiscreteWeibull ,
17111713 Nat ,
@@ -1732,8 +1734,10 @@ def test_poisson(self):
17321734 )
17331735
17341736 def test_diracdeltadist (self ):
1735- check_logp (DiracDelta , I , {"c" : I }, lambda value , c : np .log (c == value ))
1736- check_logcdf (DiracDelta , I , {"c" : I }, lambda value , c : np .log (value >= c ))
1737+ with warnings .catch_warnings ():
1738+ warnings .filterwarnings ("ignore" , "divide by zero encountered in log" , RuntimeWarning )
1739+ check_logp (DiracDelta , I , {"c" : I }, lambda value , c : np .log (c == value ))
1740+ check_logcdf (DiracDelta , I , {"c" : I }, lambda value , c : np .log (value >= c ))
17371741
17381742 def test_zeroinflatedpoisson (self ):
17391743 def logp_fn (value , psi , mu ):
@@ -2370,12 +2374,15 @@ def test_categorical_p_not_normalized(self):
23702374
23712375 @pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
23722376 def test_orderedlogistic (self , n ):
2373- check_logp (
2374- OrderedLogistic ,
2375- Domain (range (n ), dtype = "int64" , edges = (None , None )),
2376- {"eta" : R , "cutpoints" : Vector (R , n - 1 )},
2377- lambda value , eta , cutpoints : orderedlogistic_logpdf (value , eta , cutpoints ),
2378- )
2377+ with warnings .catch_warnings ():
2378+ warnings .filterwarnings ("ignore" , "invalid value encountered in log" , RuntimeWarning )
2379+ warnings .filterwarnings ("ignore" , "divide by zero encountered in log" , RuntimeWarning )
2380+ check_logp (
2381+ OrderedLogistic ,
2382+ Domain (range (n ), dtype = "int64" , edges = (None , None )),
2383+ {"eta" : R , "cutpoints" : Vector (R , n - 1 )},
2384+ lambda value , eta , cutpoints : orderedlogistic_logpdf (value , eta , cutpoints ),
2385+ )
23792386
23802387 @pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
23812388 def test_orderedprobit (self , n ):
@@ -2622,6 +2629,7 @@ def ref_logp(value, mu, sigma, steps):
26222629 {"mu" : R , "sigma" : Rplus , "steps" : Nat },
26232630 ref_logp ,
26242631 decimal = select_by_precision (float64 = 6 , float32 = 1 ),
2632+ extra_args = {"init_dist" : Normal .dist (0 , 100 )},
26252633 )
26262634
26272635
@@ -2631,8 +2639,14 @@ class TestBound:
26312639 def test_continuous (self ):
26322640 with Model () as model :
26332641 dist = Normal .dist (mu = 0 , sigma = 1 )
2634- UnboundedNormal = Bound ("unbound" , dist , transform = None )
2635- InfBoundedNormal = Bound ("infbound" , dist , lower = - np .inf , upper = np .inf , transform = None )
2642+ with warnings .catch_warnings ():
2643+ warnings .filterwarnings (
2644+ "ignore" , "invalid value encountered in add" , RuntimeWarning
2645+ )
2646+ UnboundedNormal = Bound ("unbound" , dist , transform = None )
2647+ InfBoundedNormal = Bound (
2648+ "infbound" , dist , lower = - np .inf , upper = np .inf , transform = None
2649+ )
26362650 LowerNormal = Bound ("lower" , dist , lower = 0 , transform = None )
26372651 UpperNormal = Bound ("upper" , dist , upper = 0 , transform = None )
26382652 BoundedNormal = Bound ("bounded" , dist , lower = 1 , upper = 10 , transform = None )
@@ -2667,7 +2681,11 @@ def test_continuous(self):
26672681 def test_discrete (self ):
26682682 with Model () as model :
26692683 dist = Poisson .dist (mu = 4 )
2670- UnboundedPoisson = Bound ("unbound" , dist )
2684+ with warnings .catch_warnings ():
2685+ warnings .filterwarnings (
2686+ "ignore" , "invalid value encountered in add" , RuntimeWarning
2687+ )
2688+ UnboundedPoisson = Bound ("unbound" , dist )
26712689 LowerPoisson = Bound ("lower" , dist , lower = 1 )
26722690 UpperPoisson = Bound ("upper" , dist , upper = 10 )
26732691 BoundedPoisson = Bound ("bounded" , dist , lower = 1 , upper = 10 )
@@ -2714,8 +2732,12 @@ def test_arguments_checks(self):
27142732 msg = "Cannot transform discrete variable."
27152733 with pm .Model () as m :
27162734 x = pm .Poisson .dist (0.5 )
2717- with pytest .raises (ValueError , match = msg ):
2718- pm .Bound ("bound" , x , transform = pm .distributions .transforms .log )
2735+ with warnings .catch_warnings ():
2736+ warnings .filterwarnings (
2737+ "ignore" , "invalid value encountered in add" , RuntimeWarning
2738+ )
2739+ with pytest .raises (ValueError , match = msg ):
2740+ pm .Bound ("bound" , x , transform = pm .distributions .transforms .log )
27192741
27202742 msg = "Given dims do not exist in model coordinates."
27212743 with pm .Model () as m :
@@ -2784,8 +2806,12 @@ def test_bound_dist(self):
27842806 def test_array_bound (self ):
27852807 with Model () as model :
27862808 dist = Normal .dist ()
2787- LowerPoisson = Bound ("lower" , dist , lower = [1 , None ], transform = None )
2788- UpperPoisson = Bound ("upper" , dist , upper = [np .inf , 10 ], transform = None )
2809+ with warnings .catch_warnings ():
2810+ warnings .filterwarnings (
2811+ "ignore" , "invalid value encountered in add" , RuntimeWarning
2812+ )
2813+ LowerPoisson = Bound ("lower" , dist , lower = [1 , None ], transform = None )
2814+ UpperPoisson = Bound ("upper" , dist , upper = [np .inf , 10 ], transform = None )
27892815 BoundedPoisson = Bound ("bounded" , dist , lower = [1 , 2 ], upper = [9 , 10 ], transform = None )
27902816
27912817 first , second = joint_logp (LowerPoisson , [0 , 0 ], sum = False )[0 ].eval ()
@@ -3081,7 +3107,9 @@ def random(rng, size):
30813107 with pm .Model ():
30823108 pm .Normal ("x" )
30833109 y = pm .DensityDist ("y" , logp = func , random = random )
3084- pm .sample (draws = 5 , tune = 1 , mp_ctx = "spawn" )
3110+ with warnings .catch_warnings ():
3111+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
3112+ pm .sample (draws = 5 , tune = 1 , mp_ctx = "spawn" )
30853113
30863114 import cloudpickle
30873115
0 commit comments