88// except according to those terms.
99
1010//! The dirichlet distribution.
11-
12- use crate :: utils :: Float ;
11+ #! [ cfg ( feature = "alloc" ) ]
12+ use num_traits :: Float ;
1313use crate :: { Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
1414use rand:: Rng ;
15- use std:: { error, fmt} ;
15+ use core:: fmt;
16+ use alloc:: { boxed:: Box , vec, vec:: Vec } ;
1617
1718/// The Dirichlet distribution `Dirichlet(alpha)`.
1819///
@@ -26,14 +27,20 @@ use std::{error, fmt};
2627/// use rand::prelude::*;
2728/// use rand_distr::Dirichlet;
2829///
29- /// let dirichlet = Dirichlet::new(vec! [1.0, 2.0, 3.0]).unwrap();
30+ /// let dirichlet = Dirichlet::new(& [1.0, 2.0, 3.0]).unwrap();
3031/// let samples = dirichlet.sample(&mut rand::thread_rng());
3132/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
3233/// ```
3334#[ derive( Clone , Debug ) ]
34- pub struct Dirichlet < N > {
35+ pub struct Dirichlet < F >
36+ where
37+ F : Float ,
38+ StandardNormal : Distribution < F > ,
39+ Exp1 : Distribution < F > ,
40+ Open01 : Distribution < F > ,
41+ {
3542 /// Concentration parameters (alpha)
36- alpha : Vec < N > ,
43+ alpha : Box < [ F ] > ,
3744}
3845
3946/// Error type returned from `Dirchlet::new`.
@@ -58,68 +65,70 @@ impl fmt::Display for Error {
5865 }
5966}
6067
61- impl error:: Error for Error { }
68+ #[ cfg( feature = "std" ) ]
69+ impl std:: error:: Error for Error { }
6270
63- impl < N : Float > Dirichlet < N >
71+ impl < F > Dirichlet < F >
6472where
65- StandardNormal : Distribution < N > ,
66- Exp1 : Distribution < N > ,
67- Open01 : Distribution < N > ,
73+ F : Float ,
74+ StandardNormal : Distribution < F > ,
75+ Exp1 : Distribution < F > ,
76+ Open01 : Distribution < F > ,
6877{
6978 /// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
7079 ///
7180 /// Requires `alpha.len() >= 2`.
7281 #[ inline]
73- pub fn new < V : Into < Vec < N > > > ( alpha : V ) -> Result < Dirichlet < N > , Error > {
74- let a = alpha. into ( ) ;
75- if a. len ( ) < 2 {
82+ pub fn new ( alpha : & [ F ] ) -> Result < Dirichlet < F > , Error > {
83+ if alpha. len ( ) < 2 {
7684 return Err ( Error :: AlphaTooShort ) ;
7785 }
78- for & ai in & a {
79- if !( ai > N :: from ( 0.0 ) ) {
86+ for & ai in alpha . iter ( ) {
87+ if !( ai > F :: zero ( ) ) {
8088 return Err ( Error :: AlphaTooSmall ) ;
8189 }
8290 }
8391
84- Ok ( Dirichlet { alpha : a } )
92+ Ok ( Dirichlet { alpha : alpha . to_vec ( ) . into_boxed_slice ( ) } )
8593 }
8694
8795 /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
8896 ///
8997 /// Requires `size >= 2`.
9098 #[ inline]
91- pub fn new_with_size ( alpha : N , size : usize ) -> Result < Dirichlet < N > , Error > {
92- if !( alpha > N :: from ( 0.0 ) ) {
99+ pub fn new_with_size ( alpha : F , size : usize ) -> Result < Dirichlet < F > , Error > {
100+ if !( alpha > F :: zero ( ) ) {
93101 return Err ( Error :: AlphaTooSmall ) ;
94102 }
95103 if size < 2 {
96104 return Err ( Error :: SizeTooSmall ) ;
97105 }
98106 Ok ( Dirichlet {
99- alpha : vec ! [ alpha; size] ,
107+ alpha : vec ! [ alpha; size] . into_boxed_slice ( ) ,
100108 } )
101109 }
102110}
103111
104- impl < N : Float > Distribution < Vec < N > > for Dirichlet < N >
112+ impl < F > Distribution < Vec < F > > for Dirichlet < F >
105113where
106- StandardNormal : Distribution < N > ,
107- Exp1 : Distribution < N > ,
108- Open01 : Distribution < N > ,
114+ F : Float ,
115+ StandardNormal : Distribution < F > ,
116+ Exp1 : Distribution < F > ,
117+ Open01 : Distribution < F > ,
109118{
110- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < N > {
119+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < F > {
111120 let n = self . alpha . len ( ) ;
112- let mut samples = vec ! [ N :: from ( 0.0 ) ; n] ;
113- let mut sum = N :: from ( 0.0 ) ;
121+ let mut samples = vec ! [ F :: zero ( ) ; n] ;
122+ let mut sum = F :: zero ( ) ;
114123
115124 for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
116- let g = Gamma :: new ( a, N :: from ( 1.0 ) ) . unwrap ( ) ;
125+ let g = Gamma :: new ( a, F :: one ( ) ) . unwrap ( ) ;
117126 * s = g. sample ( rng) ;
118- sum += * s ;
127+ sum = sum + ( * s ) ;
119128 }
120- let invacc = N :: from ( 1.0 ) / sum;
129+ let invacc = F :: one ( ) / sum;
121130 for s in samples. iter_mut ( ) {
122- * s *= invacc;
131+ * s = ( * s ) * invacc;
123132 }
124133 samples
125134 }
@@ -131,7 +140,7 @@ mod test {
131140
132141 #[ test]
133142 fn test_dirichlet ( ) {
134- let d = Dirichlet :: new ( vec ! [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
143+ let d = Dirichlet :: new ( & [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
135144 let mut rng = crate :: test:: rng ( 221 ) ;
136145 let samples = d. sample ( & mut rng) ;
137146 let _: Vec < f64 > = samples
@@ -170,20 +179,4 @@ mod test {
170179 fn test_dirichlet_invalid_alpha ( ) {
171180 Dirichlet :: new_with_size ( 0.0f64 , 2 ) . unwrap ( ) ;
172181 }
173-
174- #[ test]
175- fn value_stability ( ) {
176- let mut rng = crate :: test:: rng ( 223 ) ;
177- assert_eq ! (
178- rng. sample( Dirichlet :: new( vec![ 1.0 , 2.0 , 3.0 ] ) . unwrap( ) ) ,
179- vec![ 0.12941567177708177 , 0.4702121891675036 , 0.4003721390554146 ]
180- ) ;
181- assert_eq ! ( rng. sample( Dirichlet :: new_with_size( 8.0 , 5 ) . unwrap( ) ) , vec![
182- 0.17684200044809556 ,
183- 0.29915953935953055 ,
184- 0.1832858056608014 ,
185- 0.1425623503573967 ,
186- 0.19815030417417595
187- ] ) ;
188- }
189182}
0 commit comments