Skip to content

Commit 913ec31

Browse files
committed
Implement Beta distribution
1 parent af8aa52 commit 913ec31

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

src/distributions/gamma.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,49 @@ impl Distribution<f64> for StudentT {
305305
}
306306
}
307307

308+
/// The Beta distribution with shape parameters `alpha` and `beta`.
309+
///
310+
/// # Example
311+
///
312+
/// ```
313+
/// use rand::distributions::{Distribution, Beta};
314+
///
315+
/// let beta = Beta::new(2.0, 5.0);
316+
/// let v = beta.sample(&mut rand::thread_rng());
317+
/// println!("{} is from a Beta(2, 5) distribution", v);
318+
/// ```
319+
#[derive(Clone, Copy, Debug)]
320+
pub struct Beta {
321+
gamma_a: Gamma,
322+
gamma_b: Gamma,
323+
}
324+
325+
impl Beta {
326+
/// Construct an object representing the `Beta(alpha, beta)`
327+
/// distribution.
328+
///
329+
/// Panics if `shape <= 0` or `scale <= 0`.
330+
pub fn new(alpha: f64, beta: f64) -> Beta {
331+
assert!((alpha > 0.) & (beta > 0.));
332+
Beta {
333+
gamma_a: Gamma::new(alpha, 1.),
334+
gamma_b: Gamma::new(beta, 1.),
335+
}
336+
}
337+
}
338+
339+
impl Distribution<f64> for Beta {
340+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
341+
let x = self.gamma_a.sample(rng);
342+
let y = self.gamma_b.sample(rng);
343+
x / (x + y)
344+
}
345+
}
346+
308347
#[cfg(test)]
309348
mod test {
310349
use distributions::Distribution;
311-
use super::{ChiSquared, StudentT, FisherF};
350+
use super::{Beta, ChiSquared, StudentT, FisherF};
312351

313352
#[test]
314353
fn test_chi_squared_one() {
@@ -357,4 +396,19 @@ mod test {
357396
t.sample(&mut rng);
358397
}
359398
}
399+
400+
#[test]
401+
fn test_beta() {
402+
let beta = Beta::new(1.0, 2.0);
403+
let mut rng = ::test::rng(201);
404+
for _ in 0..1000 {
405+
beta.sample(&mut rng);
406+
}
407+
}
408+
409+
#[test]
410+
#[should_panic]
411+
fn test_beta_invalid_dof() {
412+
Beta::new(0., 0.);
413+
}
360414
}

src/distributions/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
//! - [`ChiSquared`] distribution
9999
//! - [`StudentT`] distribution
100100
//! - [`FisherF`] distribution
101+
//! - [`Beta`] distribution
101102
//! - Multivariate probability distributions
102103
//! - [`Dirichlet`] distribution
103104
//! - [`UnitSphereSurface`] distribution
@@ -151,6 +152,7 @@
151152
// distributions
152153
//! [`Alphanumeric`]: struct.Alphanumeric.html
153154
//! [`Bernoulli`]: struct.Bernoulli.html
155+
//! [`Beta`]: struct.Beta.html
154156
//! [`Binomial`]: struct.Binomial.html
155157
//! [`Cauchy`]: struct.Cauchy.html
156158
//! [`ChiSquared`]: struct.ChiSquared.html
@@ -184,7 +186,8 @@ pub use self::bernoulli::Bernoulli;
184186
#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError};
185187
#[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface;
186188
#[cfg(feature="std")] pub use self::unit_circle::UnitCircle;
187-
#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
189+
#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF,
190+
StudentT, Beta};
188191
#[cfg(feature="std")] pub use self::normal::{Normal, LogNormal, StandardNormal};
189192
#[cfg(feature="std")] pub use self::exponential::{Exp, Exp1};
190193
#[cfg(feature="std")] pub use self::pareto::Pareto;

0 commit comments

Comments
 (0)