Skip to content

Best practice for DiffOpt.jl implementation with Flux (logsumexp) #228

@JinraeKim

Description

@JinraeKim

Hi, developers! Thanks for this promising and potentially useful package.

I'm studying differentiable convex optimisation and trying to implement it to the PLSE, a neural network that I proposed.
I used to use cvxpylayers but I'm sick of the slow speed of Python stuff. So I'm wondering if I can implement this through DiffOpt.jl.

Background

I have a neural network (called PLSE) f(x, u; \theta) with two inputs x (condition) and u (decision) and the network parameter theta. f(x, \cdot) is guaranteed to be convex, and the corresponding convex optimisation is exponential cone program (the original form is log-sum-exp). This is implemented in ParametrisedConvexApproximators.jl.

What I'm trying to do

It is pretty simple.
I wanna get the derivative du*/d\theta where the optimal decision u*(x, \theta) which minimises f(x, \cdot; \theta) possibly within a prescribed set (decision space) and the network parameter \theta.
You can find this idea with cvxpylayers here.

Issues with DiffOpt.jl

Before addressing this, I'm not familiar with this package. Please lmk if there are any workarounds that I missed.
So what I tried is following Custom ReLU example. For this, I need to define the objective function.
An example code would be

using ParametrisedConvexApproximators
using JuMP
import DiffOpt
import SCS
import ChainRulesCore
import Flux


function main()
    model = Model(() -> DiffOpt.diff_optimizer(SCS.Optimizer))
    n, m = 3, 2
    i_max = 20
    T = 1e-0
    h_array = [64]
    act = Flux.relu
    plse = PLSE(n, m, i_max, T, h_array, act)
    x = rand(n)
    @show plse(x, rand(m))
    @variable(model, u[1:m])
    # @objective(model, Min, plse(x, u)[1])
    # optimize!(model)
    # return value.(u)
end

Note that the output of plse is a vector with 1-element.

And the following is how to obtain the plse(x, u), which can be found here.

function (nn::PLSE)(x::AbstractArray, u::AbstractArray)
    @unpack T = nn
    is_vector = length(size(x)) == 1
    @assert is_vector == (length(size(u)) == 1)
    x = is_vector ? reshape(x, :, 1) : x
    u = is_vector ? reshape(u, :, 1) : u
    @assert size(x)[2] == size(u)[2]
    tmp = affine_map(nn, x, u)
    _res = T * Flux.logsumexp((1/T)*tmp, dims=1)
    res = is_vector ? reshape(_res, 1) : _res
    return res
end

And in the Flux.logsumexp, I encountered this error:

1|julia> Flux.logsumexp((1/T)*tmp, dims=1)
ERROR: MethodError: no method matching isless(::AffExpr, ::AffExpr)
Closest candidates are:
  isless(::Any, ::Missing) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/missing.jl:88
  isless(::Missing, ::Any) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/missing.jl:87
Stacktrace:
  [1] max(x::AffExpr, y::AffExpr)
    @ Base ./operators.jl:492
  [2] mapreduce_impl(f::typeof(identity), op::typeof(max), A::Matrix{AffExpr}, first::Int64, last::Int64)
    @ Base ./reduce.jl:635
  [3] _mapreducedim!(f::typeof(identity), op::typeof(max), R::Matrix{AffExpr}, A::Matrix{AffExpr})
    @ Base ./reducedim.jl:260
  [4] mapreducedim!
    @ ./reducedim.jl:289 [inlined]
  [5] _mapreduce_dim
    @ ./reducedim.jl:336 [inlined]
  [6] #mapreduce#731
    @ ./reducedim.jl:322 [inlined]
  [7] #_maximum#769
    @ ./reducedim.jl:916 [inlined]
  [8] _maximum
    @ ./reducedim.jl:916 [inlined]
  [9] #_maximum#768
    @ ./reducedim.jl:915 [inlined]
 [10] _maximum
    @ ./reducedim.jl:915 [inlined]
 [11] #maximum#746
    @ ./reducedim.jl:889 [inlined]
 [12] logsumexp(x::Matrix{AffExpr}; dims::Int64)
    @ NNlib ~/.julia/packages/NNlib/tvMmZ/src/softmax.jl:142
 [13] top-level scope
    @ none:1
 [14] eval
    @ ./boot.jl:373 [inlined]
 [15] eval_code(frame::JuliaInterpreter.Frame, expr::Expr)
    @ JuliaInterpreter ~/.julia/packages/JuliaInterpreter/4B89D/src/utils.jl:649
 [16] eval_code(frame::JuliaInterpreter.Frame, command::String)
    @ JuliaInterpreter ~/.julia/packages/JuliaInterpreter/4B89D/src/utils.jl:627
 [17] _eval_code(frame::JuliaInterpreter.Frame, code::String)
    @ Debugger ~/.julia/packages/Debugger/I4w2y/src/repl.jl:211
 [18] (::Debugger.var"#27#29"{Debugger.DebuggerState})(s::REPL.LineEdit.MIState, buf::IOBuffer, ok::Bool)
    @ Debugger ~/.julia/packages/Debugger/I4w2y/src/repl.jl:194
 [19] #invokelatest#2
    @ ./essentials.jl:716 [inlined]
 [20] invokelatest
    @ ./essentials.jl:714 [inlined]
 [21] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
    @ REPL.LineEdit /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/REPL/src/LineEdit.jl:2493
 [22] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface)
    @ REPL.LineEdit /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/REPL/src/LineEdit.jl:2487
 [23] RunDebugger(frame::JuliaInterpreter.Frame, repl::Nothing, terminal::Nothing; initial_continue::Bool)
    @ Debugger ~/.julia/packages/Debugger/I4w2y/src/repl.jl:167
 [24] macro expansion
    @ ~/.julia/packages/Debugger/I4w2y/src/Debugger.jl:137 [inlined]
 [25] main()
    @ Main ~/.julia/dev/ParametrisedConvexApproximators/test/tmp.jl:20
 [26] top-level scope
    @ REPL[2]:1
 [27] top-level scope
    @ ~/.julia/packages/CUDA/sCev8/src/initialization.jl:52

