Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -58,7 +57,6 @@ OptimizationOptimisers = "0.3"
OrdinaryDiffEqTsit5 = "1"
Pkg = "1"
Printf = "1.10"
ProgressLogging = "0.1"
Random = "1.10"
Reexport = "1.2"
ReverseDiff = "1"
Expand Down Expand Up @@ -109,6 +107,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[targets]
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff",
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
"OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays",
"Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]
10 changes: 6 additions & 4 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ name = "OptimizationOptimisers"
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.3.13"

[deps]
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -19,14 +20,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
julia = "1.10"
OptimizationBase = "3"
ProgressLogging = "0.1"
SciMLBase = "2.58"
Optimisers = "0.2, 0.3, 0.4"
Reexport = "1.2"
Logging = "1.10"

[targets]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"]
121 changes: 59 additions & 62 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module OptimizationOptimisers

using Reexport, Printf, ProgressLogging
using Reexport, UUIDs, Logging
@reexport using Optimisers, OptimizationBase
using SciMLBase

Expand Down Expand Up @@ -95,77 +95,74 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
gevals = 0
t0 = time()
breakall = false
begin
for epoch in 1:epochs
if breakall
break
progress_id = uuid4()
for epoch in 1:epochs, d in data
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = OptimizationBase.OptimizationState(
iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
if !(breakall isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif breakall
break
end
if cache.progress
message = "Loss: $(round(first(first(x)); digits = 3))"
@logmsg(LogLevel(-1), "Optimization", _id=progress_id,
message=message, progress=iterations / maxiters)
end
if cache.solver_args.save_best
if first(x)[1] < first(min_err)[1] #found a better solution
min_opt = opt
min_err = x
min_θ = copy(θ)
end
for (i, d) in enumerate(data)
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = OptimizationBase.OptimizationState(
iter = i + (epoch - 1) * length(data),
if iterations == length(data) * epochs #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d)
opt_state = OptimizationBase.OptimizationState(iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
if !(breakall isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif breakall
break
end
msg = @sprintf("loss: %.3g", first(x)[1])
#cache.progress && ProgressLogging.@logprogress msg iterations/maxiters

if cache.solver_args.save_best
if first(x)[1] < first(min_err)[1] #found a better solution
min_opt = opt
min_err = x
min_θ = copy(θ)
end
if iterations == length(data) * epochs #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d)
opt_state = OptimizationBase.OptimizationState(iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
break
end
end
state, θ = Optimisers.update(state, θ, G)
break
end
end
state, θ = Optimisers.update(state, θ, G)
end

cache.progress && @logmsg(LogLevel(-1), "Optimization",
_id=progress_id, message="Done", progress=1.0)
t1 = time()
stats = OptimizationBase.OptimizationStats(; iterations,
time = t1 - t0, fevals, gevals)
Expand Down
2 changes: 1 addition & 1 deletion src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ if !isdefined(Base, :get_extension)
using Requires
end

using Logging, ProgressLogging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
using Logging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra

import OptimizationBase: instantiate_function, OptimizationCache, ReInitCache
Expand Down
Loading