Skip to content

Commit 7c90922

Browse files
committed
Generalize Product to ProductDistribution
1 parent 34cd1ac commit 7c90922

File tree

2 files changed

+59
-37
lines changed

2 files changed

+59
-37
lines changed

src/Distributions.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ export
139139
NormalInverseGaussian,
140140
Pareto,
141141
PGeneralizedGaussian,
142-
Product,
143142
Poisson,
144143
PoissonBinomial,
145144
QQPair,

src/product.jl

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,89 @@
11
import Statistics: mean, var, cov
22

33
"""
4-
Product <: MultivariateDistribution
4+
ProductDistribution <: Distribution{<:ValueSupport,<:ArrayLikeVariate}
55
6-
An N dimensional `MultivariateDistribution` constructed from a vector of N independent
7-
`UnivariateDistribution`s.
6+
A distribution of `M + N`-dimensional arrays, constructed from an `N`-dimensional array of
7+
independent `M`-dimensional distributions by stacking them.
88
9-
```julia
10-
Product(Uniform.(rand(10), 1)) # A 10-dimensional Product from 10 independent `Uniform` distributions.
11-
```
9+
Users should use [`product_distribution`](@ref) to construct a product distribution of
10+
independent distributions instead of constructing a `ProductDistribution` directly.
1211
"""
13-
struct Product{
12+
struct ProductDistribution{
13+
N,
1414
S<:ValueSupport,
15-
T<:UnivariateDistribution{S},
16-
V<:AbstractVector{T},
17-
} <: MultivariateDistribution{S}
15+
T<:Distribution{<:ArrayLikeVariate,S},
16+
V<:AbstractArray{T},
17+
} <: Distribution{ArrayLikeVariate{N},S}
1818
v::V
19-
function Product(v::V) where
20-
V<:AbstractVector{T} where
21-
T<:UnivariateDistribution{S} where
22-
S<:ValueSupport
23-
return new{S, T, V}(v)
19+
20+
function ProductDistribution(v::AbstractArray{T,N}) where {S<:ValueSupport, M, T<:Distribution{ArrayLikeVariate{M},S}, N}
21+
return new{M + N, S, T, typeof(v)}(v)
2422
end
2523
end
2624

27-
length(d::Product) = length(d.v)
28-
function Base.eltype(::Type{<:Product{S,T}}) where {S<:ValueSupport,
29-
T<:UnivariateDistribution{S}}
25+
# aliases
26+
const VectorOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} =
27+
ProductDistribution{1,S,T,V}
28+
const MatrixOfUnivariateDistribution{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractMatrix{T}} =
29+
ProductDistribution{2,S,T,V}
30+
const VectorOfMultivariateDistribution{S<:ValueSupport,T<:MultivariateDistribution{S},V<:AbstractVector{T}} =
31+
ProductDistribution{2,S,T,V}
32+
33+
## deprecations
34+
# type parameters can't be deprecated it seems: https://github.com/JuliaLang/julia/issues/9830
35+
# so we define an alias and deprecate the corresponding constructor
36+
const Product{S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} = ProductDistribution{1,S,T,V}
37+
Base.@deprecate Product(v::AbstractVector{<:UnivariateDistribution}) ProductDistribution(v)
38+
39+
## General definitions
40+
function Base.eltype(::Type{<:ProductDistribution{S,T}}) where {S<:ValueSupport,T<:Distribution{S,<:ArrayLikeVariate}}
3041
return eltype(T)
3142
end
3243

33-
_rand!(rng::AbstractRNG, d::Product, x::AbstractVector{<:Real}) =
44+
45+
## Vector of univariate distributions
46+
length(d::VectorOfUnivariateDistribution) = length(d.v)
47+
48+
_rand!(rng::AbstractRNG, d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) =
3449
broadcast!(dn->rand(rng, dn), x, d.v)
35-
_logpdf(d::Product, x::AbstractVector{<:Real}) =
50+
_logpdf(d::VectorOfUnivariateDistribution, x::AbstractVector{<:Real}) =
3651
sum(n->logpdf(d.v[n], x[n]), 1:length(d))
3752

38-
mean(d::Product) = mean.(d.v)
39-
var(d::Product) = var.(d.v)
40-
cov(d::Product) = Diagonal(var(d))
41-
entropy(d::Product) = sum(entropy, d.v)
42-
insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x))
53+
mean(d::VectorOfUnivariateDistribution) = map(mean, d.v)
54+
var(d::VectorOfUnivariateDistribution) = map(var, d.v)
55+
cov(d::VectorOfUnivariateDistribution) = Diagonal(var(d))
56+
entropy(d::VectorOfUnivariateDistribution) = sum(entropy, d.v)
57+
function insupport(d::VectorOfUnivariateDistribution, x::AbstractVector)
58+
length(d) == length(x) && all(insupport(vi, xi) for (vi, xi) in zip(d.v, x))
59+
end
4360

4461
"""
45-
product_distribution(dists::AbstractVector{<:UnivariateDistribution})
62+
product_distribution(
63+
dists::AbstractArray{<:Distribution{<:ValueSupport,<:ArrayLikeVariate{M}},N}
64+
)
4665
47-
Creates a multivariate product distribution `P` from a vector of univariate distributions.
48-
Fallback is the `Product constructor`, but specialized methods can be defined
49-
for distributions with a special multivariate product.
66+
Create a distribution of `M + N`-dimensional arrays as a product distribution of
67+
independent `M`-dimensional distributions by stacking them.
68+
69+
The function falls back to constructing a [`ProductDistribution`](@ref) distribution but
70+
specialized methods can be defined.
5071
"""
51-
function product_distribution(dists::AbstractVector{<:UnivariateDistribution})
52-
return Product(dists)
72+
function product_distribution(
73+
dists::AbstractArray{<:Distribution{<:ValueSupport,<:ArrayLikeVariate}}
74+
)
75+
return ProductDistribution(dists)
5376
end
5477

5578
"""
5679
product_distribution(dists::AbstractVector{<:Normal})
5780
58-
Computes the multivariate Normal distribution obtained by stacking the univariate
59-
normal distributions. The result is a multivariate Gaussian with a diagonal
60-
covariance matrix.
81+
Create a multivariate normal distribution by stacking the univariate normal distributions.
82+
83+
The resulting distribution of type [`MvNormal`](@ref) has a diagonal covariance matrix.
6184
"""
6285
function product_distribution(dists::AbstractVector{<:Normal})
63-
µ = mean.(dists)
64-
σ2 = var.(dists)
86+
µ = map(mean, dists)
87+
σ2 = map(var, dists)
6588
return MvNormal(µ, Diagonal(σ2))
6689
end

0 commit comments

Comments
 (0)