Skip to content

Commit ccf4f8e

Browse files
committed
Unify and generalize rand!, logpdf and pdf
1 parent 5c7a82a commit ccf4f8e

File tree

10 files changed

+286
-366
lines changed

10 files changed

+286
-366
lines changed

src/genericrand.jl

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,32 @@ rand(s::Sampleable, dim1::Int, moredims::Int...) =
2626
rand(rng::AbstractRNG, s::Sampleable, dim1::Int, moredims::Int...) =
2727
rand(rng, s, (dim1, moredims...))
2828

29+
# default fallback (redefined for univariate distributions)
30+
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous})
31+
return @inbounds rand!(rng, sampler(s), Array{float(eltype(s))}(undef, size(s)))
32+
end
33+
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Discrete})
34+
return @inbounds rand!(rng, sampler(s), Array{eltype(s)}(undef, size(s)))
35+
end
36+
37+
# multiple samples (redefined for univariate distributions)
38+
function rand(
39+
rng::AbstractRNG, s::Sampleable{ArrayLikeVariate{N},Discrete}, dims::Dims,
40+
) where {N}
41+
sz = size(s)
42+
ax = map(Base.OneTo, dims)
43+
out = [Array{eltype(s),N}(undef, sz) for _ in Iterators.product(ax)]
44+
return @inbounds rand!(rng, sampler(s), out)
45+
end
46+
function rand(
47+
rng::AbstractRNG, s::Sampleable{ArrayLikeVariate{N},Continuous}, dims::Dims,
48+
) where {N}
49+
sz = size(s)
50+
ax = map(Base.OneTo, dims)
51+
out = [Array{float(eltype(s)),N}(undef, sz) for _ in Iterators.product(ax)]
52+
return @inbounds rand!(rng, sampler(s), out)
53+
end
54+
2955
"""
3056
rand!([rng::AbstractRNG,] s::Sampleable, A::AbstractArray)
3157
@@ -40,11 +66,78 @@ form as specified above. The rules are summarized as below:
4066
matrices with each element for a sample matrix.
4167
"""
4268
function rand! end
43-
rand!(s::Sampleable, X::AbstractArray{<:AbstractArray}, allocate::Bool) =
44-
rand!(GLOBAL_RNG, s, X, allocate)
45-
rand!(s::Sampleable, X::AbstractArray) = rand!(GLOBAL_RNG, s, X)
69+
Base.@propagate_inbounds rand!(s::Sampleable, X::AbstractArray) = rand!(GLOBAL_RNG, s, X)
4670
rand!(rng::AbstractRNG, s::Sampleable, X::AbstractArray) = _rand!(rng, s, X)
4771

