Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 73 additions & 60 deletions src/nlsolve/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,86 +347,99 @@ DiffEqBase.@def getoopnlsolvefields begin
uf = nlsolver.uf
end

## resize NLSolver

function nlsolve_resize!(integrator::DEIntegrator, i::Int)
if !isdefined(integrator.cache, :nlsolver)
return nothing
end
alg = integrator.alg; nlsolver = integrator.cache.nlsolver
(isdefined(integrator.alg, :nlsolve) && isdefined(integrator.cache, :nlsolver)) || return

nlsolver = integrator.cache.nlsolver
if nlsolver isa AbstractArray
for idx in eachindex(nlsolver) # looping because we may have multiple nlsolver for threaded case
_nlsolver = nlsolver[idx]
@unpack z,dz,tmp,ztmp,k,du1,uf,jac_config,linsolve,weight,cache = _nlsolver
# doubt: if these fields are always going to be in alg cache too, then we shouldnt do this here.
# double resize doesn't do any bad I think though
resize!(z,i)
resize!(dz,i)
resize!(tmp,i)
resize!(ztmp,i)
resize!(k,i)
resize!(du1,i)
if jac_config !== nothing
_nlsolver.jac_config = resize_jac_config!(jac_config, i)
end
resize!(weight, i)
nlsolve_cache_resize!(cache,alg,i)
for idx in eachindex(nlsolver)
resize!(nlsolver[idx], i)
end
else
@unpack z,dz,tmp,ztmp,k,du1,uf,jac_config,linsolve,weight,cache = nlsolver
resize!(z,i)
resize!(dz,i)
resize!(tmp,i)
resize!(ztmp,i)
resize!(k,i)
resize!(du1,i)
if jac_config !== nothing
nlsolver.jac_config = resize_jac_config!(jac_config,i)
end
resize!(weight, i)
nlsolve_cache_resize!(cache,alg,i)
resize!(nlsolver, i)
end

nothing
end

function nlsolve_cache_resize!(cache::NLNewtonCache, alg, i::Int)
nothing
function Base.resize!(nlsolver::NLSolver, i::Int)
@unpack z,dz,tmp,ztmp,k,cache = nlsolver

resize!(z, i)
resize!(dz, i)
resize!(tmp, i)
resize!(ztmp, i)
resize!(k, i)

if nlsolver.alg isa NLAnderson
resize!(cache, nlsolver.alg, i)
else
resize!(cache, i)
end
end

function nlsolve_cache_resize!(cache::NLNewtonConstantCache, alg, i::Int)
Base.resize!(::AbstractNLSolverCache, ::Int) = nothing

function Base.resize!(nlcache::NLFunctionalCache, i::Int)
resize!(nlcache.z₊, i)
nothing
end

function nlsolve_cache_resize!(cache::NLAndersonCache, alg, i::Int)
resize!(cache.z₊, i)
resize!(cache.dzold, i)
resize!(cache.z₊old, i)
max_history = min(alg.nlsolve.max_history, alg.nlsolve.max_iter, i)
prev_max_history = length(cache.Δz₊s)
resize!(cache.γs, max_history)
resize!(cache.Δz₊s, max_history)
if max_history > prev_max_history
for i in (max_history - prev_max_history):max_history
cache.Δz₊s[i] = zero(z₊)
end
function Base.resize!(nlcache::NLNewtonCache, i::Int)
@unpack du1,jac_config,linsolve,weight = nlcache

resize!(du1, i)
if jac_config !== nothing
nlsolver.jac_config = resize_jac_config!(jac_config, i)
end
cache.Q = typeof(cache.Q)(undef, i, max_history)
cache.R = typeof(cache.R)(undef, max_history, max_history)
nothing
end
resize!(weight, i)

function nlsolve_cache_resize!(cache::NLAndersonConstantCache, alg, i::Int)
max_history = min(alg.nlsolve.max_history, alg.nlsolve.max_iter, i)
resize!(cache.Δz₊s, max_history)
cache.Q = typeof(cache.Q)(undef, i, max_history)
cache.R = typeof(cache.R)(undef, max_history, max_history)
resize!(cache.γs, max_history)
nothing
end

function nlsolve_cache_resize!(cache::NLFunctionalCache, alg, i::Int)
resize!(cache.z₊, i)
function Base.resize!(nlcache::NLAndersonCache, nlalg::NLAnderson, i::Int)
@unpack z₊, dzold, z₊old, γs, Δz₊s = nlcache

resize!(z₊, i)
resize!(dzold, i)
resize!(z₊old, i)

# determine new maximum history
max_history_old = length(Δz₊s)
max_history = min(nlalg.max_history, nlalg.max_iter, i)

resize!(γs, max_history)
resize!(Δz₊s, max_history)
if max_history > max_history_old
for i in (max_history_old + 1):max_history
Δz₊s[i] = zero(z₊)
end
end

if max_history != max_history_old
nlcache.Q = typeof(nlcache.Q)(undef, i, max_history)
nlcache.R = typeof(nlcache.R)(undef, max_history, max_history)
end

nothing
end

function nlsolve_cache_resize!(cache::NLFunctionalConstantCache, alg, i::Int)
function Base.resize!(nlcache::NLAndersonConstantCache, nlalg::NLAnderson, i::Int)
@unpack γs, Δz₊s = nlcache

# determine new maximum history
max_history_old = length(Δz₊s)
max_history = min(nlalg.max_history, nlalg.max_iter, i)

resize!(γs, max_history)
resize!(Δz₊s, max_history)

if max_history != max_history_old
nlcache.Q = typeof(nlcache.Q)(undef, i, max_history)
nlcache.R = typeof(nlcache.R)(undef, max_history, max_history)
end

nothing
end