|
1 | 1 | import Statistics: mean, var, cov |
2 | 2 |
|
3 | 3 | """ |
4 | | - Product <: MultivariateDistribution |
| 4 | + ProductDistribution <: Distribution{<:ValueSupport,<:ArrayLikeVariate} |
5 | 5 |
|
6 | | -An N dimensional `MultivariateDistribution` constructed from a vector of N independent |
7 | | -`UnivariateDistribution`s. |
| 6 | +A distribution of `M + N`-dimensional arrays, constructed from an `N`-dimensional array of |
| 7 | +independent `M`-dimensional distributions by stacking them. |
8 | 8 |
|
9 | | -```julia |
10 | | -Product(Uniform.(rand(10), 1)) # A 10-dimensional Product from 10 independent `Uniform` distributions. |
11 | | -``` |
| 9 | +Users should use [`product_distribution`](@ref) to construct a product distribution of |
| 10 | +independent distributions instead of constructing a `ProductDistribution` directly. |
12 | 11 | """ |
13 | | -struct Product{ |
| 12 | +struct ProductDistribution{ |
| 13 | + N, |
14 | 14 | S<:ValueSupport, |
15 | | - T<:UnivariateDistribution{S}, |
16 | | - V<:AbstractVector{T}, |
17 | | -} <: MultivariateDistribution{S} |
| 15 | + T<:Distribution{<:ArrayLikeVariate,S}, |
| 16 | + V<:AbstractArray{T}, |
| 17 | +} <: Distribution{ArrayLikeVariate{N},S} |
18 | 18 | v::V |
19 | | - function Product(v::V) where |
20 | | - V<:AbstractVector{T} where |
21 | | - T<:UnivariateDistribution{S} where |
22 | | - S<:ValueSupport |
23 | | - return new{S, T, V}(v) |
| 19 | + |
| 20 | + function ProductDistribution(v::AbstractArray{T,N}) where {S<:ValueSupport, M, T<:Distribution{ArrayLikeVariate{M},S}, N} |
| 21 | + return new{M + N, S, T, typeof(v)}(v) |
24 | 22 | end |
25 | 23 | end |
26 | 24 |
|
27 | | -length(d::Product) = length(d.v) |
28 | | -function Base.eltype(::Type{<:Product{S,T}}) where {S<:ValueSupport, |
29 | | - T<:UnivariateDistribution{S}} |
| 25 | +# aliases |
| 26 | +const VectorOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} = |
| 27 | + ProductDistribution{1,S,T,V} |
| 28 | +const MatrixOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractMatrix{T}} = |
| 29 | + ProductDistribution{2,S,T,V} |
| 30 | +const VectorOfMultivariateDistribution{S<:ValueSupport,T<:MultivariateDistribution{S},V<:AbstractVector{T}} = |
| 31 | + ProductDistribution{2,S,T,V} |
| 32 | + |
| 33 | +## deprecations |
| 34 | +# type parameters can't be deprecated it seems: https://github.com/JuliaLang/julia/issues/9830 |
| 35 | +# so we define an alias and deprecate the corresponding constructor |
| 36 | +const Product{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} = ProductDistribution{1,S,T,V} |
| 37 | +Base.@deprecate Product(v::AbstractVector{<:UnivariateDistribution}) ProductDistribution(v) |
| 38 | + |
| 39 | +## General definitions |
| 40 | +function Base.eltype(::Type{<:ProductDistribution{S,T}}) where {S<:ValueSupport,T<:Distribution{S,<:ArrayLikeVariate}} |
30 | 41 | return eltype(T) |
31 | 42 | end |
32 | 43 |
|
33 | | -_rand!(rng::AbstractRNG, d::Product, x::AbstractVector{<:Real}) = |
| 44 | + |
| 45 | +## Vector of univariate distributions |
| 46 | +length(d::VectorOfUnivariateDistribution) = length(d.v) |
| 47 | + |
| 48 | +_rand!(rng::AbstractRNG, d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) = |
34 | 49 | broadcast!(dn->rand(rng, dn), x, d.v) |
35 | | -_logpdf(d::Product, x::AbstractVector{<:Real}) = |
| 50 | +_logpdf(d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) = |
36 | 51 | sum(n->logpdf(d.v[n], x[n]), 1:length(d)) |
37 | 52 |
|
38 | | -mean(d::Product) = mean.(d.v) |
39 | | -var(d::Product) = var.(d.v) |
40 | | -cov(d::Product) = Diagonal(var(d)) |
41 | | -entropy(d::Product) = sum(entropy, d.v) |
42 | | -insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x)) |
| 53 | +mean(d::VectorOfUnivariateDistribution) = map(mean, d.v) |
| 54 | +var(d::VectorOfUnivariateDistribution) = map(var, d.v) |
| 55 | +cov(d::VectorOfUnivariateDistribution) = Diagonal(var(d)) |
| 56 | +entropy(d::VectorOfUnivariateDistribution) = sum(entropy, d.v) |
| 57 | +function insupport(d::VectorOfUnivariateDistribution, x::AbstractVector) |
| 58 | + length(d) == length(x) && all(insupport(vi, xi) for (vi, xi) in zip(d.v, x)) |
| 59 | +end |
43 | 60 |
|
44 | 61 | """ |
45 | | - product_distribution(dists::AbstractVector{<:UnivariateDistribution}) |
| 62 | + product_distribution( |
| 63 | + dists::AbstractArray{<:Distribution{<:ValueSupport,<:ArrayLikeVariate{M}},N} |
| 64 | + ) |
46 | 65 |
|
47 | | -Creates a multivariate product distribution `P` from a vector of univariate distributions. |
48 | | -Fallback is the `Product constructor`, but specialized methods can be defined |
49 | | -for distributions with a special multivariate product. |
| 66 | +Create a distribution of `M + N`-dimensional arrays as a product distribution of |
| 67 | +independent `M`-dimensional distributions by stacking them. |
| 68 | +
|
| 69 | +The function falls back to constructing a [`ProductDistribution`](@ref) distribution but |
| 70 | +specialized methods can be defined. |
50 | 71 | """ |
51 | | -function product_distribution(dists::AbstractVector{<:UnivariateDistribution}) |
52 | | - return Product(dists) |
| 72 | +function product_distribution( |
| 73 | + dists::AbstractArray{<:Distribution{<:ValueSupport,<:ArrayLikeVariate}} |
| 74 | +) |
| 75 | + return ProductDistribution(dists) |
53 | 76 | end |
54 | 77 |
|
55 | 78 | """ |
56 | 79 | product_distribution(dists::AbstractVector{<:Normal}) |
57 | 80 |
|
58 | | -Computes the multivariate Normal distribution obtained by stacking the univariate |
59 | | -normal distributions. The result is a multivariate Gaussian with a diagonal |
60 | | -covariance matrix. |
| 81 | +Create a multivariate normal distribution by stacking the univariate normal distributions. |
| 82 | +
|
| 83 | +The resulting distribution of type [`MvNormal`](@ref) has a diagonal covariance matrix. |
61 | 84 | """ |
62 | 85 | function product_distribution(dists::AbstractVector{<:Normal}) |
63 | | - µ = mean.(dists) |
64 | | - σ2 = var.(dists) |
| 86 | + µ = map(mean, dists) |
| 87 | + σ2 = map(var, dists) |
65 | 88 | return MvNormal(µ, Diagonal(σ2)) |
66 | 89 | end |
0 commit comments