72+
# default definitions for arraylike variates
73+
@inline function rand!(
74+
rng::AbstractRNG,
75+
s::Sampleable{ArrayLikeVariate{N}},
76+
x::AbstractArray{<:Real,N},
77+
) where {N}
78+
@boundscheck begin
79+
size(x) == size(s) || throw(DimensionMismatch("inconsistent array dimensions"))
80+
end
81+
# the function barrier fixes performance issues if `sampler(s)` is type unstable
82+
return _rand!(rng, sampler(s), x)
83+
end
84+
85+
@inline function rand!(
86+
rng::AbstractRNG,
87+
s::Sampleable{ArrayLikeVariate{N}},
88+
x::AbstractArray{<:Real,M},
89+
) where {N,M}
90+
@boundscheck begin
91+
M > N ||
92+
throw(DimensionMismatch(
93+
"number of dimensions of `x` ($M) must be greater than number of dimensions of `s` ($N)"
94+
))
95+
ntuple(i -> size(x, i), Val(N)) == size(s) ||
96+
throw(DimensionMismatch("inconsistent array dimensions"))
97+
end
98+
return _rand!(rng, sampler(s), x)
99+
end
100+
101+
function _rand!(
102+
rng::AbstractRNG,
103+
s::Sampleable{<:ArrayLikeVariate},
104+
x::AbstractArray{<:Real},
105+
)
106+
@inbounds for xi in eachvariate(x, variate_form(typeof(s)))
107+
rand!(rng, s, xi)
108+
end
109+
return x
110+
end
111+
112+
Base.@propagate_inbounds function rand!(
113+
rng::AbstractRNG,
114+
s::Sampleable{ArrayLikeVariate{N}},
115+
x::AbstractArray{<:AbstractArray{<:Real,N}},
116+
) where {N}
117+
# the function barrier fixes performance issues if `sampler(s)` is type unstable
118+
return _rand!(rng, sampler(s), x)
119+
end
120+
121+
Base.@propagate_inbounds function rand!(
122+
rng::AbstractRNG,
123+
s::Sampleable{ArrayLikeVariate{N}},
124+
x::AbstractArray{<:AbstractArray{<:Real,N}},
125+
) where {N}
126+
# the function barrier fixes performance issues if `sampler(s)` is type unstable
127+
return _rand!(rng, sampler(s), x)
128+
end
129+
130+
Base.@propagate_inbounds function _rand!(
131+
rng::AbstractRNG,
132+
s::Sampleable{ArrayLikeVariate{N}},
133+
x::AbstractArray{<:AbstractArray{<:Real,N}},
134+
) where {N}
135+
for i in eachindex(x)
136+
rand!(rng, s, @inbounds(x[i]))
137+
end
138+
return x
139+
end
140+
48141
"""
49142
sampler(d::Distribution) -> Sampleable
50143
sampler(s::Sampleable) -> s

src/matrixvariates.jl

Lines changed: 1 addition & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,6 @@ _rand!(::AbstractRNG, ::MatrixDistribution, A::AbstractMatrix)
9393

9494
## sampling
9595

96-
# multivariate with pre-allocated 3D array
97-
function _rand!(rng::AbstractRNG, s::Sampleable{Matrixvariate},
98-
m::AbstractArray{<:Real, 3})
99-
@boundscheck (size(m, 1), size(m, 2)) == (size(s, 1), size(s, 2)) ||
100-
throw(DimensionMismatch("Output size inconsistent with matrix size."))
101-
smp = sampler(s)
102-
for i in Base.OneTo(size(m,3))
103-
_rand!(rng, smp, view(m,:,:,i))
104-
end
105-
return m
106-
end
10796

10897
# multiple matrix-variates with pre-allocated array of maybe pre-allocated matrices
10998
rand!(rng::AbstractRNG, s::Sampleable{Matrixvariate},
@@ -127,120 +116,18 @@ function rand!(rng::AbstractRNG, s::Sampleable{Matrixvariate},
127116
return X
128117
end
129118

130-
# multiple matrix-variates, must allocate array of arrays
131-
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate}, dims::Dims) =
132-
rand!(rng, s, Array{Matrix{eltype(s)}}(undef, dims), true)
133-
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate,Continuous}, dims::Dims) =
134-
rand!(rng, s, Array{Matrix{float(eltype(s))}}(undef, dims), true)
135-
136-
# single matrix-variate, must allocate one matrix
137-
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate}) =
138-
_rand!(rng, s, Matrix{eltype(s)}(undef, size(s)))
139-
rand(rng::AbstractRNG, s::Sampleable{Matrixvariate,Continuous}) =
140-
_rand!(rng, s, Matrix{float(eltype(s))}(undef, size(s)))
141-
142-
# single matrix-variate with pre-allocated matrix
143-
function rand!(rng::AbstractRNG, s::Sampleable{Matrixvariate},
144-
A::AbstractMatrix{<:Real})
145-
@boundscheck size(A) == size(s) ||
146-
throw(DimensionMismatch("Output size inconsistent with matrix size."))
147-
return _rand!(rng, s, A)
148-
end
149-
150119
# pdf & logpdf
151120

152121
_logpdf(d::MatrixDistribution, X::AbstractMatrix{<:Real}) = logkernel(d, X) + d.logc0
153122

154-
_pdf(d::MatrixDistribution, x::AbstractMatrix{<:Real}) = exp(_logpdf(d, x))
155-
156-
"""
157-
logpdf(d::MatrixDistribution, AbstractMatrix)
158-
159-
Compute the logarithm of the probability density at the input matrix `x`.
160-
"""
161-
function logpdf(d::MatrixDistribution, x::AbstractMatrix{<:Real})
162-
size(x) == size(d) ||
163-
throw(DimensionMismatch("Inconsistent array dimensions."))
164-
_logpdf(d, x)
165-
end
166-
167-
"""
168-
pdf(d::MatrixDistribution, x::AbstractArray)
169-
170-
Compute the probability density at the input matrix `x`.
171-
"""
172-
function pdf(d::MatrixDistribution, x::AbstractMatrix{<:Real})
173-
size(x) == size(d) ||
174-
throw(DimensionMismatch("Inconsistent array dimensions."))
175-
_pdf(d, x)
176-
end
177-
178-
function _logpdf!(r::AbstractArray, d::MatrixDistribution, X::AbstractArray{<:AbstractMatrix{<:Real}})
179-
for i = 1:length(X)
180-
r[i] = logpdf(d, X[i])
181-
end
182-
return r
183-
end
184-
185-
function _pdf!(r::AbstractArray, d::MatrixDistribution, X::AbstractArray{M}) where M<:Matrix
186-
for i = 1:length(X)
187-
r[i] = pdf(d, X[i])
188-
end
189-
return r
190-
end
191-
192-
function logpdf!(r::AbstractArray, d::MatrixDistribution, X::AbstractArray{M}) where M<:Matrix
193-
length(X) == length(r) ||
194-
throw(DimensionMismatch("Inconsistent array dimensions."))
195-
_logpdf!(r, d, X)
196-
end
197-
198-
function pdf!(r::AbstractArray, d::MatrixDistribution, X::AbstractArray{M}) where M<:Matrix
199-
length(X) == length(r) ||
200-
throw(DimensionMismatch("Inconsistent array dimensions."))
201-
_pdf!(r, d, X)
202-
end
203-
204-
function logpdf(d::MatrixDistribution, X::AbstractArray{<:AbstractMatrix{<:Real}})
205-
map(Base.Fix1(logpdf, d), X)
206-
end
207-
208-
function pdf(d::MatrixDistribution, X::AbstractArray{<:AbstractMatrix{<:Real}})
209-
map(Base.Fix1(pdf, d), X)
210-
end
211-
212-
"""
213-
_logpdf(d::MatrixDistribution, x::AbstractArray)
214-
215-
Evaluate logarithm of pdf value for a given sample `x`. This function need not perform dimension checking.
216-
"""
217-
_logpdf(d::MatrixDistribution, x::AbstractArray)
218-
219-
"""
220-
loglikelihood(d::MatrixDistribution, x::AbstractArray)
221-
222-
The log-likelihood of distribution `d` with respect to all samples contained in array `x`.
223-
224-
Here, `x` can be a matrix of size `size(d)`, a three-dimensional array with `size(d, 1)`
225-
rows and `size(d, 2)` columns, or an array of matrices of size `size(d)`.
226-
"""
227-
loglikelihood(d::MatrixDistribution, X::AbstractMatrix{<:Real}) = logpdf(d, X)
228-
function loglikelihood(d::MatrixDistribution, X::AbstractArray{<:Real,3})
229-
(size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("Inconsistent array dimensions."))
230-
return sum(i -> _logpdf(d, view(X, :, :, i)), axes(X, 3))
231-
end
232-
function loglikelihood(d::MatrixDistribution, X::AbstractArray{<:AbstractMatrix{<:Real}})
233-
return sum(x -> logpdf(d, x), X)
234-
end
235-
236123
# for testing
237124
is_univariate(d::MatrixDistribution) = size(d) == (1, 1)
238125
check_univariate(d::MatrixDistribution) = is_univariate(d) || throw(ArgumentError("not 1 x 1"))
239126

240127
##### Specific distributions #####
241128

242129
for fname in ["wishart.jl", "inversewishart.jl", "matrixnormal.jl",
243-
"matrixreshaped.jl", "matrixtdist.jl", "matrixbeta.jl",
130+
"matrixreshaped.jl", "matrixtdist.jl", "matrixbeta.jl",
244131
"matrixfdist.jl", "lkj.jl"]
245132
include(joinpath("matrix", fname))
246133
end

src/multivariates.jl

Lines changed: 3 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,6 @@ a vector of length `dim(d)` or a matrix with `dim(d)` rows.
2424
"""
2525
rand!(rng::AbstractRNG, d::MultivariateDistribution, x::AbstractArray)
2626