1|julia> maximum(tmp; dims=1)
ERROR: MethodError: no method matching isless(::AffExpr, ::AffExpr)
Closest candidates are:
  isless(::Any, ::Missing) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/missing.jl:88
  isless(::Missing, ::Any) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/missing.jl:87
Stacktrace:
  [1] max(x::AffExpr, y::AffExpr)
    @ Base ./operators.jl:492
  [2] mapreduce_impl(f::typeof(identity), op::typeof(max), A::Matrix{AffExpr}, first::Int64, last::Int64)
    @ Base ./reduce.jl:635
  [3] _mapreducedim!(f::typeof(identity), op::typeof(max), R::Matrix{AffExpr}, A::Matrix{AffExpr})
    @ Base ./reducedim.jl:260
  [4] mapreducedim!
    @ ./reducedim.jl:289 [inlined]
  [5] _mapreduce_dim
    @ ./reducedim.jl:336 [inlined]
  [6] #mapreduce#731
    @ ./reducedim.jl:322 [inlined]
  [7] #_maximum#769
    @ ./reducedim.jl:916 [inlined]
  [8] _maximum
    @ ./reducedim.jl:916 [inlined]
  [9] #_maximum#768
    @ ./reducedim.jl:915 [inlined]
 [10] _maximum
    @ ./reducedim.jl:915 [inlined]
 [11] #maximum#746
    @ ./reducedim.jl:889 [inlined]
 [12] top-level scope
    @ none:1
 [13] eval
    @ ./boot.jl:373 [inlined]
 [14] eval_code(frame::JuliaInterpreter.Frame, expr::Expr)
    @ JuliaInterpreter ~/.julia/packages/JuliaInterpreter/4B89D/src/utils.jl:649
 [15] eval_code(frame::JuliaInterpreter.Frame, command::String)
    @ JuliaInterpreter ~/.julia/packages/JuliaInterpreter/4B89D/src/utils.jl:627
 [16] _eval_code(frame::JuliaInterpreter.Frame, code::String)
    @ Debugger ~/.julia/packages/Debugger/I4w2y/src/repl.jl:211
 [17] (::Debugger.var"#27#29"{Debugger.DebuggerState})(s::REPL.LineEdit.MIState, buf::IOBuffer, ok::Bool)
    @ Debugger ~/.julia/packages/Debugger/I4w2y/src/repl.jl:194
 [18] #invokelatest#2
    @ ./essentials.jl:716 [inlined]
 [19] invokelatest
    @ ./essentials.jl:714 [inlined]
 [20] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
    @ REPL.LineEdit /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/REPL/src/LineEdit.jl:2493
 [21] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface)
    @ REPL.LineEdit /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/REPL/src/LineEdit.jl:2487
 [22] RunDebugger(frame::JuliaInterpreter.Frame, repl::Nothing, terminal::Nothing; initial_continue::Bool)
    @ Debugger ~/.julia/packages/Debugger/I4w2y/src/repl.jl:167
 [23] macro expansion
    @ ~/.julia/packages/Debugger/I4w2y/src/Debugger.jl:137 [inlined]
 [24] main()
    @ Main ~/.julia/dev/ParametrisedConvexApproximators/test/tmp.jl:20
 [25] top-level scope
    @ REPL[2]:1
 [26] top-level scope
    @ ~/.julia/packages/CUDA/sCev8/src/initialization.jl:52

It may be due to the lack of my background knowledge of how to use JuMP and DiffOpt stuff.
How can I realise my idea with DiffOpt.jl?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions