- 
                Notifications
    
You must be signed in to change notification settings  - Fork 22
 
Adaptive proposals #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
e9fd602
              5f0ddfa
              f42b784
              a188780
              d8989aa
              565f12a
              cc16195
              802ec67
              c7623c4
              a33937e
              4007bd0
              71e010b
              2d59ede
              279aea7
              93f17c5
              16715e1
              046c21b
              b91fcc0
              387eff4
              8fb1048
              4071675
              a63262d
              afd3ed1
              fe8562c
              f66f647
              2988ff2
              e5ad041
              7aa8631
              4999349
              d4b3f6b
              cb52c7f
              5a2a175
              b2be967
              518aab1
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| mutable struct Adaptor | ||
| accepted::Integer | ||
| total ::Integer | ||
| tune ::Integer # tuning interval | ||
| target ::Float64 # target acceptance rate | ||
| bound ::Float64 # bound on logσ of Gaussian kernel | ||
| δmax ::Float64 # maximum adaptation step | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| end | ||
| 
     | 
||
| Adaptor(; tune=25, target=0.44, bound=10., δmax=0.2) = | ||
| Adaptor(0, 0, tune, target, bound, δmax) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| """ | ||
| AdaptiveProposal{P} | ||
| An adaptive Metropolis-Hastings proposal. In order for this to work, the | ||
| proposal kernel should implement the `adapted(proposal, δ)` method, where `δ` | ||
| is the increment/decrement applied to the scale of the proposal distribution | ||
| during adaptation (e.g. for a Normal distribution the scale is `log(σ)`, so | ||
| that after adaptation the proposal is `Normal(0, exp(log(σ) + δ))`). | ||
                
       | 
||
| # Example | ||
| ```julia | ||
| julia> p = AdaptiveProposal(Uniform(-0.2, 0.2)); | ||
| julia> rand(p) | ||
| 0.07975590594518434 | ||
| ``` | ||
| # References | ||
| Roberts, Gareth O., and Jeffrey S. Rosenthal. "Examples of adaptive MCMC." | ||
| Journal of Computational and Graphical Statistics 18.2 (2009): 349-367. | ||
| """ | ||
| mutable struct AdaptiveProposal{P} <: Proposal{P} | ||
| proposal::P | ||
| adaptor ::Adaptor | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
                
       | 
||
| end | ||
| 
     | 
||
| function AdaptiveProposal(p; kwargs...) | ||
| AdaptiveProposal(p, Adaptor(; kwargs...)) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| end | ||
| 
     | 
||
| accepted!(p::AdaptiveProposal) = p.adaptor.accepted += 1 | ||
| accepted!(p::Vector{<:AdaptiveProposal}) = map(accepted!, p) | ||
| accepted!(p::NamedTuple{names}) where names = map(x->accepted!(getfield(p, x)), names) | ||
| 
     | 
||
| # this is defined because the first draw has no transition yet (I think) | ||
| propose(rng::Random.AbstractRNG, p::AdaptiveProposal, m::DensityModel) = | ||
| rand(rng, p.proposal) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| # the actual proposal happens here | ||
| function propose( | ||
| rng::Random.AbstractRNG, | ||
| proposal::AdaptiveProposal{<:Union{Distribution,Proposal}}, | ||
| model::DensityModel, | ||
| t | ||
| ) | ||
| consider_adaptation!(proposal) | ||
| t + rand(rng, proposal.proposal) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| end | ||
| 
     | 
||
| function q(proposal::AdaptiveProposal, t, t_cond) | ||
| logpdf(proposal, t - t_cond) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| end | ||
| 
     | 
||
| function consider_adaptation!(p) | ||
| (p.adaptor.total % p.adaptor.tune == 0) && adapt!(p) | ||
| p.adaptor.total += 1 | ||
| end | ||
| 
     | 
||
| function adapt!(p::AdaptiveProposal) | ||
| a = p.adaptor | ||
| a.total == 0 && return | ||
| δ = min(a.δmax, 1. /√(a.total/a.tune)) # diminishing adaptation | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| α = a.accepted / a.tune # acceptance ratio | ||
| p_ = adapted(p.proposal, α > a.target ? δ : -δ, a.bound) | ||
| a.accepted = 0 | ||
| p.proposal = p_ | ||
| end | ||
| 
     | 
||
| function adapted(d::Normal, δ, bound=Inf) | ||
| lσ = log(d.σ) + δ | ||
| lσ = abs(lσ) > bound ? sign(lσ) * bound : lσ | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| Normal(d.μ, exp(lσ)) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| end | ||
| 
     | 
||
| function adapted(d::Uniform, δ, bound=Inf) | ||
| lσ = log(d.b) + δ | ||
| σ = abs(lσ) > bound ? exp(sign(lσ) * bound) : exp(lσ) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| Uniform(-σ, σ) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| end | ||
| 
     | 
||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -103,4 +103,4 @@ function q( | |
| t_cond | ||
| ) | ||
| return q(proposal(t_cond), t, t_cond) | ||
| end | ||
| end | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -11,7 +11,7 @@ using Test | |
| Random.seed!(1234) | ||
| 
     | 
||
| # Generate a set of data from the posterior we want to estimate. | ||
| data = rand(Normal(0, 1), 300) | ||
| data = rand(Normal(0., 1), 300) | ||
                
      
                  arzwa marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| # Define the components of a basic model. | ||
| insupport(θ) = θ[2] >= 0 | ||
| 
          
            
          
           | 
    @@ -52,6 +52,32 @@ using Test | |
| @test mean(chain2.μ) ≈ 0.0 atol=0.1 | ||
| @test mean(chain2.σ) ≈ 1.0 atol=0.1 | ||
| end | ||
| 
     | 
||
| @testset "Adaptive random walk" begin | ||
| # Set up our sampler with initial parameters. | ||
| spl1 = MetropolisHastings([AdaptiveProposal(Normal(0,.4)), AdaptiveProposal(Normal(0,1.2))]) | ||
| spl2 = MetropolisHastings((μ=AdaptiveProposal(Normal(0,1.4)), σ=AdaptiveProposal(Normal(0,0.2)))) | ||
| 
     | 
||
| # Sample from the posterior. | ||
| chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"]) | ||
| chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"]) | ||
| 
     | 
||
| # chn_mean ≈ dist_mean atol=atol_v | ||
| @test mean(chain1.μ) ≈ 0.0 atol=0.1 | ||
| @test mean(chain1.σ) ≈ 1.0 atol=0.1 | ||
| @test mean(chain2.μ) ≈ 0.0 atol=0.1 | ||
| @test mean(chain2.σ) ≈ 1.0 atol=0.1 | ||
| end | ||
| 
     | 
||
| @testset "Compare adaptive to simple random walk" begin | ||
| data = rand(Normal(2., 1.), 500) | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arzwa This might be the problem - you redefine  You could just rename the variable here but actually I think the better approach might be to "fix" the data in the model to avoid any such surprises in the future. I guess this can be achieved by defining density = let data = data
  θ -> insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf
endThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I haven't tested it, so make sure it actually fixes the problem 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I just saw this too, thanks. I'll check and push an updated test suite. (Actually, we could just as well test against the same data defined above in the test suite, but I find testing against a mean different from 0 a bit more reassuring since the sampler actually has to 'move' to somewhere from where it starts).  | 
||
| m1 = DensityModel(x -> loglikelihood(Normal(x,1), data)) | ||
| p1 = RandomWalkProposal(Normal()) | ||
| p2 = AdaptiveProposal(Normal()) | ||
| c1 = sample(m1, MetropolisHastings(p1), 10000; chain_type=Chains) | ||
| c2 = sample(m1, MetropolisHastings(p2), 10000; chain_type=Chains) | ||
| @test ess(c2).nt.ess > ess(c1).nt.ess | ||
| end | ||
| 
     | 
||
| @testset "parallel sampling" begin | ||
| spl1 = StaticMH([Normal(0,1), Normal(0, 1)]) | ||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible one should avoid non-concrete fields:
On a more general level, I'm not completely sure if it is useful to have a separate
Adaptorstruct, it seems it could just be integrated intoAdaptiveProposal.On an even more general level, I think it would be better to make this part of the state of the sampler using the AbstractMCMC interface instead of a field of the proposal. With the current design, the proposal will be mutated in every step. However, this (IMO preferred) design requires to implement
AbstractMCMC.stepinstead of just adding theaccept!call.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes of course,
Int <-> Integerconfusion...The
Adaptorstruct may indeed be superfluous, although I found it a bit clearer separated that way. Also, I considered implementing adaptation for multivariate Normal proposals, which uses a different machinery under the hood, and my initial thought was to implement that as anAdaptiveProposalbut with differentAdaptortype. Of course, that could be implemented as another proposal struct altogether.I think I understand conceptually your preferred design at the
steplevel, although ATM my insight in howAbstractMCMCworks is insufficient to see how that should be done, and currently, to me the mutation of the proposals is the most intuitive approach to implement adaptation. Theaccept!call seemed like a very simple, but admittedly somewhat hacky, way to enable adaptive proposals.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should punt this problem to a later date. I would like to include
accept/rejectas a field in theTransitionstruct, which would make it very easy to count the number of previous acceptances by just adding adding one to thetotal_acceptancesfield in aTransition. Currently AdvancedMH doesn't track that internally, but I can just modify this code to remove the mutation later.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not only about the number of accepted/rejected steps here though, the state would have to include the updated proposal etc as well, so it won't be solved by including the stats in Transition.
However, I'm fine with postponing this refactoring. Probably best to open an issue so we don't forget it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my thinking, you'd add an extra field to Transitions that just accumulates the total number of acceptances, which is easier to get when you have individual acceptances for each draw. I'll open an issue up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I understand (and I think that's a good addition). My point was just that it is not sufficient here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've opened an issue (#40).