Skip to content

Commit 487a943

Browse files
authored
Merge pull request #19 from TuringLang/resampling_rng
Add RNG to resampling methods
2 parents c87cba7 + fbea67e commit 487a943

File tree

8 files changed

+53
-43
lines changed

8 files changed

+53
-43
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
9+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1011

1112
[compat]

src/AdvancedPS.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module AdvancedPS
22

33
import Distributions
44
import Libtask
5+
import Random
56
import StatsFuns
67

78
include("resampling.jl")

src/container.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ function effectiveSampleSize(pc::ParticleContainer)
166166
end
167167

168168
"""
169-
resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing;
170-
weights = getweights(pc)])
169+
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
170+
ref = nothing; weights = getweights(pc)])
171171
172172
Resample and propagate the particles in `pc`.
173173
@@ -176,8 +176,9 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
176176
`ref` that is ensured to survive the resampling step.
177177
"""
178178
function resample_propagate!(
179+
rng::Random.AbstractRNG,
179180
pc::ParticleContainer,
180-
randcat = resample,
181+
randcat = resample_systematic,
181182
ref::Union{Particle, Nothing} = nothing;
182183
weights = getweights(pc)
183184
)
@@ -187,7 +188,7 @@ function resample_propagate!(
187188
# sample ancestor indices
188189
n = length(pc)
189190
nresamples = ref === nothing ? n : n - 1
190-
indx = randcat(weights, nresamples)
191+
indx = randcat(rng, weights, nresamples)
191192

192193
# count number of children for each particle
193194
num_children = zeros(Int, n)
@@ -230,6 +231,7 @@ function resample_propagate!(
230231
end
231232

232233
function resample_propagate!(
234+
rng::Random.AbstractRNG,
233235
pc::ParticleContainer,
234236
resampler::ResampleWithESSThreshold,
235237
ref::Union{Particle,Nothing} = nothing;
@@ -239,7 +241,7 @@ function resample_propagate!(
239241
ess = inv(sum(abs2, weights))
240242

241243
if ess resampler.threshold * length(pc)
242-
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
244+
resample_propagate!(rng, pc, resampler.resampler, ref; weights = weights)
243245
end
244246

245247
pc
@@ -292,7 +294,7 @@ function reweight!(pc::ParticleContainer)
292294
end
293295

294296
"""
295-
sweep!(pc::ParticleContainer, resampler)
297+
sweep!(rng, pc::ParticleContainer, resampler)
296298
297299
Perform a particle sweep and return an unbiased estimate of the log evidence.
298300
@@ -303,11 +305,11 @@ The resampling steps use the given `resampler`.
303305
Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers.
304306
Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436.
305307
"""
306-
function sweep!(pc::ParticleContainer, resampler)
308+
function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler)
307309
# Initial step:
308310

309311
# Resample and propagate particles.
310-
resample_propagate!(pc, resampler)
312+
resample_propagate!(rng, pc, resampler)
311313

312314
# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
313315
# weights.
@@ -317,7 +319,7 @@ function sweep!(pc::ParticleContainer, resampler)
317319
logZ0 = logZ(pc)
318320

319321
# Reweight the particles by including the first observation ``y₁``.
320-
isdone = reweight!(pc)
322+
isdone = reweight!(rng, pc)
321323

322324
# Compute the normalizing constant ``Z₁`` after reweighting.
323325
logZ1 = logZ(pc)
@@ -328,14 +330,14 @@ function sweep!(pc::ParticleContainer, resampler)
328330
# For observations ``y₂, …, yₜ``:
329331
while !isdone
330332
# Resample and propagate particles.
331-
resample_propagate!(pc, resampler)
333+
resample_propagate!(rng, pc, resampler)
332334

333335
# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
334336
# weights.
335337
logZ0 = logZ(pc)
336338

337339
# Reweight the particles by including the next observation ``yₜ``.
338-
isdone = reweight!(pc)
340+
isdone = reweight!(rng, pc)
339341

340342
# Compute the normalizing constant ``Z₁`` after reweighting.
341343
logZ1 = logZ(pc)

src/resampling.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,33 @@ struct ResampleWithESSThreshold{R, T<:Real}
1212
threshold::T
1313
end
1414

15-
function ResampleWithESSThreshold(resampler = resample)
15+
function ResampleWithESSThreshold(resampler = resample_systematic)
1616
ResampleWithESSThreshold(resampler, 0.5)
1717
end
1818

1919
# More stable, faster version of rand(Categorical)
20-
function randcat(p::AbstractVector{<:Real})
20+
function randcat(rng::Random.AbstractRNG, p::AbstractVector{<:Real})
2121
T = eltype(p)
22-
r = rand(T)
22+
r = rand(rng, T)
23+
cp = p[1]
2324
s = 1
24-
for j in eachindex(p)
25-
r -= p[j]
26-
if r <= zero(T)
27-
s = j
28-
break
29-
end
25+
n = length(p)
26+
while cp <= r && s < n
27+
@inbounds cp += p[s += 1]
3028
end
3129
return s
3230
end
3331

3432
function resample_multinomial(
33+
rng::Random.AbstractRNG,
3534
w::AbstractVector{<:Real},
3635
num_particles::Integer = length(w),
3736
)
38-
return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles)
37+
return rand(rng, Distributions.sampler(Distributions.Categorical(w)), num_particles)
3938
end
4039

4140
function resample_residual(
41+
rng::Random.AbstractRNG,
4242
w::AbstractVector{<:Real},
4343
num_particles::Integer = length(weights),
4444
)
@@ -57,19 +57,19 @@ function resample_residual(
5757
end
5858
residuals[j] = x - floor_x
5959
end
60-
60+
6161
# sampling from residuals
6262
if i <= num_particles
6363
residuals ./= sum(residuals)
64-
rand!(Distributions.Categorical(residuals), view(indices, i:num_particles))
64+
rand!(rng, Distributions.Categorical(residuals), view(indices, i:num_particles))
6565
end
66-
66+
6767
return indices
6868
end
6969

7070

7171
"""
72-
resample_stratified(weights, n)
72+
resample_stratified(rng, weights, n)
7373
7474
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
7575
generated by stratified resampling.
@@ -80,7 +80,11 @@ are selected according to the multinomial distribution defined by the normalized
8080
i.e., `xᵢ = j` if and only if
8181
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
8282
"""
83-
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights))
83+
function resample_stratified(
84+
rng::Random.AbstractRNG,
85+
weights::AbstractVector{<:Real},
86+
n::Integer = length(weights),
87+
)
8488
# check input
8589
m = length(weights)
8690
m > 0 || error("weight vector is empty")
@@ -93,7 +97,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
9397
sample = 1
9498
@inbounds for i in 1:n
9599
# sample next `u` (scaled by `n`)
96-
u = oftype(v, i - 1 + rand())
100+
u = oftype(v, i - 1 + rand(rng))
97101

98102
# as long as we have not found the next sample
99103
while v < u
@@ -114,7 +118,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
114118
end
115119

116120
"""
117-
resample_systematic(weights, n)
121+
resample_systematic(rng, weights, n)
118122
119123
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
120124
generated by systematic resampling.
@@ -125,14 +129,18 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n
125129
normalized `weights`, i.e., `xᵢ = j` if and only if
126130
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
127131
"""
128-
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights))
132+
function resample_systematic(
133+
rng::Random.AbstractRNG,
134+
weights::AbstractVector{<:Real},
135+
n::Integer = length(weights),
136+
)
129137
# check input
130138
m = length(weights)
131139
m > 0 || error("weight vector is empty")
132140

