Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 96 additions & 79 deletions base/random/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

## RandomDevice

const BoolBitIntegerType = Union{Type{Bool},Base.BitIntegerType}
const BoolBitIntegerArray = Union{Array{Bool},Base.BitIntegerArray}
# SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...}
SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...}
const SamplerBoolBitInteger = SamplerUnion(Union{Bool, Base.BitInteger})

if Sys.iswindows()
struct RandomDevice <: AbstractRNG
Expand All @@ -12,15 +13,9 @@ if Sys.iswindows()
RandomDevice() = new(Vector{UInt128}(uninitialized, 1))
end

function rand(rd::RandomDevice, T::BoolBitIntegerType)
function rand(rd::RandomDevice, sp::SamplerBoolBitInteger)
rand!(rd, rd.buffer)
@inbounds return rd.buffer[1] % T
end

function rand!(rd::RandomDevice, A::BoolBitIntegerArray)
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
A, sizeof(A))
A
@inbounds return rd.buffer[1] % sp[]
end
else # !windows
struct RandomDevice <: AbstractRNG
Expand All @@ -31,10 +26,22 @@ else # !windows
new(open(unlimited ? "/dev/urandom" : "/dev/random"), unlimited)
end

rand(rd::RandomDevice, T::BoolBitIntegerType) = read( rd.file, T)
rand!(rd::RandomDevice, A::BoolBitIntegerArray) = read!(rd.file, A)
rand(rd::RandomDevice, sp::SamplerBoolBitInteger) = read( rd.file, sp[])
end # os-test

# NOTE: this can't be put within the if-else block above
for T in (Bool, Base.BitInteger_types...)
if Sys.iswindows()
@eval function rand!(rd::RandomDevice, A::Array{$T}, ::SamplerType{$T})
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
A, sizeof(A))
A
end
else
@eval rand!(rd::RandomDevice, A::Array{$T}, ::SamplerType{$T}) = read!(rd.file, A)
end
end

