Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e332d8c
remove the type `ParamSpaceSGD`
Red-Portal Sep 15, 2025
1f35cc9
run formatter
Red-Portal Sep 15, 2025
c8404b6
run formatter
Red-Portal Sep 15, 2025
0cc7538
run formatter
Red-Portal Sep 15, 2025
ede91c6
fix rename file paramspacesgd.jl to interface.jl
Red-Portal Oct 13, 2025
625f429
Merge branch 'remove_paramspacesgd' of github.com:TuringLang/Advanced…
Red-Portal Oct 13, 2025
e3c2761
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into remov…
Red-Portal Oct 13, 2025
683a09d
throw invalid state for unknown paramspacesgd type
Red-Portal Oct 13, 2025
570fe11
add docstring for union type of paramspacesgd algorithms
Red-Portal Oct 13, 2025
2d5f373
fix remove custom state types for paramspacesgd algorithms
Red-Portal Oct 13, 2025
e0221eb
fix remove custom state types for paramspacesgd
Red-Portal Oct 13, 2025
e51ab3c
fix file path
Red-Portal Oct 13, 2025
e49c680
fix bug in BijectorsExt
Red-Portal Oct 13, 2025
3c5b56f
fix include `SubSampleObjective` as part of `ParamSpaceSGD`
Red-Portal Oct 13, 2025
30f5160
fix formatting
Red-Portal Oct 13, 2025
008c4ea
fix revert adding SubsampledObjective into ParamSpaceSGD
Red-Portal Oct 13, 2025
8a18902
refactor flatten algorithms
Red-Portal Oct 13, 2025
b002e1e
fix error update paths in main file
Red-Portal Oct 13, 2025
1ba361f
refactor flatten the tests to reflect new structure
Red-Portal Oct 13, 2025
86baa07
fix file include path in tests
Red-Portal Oct 13, 2025
67e9375
fix missing operator in subsampledobj tests
Red-Portal Oct 13, 2025
9b2eabb
fix formatting
Red-Portal Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ function AdvancedVI.init(
obj_st = AdvancedVI.init(rng, objective, adtype, q_init, prob, params, re)
avg_st = AdvancedVI.init(averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
return AdvancedVI.ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
return (
prob=prob,
q=q_init,
iteration=0,
grad_buf=grad_buf,
opt_st=opt_st,
obj_st=obj_st,
avg_st=avg_st,
)
end

function AdvancedVI.apply(
Expand Down
18 changes: 8 additions & 10 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,10 @@ export optimize
include("utils.jl")
include("optimize.jl")

## Parameter Space SGD
include("algorithms/paramspacesgd/abstractobjective.jl")
include("algorithms/paramspacesgd/paramspacesgd.jl")
## Parameter Space SGD Implementations

export ParamSpaceSGD
include("algorithms/abstractobjective.jl")

## Parameter Space SGD Implementations
### ELBO Maximization

abstract type AbstractEntropyEstimator end
Expand All @@ -304,10 +301,10 @@ Estimate the entropy of `q`.
"""
function estimate_entropy end

include("algorithms/paramspacesgd/subsampledobjective.jl")
include("algorithms/paramspacesgd/repgradelbo.jl")
include("algorithms/paramspacesgd/scoregradelbo.jl")
include("algorithms/paramspacesgd/entropy.jl")
include("algorithms/subsampledobjective.jl")
include("algorithms/repgradelbo.jl")
include("algorithms/scoregradelbo.jl")
include("algorithms/entropy.jl")

export RepGradELBO,
ScoreGradELBO,
Expand All @@ -318,7 +315,8 @@ export RepGradELBO,
StickingTheLandingEntropyZeroGradient,
SubsampledObjective

include("algorithms/paramspacesgd/constructors.jl")
include("algorithms/constructors.jl")
include("algorithms/interface.jl")

export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,43 @@ KL divergence minimization by running stochastic gradient descent with the repar
- `operator::AbstractOperator`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
- The variational approximation ``q_{\\lambda}`` implements `rand`.
- The target distribution and the variational approximation have the same support.
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
- Additonal requirements on `q` may apply depending on the choice of `entropy`.
"""
struct KLMinRepGradDescent{
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

function KLMinRepGradDescent(
adtype::ADTypes.AbstractADType;
entropy::Union{<:ClosedFormEntropy,<:StickingTheLandingEntropy,<:MonteCarloEntropy}=ClosedFormEntropy(),
Expand All @@ -39,7 +69,11 @@ function KLMinRepGradDescent(
else
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinRepGradDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

const ADVI = KLMinRepGradDescent
Expand All @@ -63,12 +97,42 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.
- `averager::AbstractAverager`: Parameter averaging strategy. (default: `PolynomialAveraging()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The variational family is `MvLocationScale`.
- The target distribution and the variational approximation have the same support.
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
- Additonal requirements on `q` may apply depending on the choice of `entropy_zerograd`.
"""
struct KLMinRepGradProxDescent{
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:ProximalLocationScaleEntropy,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

function KLMinRepGradProxDescent(
adtype::ADTypes.AbstractADType;
entropy_zerograd::Union{
Expand All @@ -85,7 +149,11 @@ function KLMinRepGradProxDescent(
else
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy_zerograd), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinRepGradProxDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

"""
Expand All @@ -106,15 +174,45 @@ KL divergence minimization by running stochastic gradient descent with the score
- `operator::Union{<:IdentityOperator, <:ClipScale}`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
- The variational approximation ``q_{\\lambda}`` implements `rand`.
- The variational approximation ``q_{\\lambda}`` implements `logpdf(q, x)`, which should also be differentiable with respect to `x`.
- The target distribution and the variational approximation have the same support.
"""
struct KLMinScoreGradDescent{
Obj<:Union{<:ScoreGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

function KLMinScoreGradDescent(
adtype::ADTypes.AbstractADType;
optimizer::Union{<:Descent,<:DoG,<:DoWG}=DoWG(),
optimizer::Optimisers.AbstractRule=DoWG(),
n_samples::Int=1,
averager::AbstractAverager=PolynomialAveraging(),
operator::AbstractOperator=IdentityOperator(),
Expand All @@ -125,7 +223,11 @@ function KLMinScoreGradDescent(
else
SubsampledObjective(ScoreGradELBO(n_samples), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinScoreGradDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

const BBVI = KLMinScoreGradDescent
File renamed without changes.
83 changes: 83 additions & 0 deletions src/algorithms/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

"""
This family of algorithms (`<:KLMinRepGradDescent`,`<:KLMinRepGradProxDescent`,`<:KLMinScoreGradDescent`) applies stochastic gradient descent (SGD) to the variational `objective` over the (Euclidean) space of variational parameters.
The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`.
This requires the variational approximation to be marked as a functor through `Functors.@functor`.
"""
const ParamSpaceSGD = Union{
<:KLMinRepGradDescent,<:KLMinRepGradProxDescent,<:KLMinScoreGradDescent
}

function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob)
(; adtype, optimizer, averager, objective, operator) = alg
if q_init isa AdvancedVI.MvLocationScale && operator isa AdvancedVI.IdentityOperator
@warn(
"IdentityOperator is used with a variational family <:MvLocationScale. Optimization can easily fail under this combination due to singular scale matrices. Consider using the operator `ClipScale` in the algorithm instead.",
)
end
params, re = Optimisers.destructure(q_init)
opt_st = Optimisers.setup(optimizer, params)
obj_st = init(rng, objective, adtype, q_init, prob, params, re)
avg_st = init(averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
return (
prob=prob,
q=q_init,
iteration=0,
grad_buf=grad_buf,
opt_st=opt_st,
obj_st=obj_st,
avg_st=avg_st,
)
end

function output(alg::ParamSpaceSGD, state)
params_avg = value(alg.averager, state.avg_st)
_, re = Optimisers.destructure(state.q)
return re(params_avg)
end

function step(
rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, objargs...; kwargs...
)
(; adtype, objective, operator, averager) = alg
(; prob, q, iteration, grad_buf, opt_st, obj_st, avg_st) = state

iteration += 1

params, re = Optimisers.destructure(q)

grad_buf, obj_st, info = estimate_gradient!(
rng, objective, adtype, grad_buf, obj_st, params, re, objargs...
)

grad = DiffResults.gradient(grad_buf)
opt_st, params = Optimisers.update!(opt_st, params, grad)
params = apply(operator, typeof(q), opt_st, params, re)
avg_st = apply(averager, avg_st, params)

state = (
prob=prob,
q=re(params),
iteration=iteration,
grad_buf=grad_buf,
opt_st=opt_st,
obj_st=obj_st,
avg_st=avg_st,
)

if !isnothing(callback)
averaged_params = value(averager, avg_st)
info′ = callback(;
rng,
iteration,
restructure=re,
params=params,
averaged_params=averaged_params,
gradient=grad,
state=state,
)
info = !isnothing(info′) ? merge(info′, info) : info
end
state, false, info
end
Loading
Loading