27-
# multivariate with pre-allocated array
28-
function _rand!(rng::AbstractRNG, s::Sampleable{Multivariate}, m::AbstractMatrix)
29-
@boundscheck size(m, 1) == length(s) ||
30-
throw(DimensionMismatch("Output size inconsistent with sample length."))
31-
smp = sampler(s)
32-
for i in Base.OneTo(size(m,2))
33-
_rand!(rng, smp, view(m,:,i))
34-
end
35-
return m
36-
end
37-
38-
# single multivariate with pre-allocated vector
39-
function rand!(rng::AbstractRNG, s::Sampleable{Multivariate},
40-
v::AbstractVector{<:Real})
41-
@boundscheck length(v) == length(s) ||
42-
throw(DimensionMismatch("Output size inconsistent with sample length."))
43-
_rand!(rng, s, v)
44-
end
45-
4627
# multiple multivariates with pre-allocated array of maybe pre-allocated vectors
4728
rand!(rng::AbstractRNG, s::Sampleable{Multivariate},
4829
X::AbstractArray{<:AbstractVector}) =
@@ -65,22 +46,11 @@ function rand!(rng::AbstractRNG, s::Sampleable{Multivariate},
6546
return X
6647
end
6748

68-
# multiple multivariate, must allocate matrix or array of vectors
69-
rand(s::Sampleable{Multivariate}, n::Int) = rand(GLOBAL_RNG, s, n)
49+
# multiple multivariate, must allocate matrix
7050
rand(rng::AbstractRNG, s::Sampleable{Multivariate}, n::Int) =
71-
_rand!(rng, s, Matrix{eltype(s)}(undef, length(s), n))
51+
@inbounds rand!(rng, sampler(s), Matrix{eltype(s)}(undef, length(s), n))
7252
rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}, n::Int) =
73-
_rand!(rng, s, Matrix{float(eltype(s))}(undef, length(s), n))
74-
rand(rng::AbstractRNG, s::Sampleable{Multivariate}, dims::Dims) =
75-
rand!(rng, s, Array{Vector{eltype(s)}}(undef, dims), true)
76-
rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}, dims::Dims) =
77-
rand!(rng, s, Array{Vector{float(eltype(s))}}(undef, dims), true)
78-
79-
# single multivariate, must allocate vector
80-
rand(rng::AbstractRNG, s::Sampleable{Multivariate}) =
81-
_rand!(rng, s, Vector{eltype(s)}(undef, length(s)))
82-
rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}) =
83-
_rand!(rng, s, Vector{float(eltype(s))}(undef, length(s)))
53+
@inbounds rand!(rng, sampler(s), Matrix{float(eltype(s))}(undef, length(s), n))
8454

