-
Notifications
You must be signed in to change notification settings - Fork 162
(Ready for review): Switch combinator #334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
marcoct
merged 31 commits into
probcomp:master
from
femtomc:20201116_mrb_switch_combinator
Dec 8, 2020
Merged
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
3e4f695
Initial work on a Switch combinator.
femtomc bd4f830
Initial implementation of propose and generate.
femtomc 374a7b0
Added implementaton of simulate.
femtomc 5872593
Corrected some bugs with Bernoulli vs bernoulli.
femtomc 9c0a9f2
Added assess implementation.
femtomc 95baf07
Split into two combinators: Switch and WithProbability implementations.
femtomc 29b7797
Working on Switch update and regenerate.
femtomc 3e6e307
Added Switch update and regenerate.
femtomc 7929b86
Added Switch update and regenerate - working out kinks in update.
femtomc 73618a1
update and regenerate appear to be computing the correct ratios. To c…
femtomc 252413f
Fixed generate index type bug.
femtomc ac3528e
Branch dispatch done using diff types.
femtomc eaf3327
Branch dispatch done using diff types.
femtomc 6d58aac
Branch dispatch done using diff types.
femtomc e413e9c
Added custom methods in update for Switch which allow the merging of …
femtomc 435493f
Added custom methods in update for Switch which allow the merging of …
femtomc 32fec4f
Idiomatic check for EmptyChoiceMap.
femtomc bb767e7
Working on backprop - seems simple? Could it really be?
femtomc a35e2e7
Extracting WithProb combinator into another PR.
femtomc 562667e
Testing backprop.
femtomc b74a071
Fixed backprop - was thinking in Zygote lang. Gradients appear to be …
femtomc 915811d
Merge branch 'master' of https://github.com/probcomp/Gen.jl into 2020…
femtomc 849d61e
Added docstring and docs example.
femtomc adf73a5
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc dfe0125
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc 3717d65
Tests for everything but gradients - working on gradients now.
femtomc cb62fb5
Last tests I need to write: accumulate_param_gradients!
femtomc 97473d0
Added accumulate_param_gradients! tests.
femtomc 176b9e9
Reverted particle filter fix - will be handled in another issue.
femtomc 0465965
Renamed mix field of Switch generative function to branches to more a…
femtomc 43c7274
Addressed review comments. Added docstrings where necessary. Correcte…
femtomc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # ------------ Switch trace ------------ # | ||
|
|
||
| struct SwitchTrace{T} <: Trace | ||
| gen_fn::GenerativeFunction{T} | ||
| index::Int | ||
| branch::Trace | ||
| retval::T | ||
| args::Tuple | ||
| score::Float64 | ||
| noise::Float64 | ||
| end | ||
|
|
||
| @inline get_choices(tr::SwitchTrace) = get_choices(tr.branch) | ||
| @inline get_retval(tr::SwitchTrace) = tr.retval | ||
| @inline get_args(tr::SwitchTrace) = tr.args | ||
| @inline get_score(tr::SwitchTrace) = tr.score | ||
| @inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn | ||
| @inline Base.getindex(tr::SwitchTrace, addr) = Base.getindex(tr.branch, addr) | ||
| @inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection) | ||
| @inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| mutable struct SwitchAssessState{T} | ||
| weight::Float64 | ||
| retval::T | ||
| SwitchAssessState{T}(weight::Float64) where T = new{T}(weight) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| choices::ChoiceMap, | ||
| state::SwitchAssessState{T}) where {C, N, K, T} | ||
| (weight, retval) = assess(getindex(gen_fn.branches, index), args, choices) | ||
| state.weight = weight | ||
| state.retval = retval | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) | ||
|
|
||
| function assess(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple, | ||
| choices::ChoiceMap) where {C, N, K, T} | ||
| index = args[1] | ||
| state = SwitchAssessState{T}(0.0) | ||
| process!(gen_fn, index, args[2 : end], choices, state) | ||
| return state.weight, state.retval | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| @inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = choice_gradients(getfield(trace, :branch), selection, retval_grad) | ||
| @inline accumulate_param_gradients!(trace::SwitchTrace{T}, retval_grad, scale_factor = 1.) where {T} = accumulate_param_gradients!(getfield(trace, :branch), retval_grad, scale_factor) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| mutable struct SwitchGenerateState{T} | ||
| score::Float64 | ||
| noise::Float64 | ||
| weight::Float64 | ||
| index::Int | ||
| subtrace::Trace | ||
| retval::T | ||
| SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| choices::ChoiceMap, | ||
| state::SwitchGenerateState{T}) where {C, N, K, T} | ||
|
|
||
| (subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices) | ||
| state.index = index | ||
| state.subtrace = subtrace | ||
| state.weight += weight | ||
| state.retval = get_retval(subtrace) | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) | ||
|
|
||
| function generate(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple, | ||
| choices::ChoiceMap) where {C, N, K, T} | ||
|
|
||
| index = args[1] | ||
| state = SwitchGenerateState{T}(0.0, 0.0, 0.0) | ||
| process!(gen_fn, index, args[2 : end], choices, state) | ||
| return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| mutable struct SwitchProposeState{T} | ||
| choices::DynamicChoiceMap | ||
| weight::Float64 | ||
| retval::T | ||
| SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| state::SwitchProposeState{T}) where {C, N, K, T} | ||
|
|
||
| (submap, weight, retval) = propose(getindex(gen_fn.branches, index), args) | ||
| state.choices = submap | ||
| state.weight += weight | ||
| state.retval = retval | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) | ||
|
|
||
| function propose(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple) where {C, N, K, T} | ||
|
|
||
| index = args[1] | ||
| choices = choicemap() | ||
| state = SwitchProposeState{T}(choices, 0.0) | ||
| process!(gen_fn, index, args[2:end], state) | ||
| return state.choices, state.weight, state.retval | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| mutable struct SwitchRegenerateState{T} | ||
| weight::Float64 | ||
| score::Float64 | ||
| noise::Float64 | ||
| prev_trace::Trace | ||
| trace::Trace | ||
| index::Int | ||
| retdiff::Diff | ||
| SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) | ||
| end | ||
|
|
||
| @inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::EmptySelection) = prev_choices | ||
| @inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::AllSelection) = choicemap() | ||
| function regenerate_recurse_merge(prev_choices::ChoiceMap, selection::Selection) | ||
| prev_choice_value_iterator = get_values_shallow(prev_choices) | ||
| prev_choice_submap_iterator = get_submaps_shallow(prev_choices) | ||
| new_choices = choicemap() | ||
| for (key, value) in prev_choice_value_iterator | ||
| in(key, selection) && continue | ||
| set_value!(new_choices, key, value) | ||
| end | ||
| for (key, node1) in prev_choice_submap_iterator | ||
| if in(key, selection) | ||
| subsel = getindex(selection, key) | ||
| node = regenerate_recurse_merge(node1, subsel) | ||
| set_submap!(new_choices, key, node) | ||
| else | ||
| set_submap!(new_choices, key, node1) | ||
| end | ||
| end | ||
| return new_choices | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| index_argdiff::UnknownChange, | ||
| args::Tuple, | ||
| kernel_argdiffs::Tuple, | ||
| selection::Selection, | ||
| state::SwitchRegenerateState{T}) where {C, N, K, T} | ||
| branch_fn = getfield(gen_fn.branches, index) | ||
| merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection) | ||
|
||
| new_trace, weight = generate(branch_fn, args, merged) | ||
| retdiff = UnknownChange() | ||
| weight -= project(state.prev_trace, complement(selection)) | ||
| weight += (project(new_trace, selection) - project(state.prev_trace, selection)) | ||
| state.index = index | ||
| state.weight = weight | ||
| state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) | ||
| state.score = get_score(new_trace) | ||
| state.trace = new_trace | ||
| state.retdiff = retdiff | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| index_argdiff::NoChange, | ||
| args::Tuple, | ||
| kernel_argdiffs::Tuple, | ||
| selection::Selection, | ||
| state::SwitchRegenerateState{T}) where {C, N, K, T} | ||
| new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) | ||
| state.index = index | ||
| state.weight = weight | ||
| state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) | ||
| state.score = get_score(new_trace) | ||
| state.trace = new_trace | ||
| state.retdiff = retdiff | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) | ||
|
|
||
| function regenerate(trace::SwitchTrace{T}, | ||
| args::Tuple, | ||
| argdiffs::Tuple, | ||
| selection::Selection) where T | ||
| gen_fn = trace.gen_fn | ||
| index, index_argdiff = args[1], argdiffs[1] | ||
| state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) | ||
| process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state) | ||
| return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| mutable struct SwitchSimulateState{T} | ||
| score::Float64 | ||
| noise::Float64 | ||
| index::Int | ||
| subtrace::Trace | ||
| retval::T | ||
| SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) | ||
| end | ||
|
|
||
| function process!(gen_fn::Switch{C, N, K, T}, | ||
| index::Int, | ||
| args::Tuple, | ||
| state::SwitchSimulateState{T}) where {C, N, K, T} | ||
| local retval::T | ||
| subtrace = simulate(getindex(gen_fn.branches, index), args) | ||
| state.index = index | ||
| state.noise += project(subtrace, EmptySelection()) | ||
| state.subtrace = subtrace | ||
| state.score += get_score(subtrace) | ||
| state.retval = get_retval(subtrace) | ||
| end | ||
|
|
||
| @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) | ||
|
|
||
| function simulate(gen_fn::Switch{C, N, K, T}, | ||
| args::Tuple) where {C, N, K, T} | ||
|
|
||
| index = args[1] | ||
| state = SwitchSimulateState{T}(0.0, 0.0) | ||
| process!(gen_fn, index, args[2 : end], state) | ||
| return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace} | ||
| branches::NTuple{N, GenerativeFunction{T}} | ||
| cases::Dict{C, Int} | ||
| function Switch(gen_fns::GenerativeFunction...) | ||
| @assert !isempty(gen_fns) | ||
| rettype = get_return_type(getindex(gen_fns, 1)) | ||
| new{Int, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, Dict{Int, Int}()) | ||
| end | ||
| function Switch(d::Dict{C, Int}, gen_fns::GenerativeFunction...) where C | ||
| @assert !isempty(gen_fns) | ||
| rettype = get_return_type(getindex(gen_fns, 1)) | ||
| new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d) | ||
| end | ||
| end | ||
| export Switch | ||
|
|
||
| has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.branches)...)) do as | ||
| all(as) | ||
| end | ||
| accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches) | ||
|
|
||
| function (gen_fn::Switch)(index::Int, args...) | ||
| (_, _, retval) = propose(gen_fn, (index, args...)) | ||
| retval | ||
| end | ||
|
|
||
| function (gen_fn::Switch{C})(index::C, args...) where C | ||
| (_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...)) | ||
| retval | ||
| end | ||
|
|
||
| include("assess.jl") | ||
| include("propose.jl") | ||
| include("simulate.jl") | ||
| include("generate.jl") | ||
| include("update.jl") | ||
| include("regenerate.jl") | ||
| include("backprop.jl") | ||
|
|
||
| @doc( | ||
| """ | ||
| gen_fn = Switch(gen_fns::GenerativeFunction...) | ||
|
|
||
| Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` where the first index indicates which branch to call. | ||
|
|
||
| gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T | ||
|
|
||
| Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` or an argument tuple of type `Tuple{T, ...}` where the first index either indicates which branch to call, or indicates an index into `d` which maps to the selected branch. This form is meant for convenience - it allows the programmer to use `d` like if-else or case statements. | ||
|
|
||
| `Switch` is designed to allow for the expression of patterns of if-else control flow. `gen_fns` must satisfy a few requirements: | ||
|
|
||
| 1. Each `gen_fn` in `gen_fns` must accept the same argument types. | ||
| 2. Each `gen_fn` in `gen_fns` must return the same return type. | ||
|
|
||
| Otherwise, each `gen_fn` can come from different modeling languages, possess different traces, etc. | ||
| """, Switch) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to use
Diffinstead ofUnknownChangehere. There might be otherDifftypes that could be passed that are intermediate betweenUnknownChangeandNoChange(e.g. there is anIntDiffalready, which tracks the arithmetic difference between two integers). This method should apply to anything that's not aNoChange.