"""
RandomDevice()

Expand All @@ -49,7 +56,7 @@ srand(rng::RandomDevice) = rng

### generation of floats

rand(r::RandomDevice, I::FloatInterval) = rand_generic(r, I)
rand(r::RandomDevice, sp::SamplerTrivial{<:FloatInterval}) = rand_generic(r, sp[])


## MersenneTwister
Expand Down Expand Up @@ -229,30 +236,31 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)

#### floats

rand(r::MersenneTwister, I::FloatInterval_64) = (reserve_1(r); rand_inbounds(r, I))
rand(r::MersenneTwister, sp::SamplerTrivial{<:FloatInterval_64}) =
(reserve_1(r); rand_inbounds(r, sp[]))

rand(r::MersenneTwister, I::FloatInterval) = rand_generic(r, I)
rand(r::MersenneTwister, sp::SamplerTrivial{<:FloatInterval}) = rand_generic(r, sp[])

#### integers

rand(r::MersenneTwister, T::Union{Type{Bool}, Type{Int8}, Type{UInt8}, Type{Int16}, Type{UInt16},
Type{Int32}, Type{UInt32}}) =
rand_ui52_raw(r) % T
rand(r::MersenneTwister,
T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
rand_ui52_raw(r) % T[]

function rand(r::MersenneTwister, ::Type{UInt64})
function rand(r::MersenneTwister, ::SamplerType{UInt64})
reserve(r, 2)
rand_ui52_raw_inbounds(r) << 32 ⊻ rand_ui52_raw_inbounds(r)
end

function rand(r::MersenneTwister, ::Type{UInt128})
function rand(r::MersenneTwister, ::SamplerType{UInt128})
reserve(r, 3)
xor(rand_ui52_raw_inbounds(r) % UInt128 << 96,
rand_ui52_raw_inbounds(r) % UInt128 << 48,
rand_ui52_raw_inbounds(r))
end

rand(r::MersenneTwister, ::Type{Int64}) = reinterpret(Int64, rand(r, UInt64))
rand(r::MersenneTwister, ::Type{Int128}) = reinterpret(Int128, rand(r, UInt128))
rand(r::MersenneTwister, ::SamplerType{Int64}) = reinterpret(Int64, rand(r, UInt64))
rand(r::MersenneTwister, ::SamplerType{Int128}) = reinterpret(Int128, rand(r, UInt128))

#### arrays of floats

Expand All @@ -278,16 +286,17 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
A
end

rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A)
rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) =
rand_AbstractArray_Float64!(r, A, length(A), I[])

fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
dsfmt_fill_array_close_open!(s, A, n)

fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) =
dsfmt_fill_array_close1_open2!(s, A, n)

function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
I::FloatInterval_64=CloseOpen())
function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int,
I::FloatInterval_64)
# depending on the alignment of A, the data written by fill_array! may have
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
# reproducibility purposes;
Expand Down Expand Up @@ -317,65 +326,63 @@ function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
A
end

rand!(r::MersenneTwister, A::Array{Float64}, sp::SamplerTrivial{<:FloatInterval_64}) =
_rand!(r, A, length(A), sp[])

mask128(u::UInt128, ::Type{Float16}) =
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00

mask128(u::UInt128, ::Type{Float32}) =
(u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000

function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}},
::Close1Open2_64)
T = eltype(A)
n = length(A)
n128 = n * sizeof(T) ÷ 16
Base.@gc_preserve A rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
2*n128, Close1Open2())
# FIXME: This code is completely invalid!!!
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
@inbounds for i in 1:n128
u = A128[i]
u ⊻= u << 26
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
# the bit xor, are:
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
# the bits needing to be random are
# [1:10, 17:26, 33:42, 49:58] (for Float16)
# [1:23, 33:55] (for Float32)
# this is obviously satisfied on the 32 low bits side, and on the high side,
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
# (which are discarded on the low side)
# this is similar for the 64 high bits of u
A128[i] = mask128(u, T)
end
for i in 16*n128÷sizeof(T)+1:n
@inbounds A[i] = rand(r, T) + oneunit(T)
for T in (Float16, Float32)
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{Close1Open2{$T}})
n = length(A)
n128 = n * sizeof($T) ÷ 16
Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
2*n128, Close1Open2())
# FIXME: This code is completely invalid!!!
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
@inbounds for i in 1:n128
u = A128[i]
u ⊻= u << 26
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
# the bit xor, are:
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
# the bits needing to be random are
# [1:10, 17:26, 33:42, 49:58] (for Float16)
# [1:23, 33:55] (for Float32)
# this is obviously satisfied on the 32 low bits side, and on the high side,
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
# (which are discarded on the low side)
# this is similar for the 64 high bits of u
A128[i] = mask128(u, $T)
end
for i in 16*n128÷sizeof($T)+1:n
@inbounds A[i] = rand(r, $T) + oneunit($T)
end
A
end
A
end

function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}, ::CloseOpen_64)
rand!(r, A, Close1Open2())
I32 = one(Float32)
for i in eachindex(A)
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{CloseOpen{$T}})
rand!(r, A, Close1Open2($T))
I32 = one(Float32)
for i in eachindex(A)
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
end
A
end
A
end

rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}) =
rand!(r, A, CloseOpen())

#### arrays of integers

function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
if n > length(A)
throw(BoundsError(A,n))
end
function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
n::Int=length(A)
# FIXME: This code is completely invalid!!!
Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n)
i = n
while true
rand!(r, Af, 2i, Close1Open2())
_rand!(r, Af, 2i, Close1Open2())
n < 5 && break
i = 0
@inbounds while n-i >= 5
Expand All @@ -396,17 +403,18 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
A
end

# A::Array{UInt128} will match the specialized method above
function rand!(r::MersenneTwister, A::Base.BitIntegerArray)
n = length(A)
T = eltype(A)
n128 = n * sizeof(T) ÷ 16
# FIXME: This code is completely invalid!!!
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
for i = 16*n128÷sizeof(T)+1:n
@inbounds A[i] = rand(r, T)
for T in Base.BitInteger_types
T === UInt128 && continue
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerType{$T})
n = length(A)
n128 = n * sizeof($T) ÷ 16
# FIXME: This code is completely invalid!!!
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
for i = 16*n128÷sizeof($T)+1:n
@inbounds A[i] = rand(r, $T)
end
A
end
A
end

#### from a range
Expand All @@ -418,7 +426,9 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
end
end

function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
function rand(rng::MersenneTwister,
sp::SamplerTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool}
r = sp[]
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = last(r) % UInt64 - first(r) % UInt64
bw = (64 - leading_zeros(m)) % UInt # bit-width
Expand All @@ -428,7 +438,9 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInte
(x + first(r) % UInt64) % T
end

function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt128}
function rand(rng::MersenneTwister,
sp::SamplerTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128}
r = sp[]
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = (last(r)-first(r)) % UInt128
bw = (128 - leading_zeros(m)) % UInt # bit-width
Expand All @@ -439,6 +451,11 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt1
x % T + first(r)
end

for T in (Bool, Base.BitInteger_types...) # eval because of ambiguity otherwise
@eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
SamplerTrivial(r)
end


### randjump

Expand Down
Loading