8555
## domain
8656

@@ -193,83 +163,6 @@ Return the logarithm of probability density evaluated at `x`.
193163
"""
194164
logpdf(d::MultivariateDistribution, x::AbstractArray)
195165

196-
_pdf(d::MultivariateDistribution, X::AbstractVector) = exp(_logpdf(d, X))
197-
198-
function logpdf(d::MultivariateDistribution, X::AbstractVector)
199-
length(X) == length(d) ||
200-
throw(DimensionMismatch("Inconsistent array dimensions."))
201-
_logpdf(d, X)
202-
end
203-
204-
function pdf(d::MultivariateDistribution, X::AbstractVector)
205-
length(X) == length(d) ||
206-
throw(DimensionMismatch("Inconsistent array dimensions."))
207-
_pdf(d, X)
208-
end
209-
210-
function _logpdf!(r::AbstractArray, d::MultivariateDistribution, X::AbstractMatrix)
211-
for i in 1 : size(X,2)
212-
@inbounds r[i] = logpdf(d, view(X,:,i))
213-
end
214-
return r
215-
end
216-
217-
function _pdf!(r::AbstractArray, d::MultivariateDistribution, X::AbstractMatrix)
218-
for i in 1 : size(X,2)
219-
@inbounds r[i] = pdf(d, view(X,:,i))
220-
end
221-
return r
222-
end
223-
224-
function logpdf!(r::AbstractArray, d::MultivariateDistribution, X::AbstractMatrix)
225-
size(X) == (length(d), length(r)) ||
226-
throw(DimensionMismatch("Inconsistent array dimensions."))
227-
_logpdf!(r, d, X)
228-
end
229-
230-
function pdf!(r::AbstractArray, d::MultivariateDistribution, X::AbstractMatrix)
231-
size(X) == (length(d), length(r)) ||
232-
throw(DimensionMismatch("Inconsistent array dimensions."))
233-
_pdf!(r, d, X)
234-
end
235-
236-
function logpdf(d::MultivariateDistribution, X::AbstractMatrix)
237-
size(X, 1) == length(d) ||
238-
throw(DimensionMismatch("Inconsistent array dimensions."))
239-
map(i -> _logpdf(d, view(X, :, i)), axes(X, 2))
240-
end
241-
242-
function pdf(d::MultivariateDistribution, X::AbstractMatrix)
243-
size(X, 1) == length(d) ||
244-
throw(DimensionMismatch("Inconsistent array dimensions."))
245-
map(i -> _pdf(d, view(X, :, i)), axes(X, 2))
246-
end
247-
248-
"""
249-
_logpdf{T<:Real}(d::MultivariateDistribution, x::AbstractArray)
250-
251-
Evaluate logarithm of pdf value for a given vector `x`. This function need not perform dimension checking.
252-
Generally, one does not need to implement `pdf` (or `_pdf`) as fallback methods are provided in `src/multivariates.jl`.
253-
"""
254-
_logpdf(d::MultivariateDistribution, x::AbstractArray)
255-
256-
"""
257-
loglikelihood(d::MultivariateDistribution, x::AbstractArray)
258-
259-
The log-likelihood of distribution `d` with respect to all samples contained in array `x`.
260-
261-
Here, `x` can be a vector of length `dim(d)`, a matrix with `dim(d)` rows, or an array of
262-
vectors of length `dim(d)`.
263-
"""
264-
loglikelihood(d::MultivariateDistribution, X::AbstractVector{<:Real}) = logpdf(d, X)
265-
function loglikelihood(d::MultivariateDistribution, X::AbstractMatrix{<:Real})
266-
size(X, 1) == length(d) || throw(DimensionMismatch("Inconsistent array dimensions."))
267-
return sum(i -> _logpdf(d, view(X, :, i)), 1:size(X, 2))
268-
end
269-
function loglikelihood(d::MultivariateDistribution, X::AbstractArray{<:AbstractVector})
270-
return sum(x -> logpdf(d, x), X)
271-
end
272-
273166
##### Specific distributions #####
274167

275168
for fname in ["dirichlet.jl",

0 commit comments

Comments
 (0)