133141
# pre-calculations
134142
@inbounds v = n * weights[1]
135-
u = oftype(v, rand())
143+
u = oftype(v, rand(rng))
136144

137145
# find all samples
138146
samples = Array{Int}(undef, n)
@@ -158,6 +166,3 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = lengt
158166

159167
return samples
160168
end
161-
162-
# Default resampling scheme
163-
const resample = resample_systematic

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
3+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
34
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
45

56
[compat]

test/container.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
@test AdvancedPS.logZ(pc) log(sum(exp, 2 .* logps))
4848

4949
# Resample and propagate particles.
50-
AdvancedPS.resample_propagate!(pc)
50+
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
5151
@test pc.logWs == zeros(3)
5252
@test AdvancedPS.getweights(pc) == fill(1/3, 3)
5353
@test all(AdvancedPS.getweight(pc, i) == 1/3 for i in 1:3)

test/resampling.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
@testset "resampling.jl" begin
22
D = [0.3, 0.4, 0.3]
33
num_samples = Int(1e6)
4+
rng = Random.GLOBAL_RNG
45

5-
resSystematic = AdvancedPS.resample_systematic(D, num_samples )
6-
resStratified = AdvancedPS.resample_stratified(D, num_samples )
7-
resMultinomial= AdvancedPS.resample_multinomial(D, num_samples )
8-
resResidual = AdvancedPS.resample_residual(D, num_samples )
9-
AdvancedPS.resample(D)
10-
resSystematic2= AdvancedPS.resample(D, num_samples )
6+
resSystematic = AdvancedPS.resample_systematic(rng, D, num_samples)
7+
resStratified = AdvancedPS.resample_stratified(rng, D, num_samples)
8+
resMultinomial= AdvancedPS.resample_multinomial(rng, D, num_samples)
9+
resResidual = AdvancedPS.resample_residual(rng, D, num_samples)
10+
AdvancedPS.resample_systematic(rng, D)
1111

1212
@test sum(resSystematic .== 2) (num_samples * 0.4) atol=1e-3*num_samples
13-
@test sum(resSystematic2 .== 2) (num_samples * 0.4) atol=1e-3*num_samples
1413
@test sum(resStratified .== 2) (num_samples * 0.4) atol=1e-3*num_samples
1514
@test sum(resMultinomial .== 2) (num_samples * 0.4) atol=1e-2*num_samples
1615
@test sum(resResidual .== 2) (num_samples * 0.4) atol=1e-2*num_samples
17-
end
16+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using AdvancedPS
22
using Libtask
3+
using Random
34
using Test
45

56
@testset "AdvancedPS.jl" begin

0 commit comments

Comments
 (0)