diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 8a70266f8a..56721b601d 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -40,15 +40,15 @@ module OrdinaryDiffEq using DiffEqBase: check_error!, @def, @.. , _vec, _reshape - using DiffEqBase: nlsolvefail, isnewton, set_new_W!, get_W, iipnlsolve, oopnlsolve + using DiffEqBase: nlsolvefail, isnewton, set_new_W!, get_W, get_linsolve, build_nlsolver, nlsolve! using DiffEqBase: NLSolver - using DiffEqBase: FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence + using DiffEqBase: FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence, MaxItersReached import DiffEqBase: calculate_residuals, calculate_residuals!, nlsolve_f, unwrap_cache, @tight_loop_macros, islinear - import DiffEqBase: iip_get_uf, oop_get_uf, build_jac_config + import DiffEqBase: build_jac_config import SparseDiffTools: forwarddiff_color_jacobian!, ForwardColorJacCache diff --git a/src/caches/adams_bashforth_moulton_caches.jl b/src/caches/adams_bashforth_moulton_caches.jl index 3aedb27acc..0648b12df1 100644 --- a/src/caches/adams_bashforth_moulton_caches.jl +++ b/src/caches/adams_bashforth_moulton_caches.jl @@ -905,8 +905,7 @@ end function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1//2, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) k2 = rate_prototype uprev3 = u @@ -917,8 +916,7 @@ end function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1//2, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) k1 = zero(rate_prototype) @@ -955,8 +953,7 @@ end function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1//1, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) k2 = rate_prototype uprev2 = u @@ -968,8 +965,7 @@ end function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1//1, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) k1 = zero(rate_prototype) diff --git a/src/caches/bdf_caches.jl b/src/caches/bdf_caches.jl index 2e5d02d380..c2c3fb39a7 100644 --- a/src/caches/bdf_caches.jl +++ b/src/caches/bdf_caches.jl @@ -8,8 +8,7 @@ end function alg_cache(alg::ABDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits, uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 2//3, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) eulercache = ImplicitEulerConstantCache(nlsolver) dtₙ₋₁ = one(dt) @@ -34,8 +33,7 @@ end function alg_cache(alg::ABDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 2//3, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) fsalfirstprev = zero(rate_prototype) @@ -84,8 +82,7 @@ end function alg_cache(alg::SBDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1//1, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) k2 = rate_prototype k₁ = rate_prototype; k₂ = rate_prototype; k₃ = rate_prototype @@ -98,8 +95,7 @@ end function alg_cache(alg::SBDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1//1, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) order = alg.order @@ -144,8 +140,7 @@ end function alg_cache(alg::QNDF1,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = zero(inv((1-alg.kappa))), 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) uprev2 = u dtₙ₋₁ = t @@ -162,8 +157,7 @@ end function alg_cache(alg::QNDF1,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = zero(inv((1-alg.kappa))), 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) D = Array{typeof(u)}(undef, 1, 1) @@ -215,8 +209,7 @@ end function alg_cache(alg::QNDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = zero(inv((1-alg.kappa))), 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) uprev2 = u uprev3 = u @@ -235,8 +228,7 @@ end function alg_cache(alg::QNDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = zero(inv((1-alg.kappa))), 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) D = Array{typeof(u)}(undef, 1, 2) @@ -292,8 +284,7 @@ end function alg_cache(alg::QNDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = one(eltype(alg.kappa)), 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) udiff = fill(zero(u), 1, 6) dts = fill(zero(dt), 1, 6) @@ -311,8 +302,7 @@ end function alg_cache(alg::QNDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = one(eltype(alg.kappa)), 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) udiff = Array{typeof(u)}(undef, 1, 6) @@ -357,8 +347,7 @@ end function alg_cache(alg::MEBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); tmp2 = zero(u) @@ -374,7 +363,6 @@ end function alg_cache(alg::MEBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) MEBDF2ConstantCache(nlsolver) end diff --git a/src/caches/kencarp_kvaerno_caches.jl b/src/caches/kencarp_kvaerno_caches.jl index 248910aa6e..2a3e43db0b 100644 --- a/src/caches/kencarp_kvaerno_caches.jl +++ b/src/caches/kencarp_kvaerno_caches.jl @@ -7,8 +7,7 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = KenCarp3Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) KenCarp3ConstantCache(nlsolver,tab) end @@ -34,8 +33,7 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = KenCarp3Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) if typeof(f) <: SplitFunction @@ -62,8 +60,7 @@ function alg_cache(alg::Kvaerno4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = Kvaerno4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) Kvaerno4ConstantCache(nlsolver,tab) end @@ -85,8 +82,7 @@ function alg_cache(alg::Kvaerno4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = Kvaerno4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = nlsolver.z @@ -104,8 +100,7 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = KenCarp4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) KenCarp4ConstantCache(nlsolver,tab) end @@ -134,8 +129,7 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = KenCarp4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) if typeof(f) <: SplitFunction @@ -165,8 +159,7 @@ function alg_cache(alg::Kvaerno5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = Kvaerno5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) Kvaerno5ConstantCache(nlsolver,tab) end @@ -191,8 +184,7 @@ function alg_cache(alg::Kvaerno5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = Kvaerno5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = zero(u) @@ -211,8 +203,7 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = KenCarp5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) KenCarp5ConstantCache(nlsolver,tab) end @@ -246,8 +237,7 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = KenCarp5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.c3 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) if typeof(f) <: SplitFunction diff --git a/src/caches/pdirk_caches.jl b/src/caches/pdirk_caches.jl index d24943346c..27e3dbbcd3 100644 --- a/src/caches/pdirk_caches.jl +++ b/src/caches/pdirk_caches.jl @@ -48,14 +48,11 @@ end function alg_cache(alg::PDIRK44,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1.0, 1.0 if alg.threading - J1, W1 = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver1 = iipnlsolve(alg,u,uprev,p,t,dt,f,W1,J1,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - J2, W2 = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver2 = iipnlsolve(alg,u,uprev,p,t,dt,f,W2,J2,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver1 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) + nlsolver2 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) nlsolver = [nlsolver1, nlsolver2] else - _J, _W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - _nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,_W,_J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + _nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) nlsolver = [_nlsolver] end tab = PDIRK44Tableau(real(uBottomEltypeNoUnits), real(tTypeNoUnits)) @@ -67,14 +64,11 @@ end function alg_cache(alg::PDIRK44,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1.0, 1.0 if alg.threading - J1, W1 = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver1 = oopnlsolve(alg,u,uprev,p,t,dt,f,W1,J1,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - J2, W2 = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver2 = oopnlsolve(alg,u,uprev,p,t,dt,f,W2,J2,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver1 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) + nlsolver2 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) nlsolver = [nlsolver1, nlsolver2] else - _J, _W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - _nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,_W,_J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + _nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) nlsolver = [_nlsolver] end tab = PDIRK44Tableau(real(uBottomEltypeNoUnits), real(tTypeNoUnits)) diff --git a/src/caches/rkc_caches.jl b/src/caches/rkc_caches.jl index 847555a6a2..9017823e6d 100644 --- a/src/caches/rkc_caches.jl +++ b/src/caches/rkc_caches.jl @@ -137,8 +137,7 @@ end function alg_cache(alg::IRKC,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1.0, 1.0 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) zprev = u du₁ = rate_prototype; du₂ = rate_prototype IRKCConstantCache(50,zprev,nlsolver,du₁,du₂) @@ -146,8 +145,7 @@ end function alg_cache(alg::IRKC,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1.0, 1.0 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) gprev = similar(u) gprev2 = similar(u) diff --git a/src/caches/rosenbrock_caches.jl b/src/caches/rosenbrock_caches.jl index 450046a38c..82d8e3be76 100644 --- a/src/caches/rosenbrock_caches.jl +++ b/src/caches/rosenbrock_caches.jl @@ -64,7 +64,7 @@ function alg_cache(alg::Rosenbrock23,u,rate_prototype,uEltypeNoUnits,uBottomElty fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = Rosenbrock23Tableau(real(uBottomEltypeNoUnits)) @@ -92,7 +92,7 @@ function alg_cache(alg::Rosenbrock32,u,rate_prototype,uEltypeNoUnits,uBottomElty fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = Rosenbrock32Tableau(real(uBottomEltypeNoUnits)) @@ -124,7 +124,7 @@ end function alg_cache(alg::Rosenbrock23,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rosenbrock23ConstantCache(real(uBottomEltypeNoUnits),tf,uf,J,W,linsolve) end @@ -147,7 +147,7 @@ end function alg_cache(alg::Rosenbrock32,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rosenbrock32ConstantCache(real(uBottomEltypeNoUnits),tf,uf,J,W,linsolve) end @@ -202,7 +202,7 @@ function alg_cache(alg::ROS3P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = ROS3PTableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) @@ -220,7 +220,7 @@ end function alg_cache(alg::ROS3P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rosenbrock33ConstantCache(tf,uf,ROS3PTableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve) end @@ -262,7 +262,7 @@ function alg_cache(alg::Rodas3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = Rodas3Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) @@ -290,7 +290,7 @@ end function alg_cache(alg::Rodas3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rosenbrock34ConstantCache(tf,uf,Rodas3Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve) end @@ -366,7 +366,7 @@ function alg_cache(alg::Rodas4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = Rodas4Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) @@ -386,7 +386,7 @@ end function alg_cache(alg::Rodas4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rodas4ConstantCache(tf,uf,Rodas4Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve) end @@ -406,7 +406,7 @@ function alg_cache(alg::Rodas42,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoU fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = Rodas42Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) @@ -426,7 +426,7 @@ end function alg_cache(alg::Rodas42,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rodas4ConstantCache(tf,uf,Rodas42Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve) end @@ -446,7 +446,7 @@ function alg_cache(alg::Rodas4P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoU fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = Rodas4PTableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) @@ -466,7 +466,7 @@ end function alg_cache(alg::Rodas4P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rodas4ConstantCache(tf,uf,Rodas4PTableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve) end @@ -533,7 +533,7 @@ function alg_cache(alg::Rodas5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = Rodas5Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) @@ -553,7 +553,7 @@ end function alg_cache(alg::Rodas5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) Rosenbrock5ConstantCache(tf,uf,Rodas5Tableau(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve) end diff --git a/src/caches/sdirk_caches.jl b/src/caches/sdirk_caches.jl index bf6134be06..46dd8cc49f 100644 --- a/src/caches/sdirk_caches.jl +++ b/src/caches/sdirk_caches.jl @@ -12,8 +12,7 @@ end function alg_cache(alg::ImplicitEuler,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) atmp = similar(u,uEltypeNoUnits) @@ -28,8 +27,7 @@ end function alg_cache(alg::ImplicitEuler,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) ImplicitEulerConstantCache(nlsolver) end @@ -39,8 +37,7 @@ end function alg_cache(alg::ImplicitMidpoint,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1//2, 1//2 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) ImplicitMidpointConstantCache(nlsolver) end @@ -54,8 +51,7 @@ end function alg_cache(alg::ImplicitMidpoint,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1//2, 1//2 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) ImplicitMidpointCache(u,uprev,fsalfirst,nlsolver) end @@ -69,8 +65,7 @@ end function alg_cache(alg::Trapezoid,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits, uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1//2, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) uprev3 = u tprev2 = t @@ -92,8 +87,7 @@ end function alg_cache(alg::Trapezoid,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1//2, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) uprev3 = zero(u) @@ -112,8 +106,7 @@ function alg_cache(alg::TRBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = TRBDF2Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.d, tab.γ - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) TRBDF2ConstantCache(nlsolver,tab) end @@ -132,8 +125,7 @@ function alg_cache(alg::TRBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = TRBDF2Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.d, tab.γ - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) atmp = similar(u,uEltypeNoUnits); zprev = similar(u); zᵧ = similar(u) @@ -148,8 +140,7 @@ end function alg_cache(alg::SDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits, uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) SDIRK2ConstantCache(nlsolver) end @@ -166,8 +157,7 @@ end function alg_cache(alg::SDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = similar(u); z₂ = nlsolver.z @@ -190,8 +180,7 @@ function alg_cache(alg::SDIRK22,u,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uBo tprev2 = t γ, c = 1, 1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) SDIRK22ConstantCache(uprev3,tprev2,nlsolver) end @@ -211,8 +200,7 @@ end function alg_cache(alg::SDIRK22,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = SDIRK22Tableau(real(uBottomEltypeNoUnits)) γ, c = 1, 1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) uprev3 = zero(u) @@ -229,8 +217,7 @@ end function alg_cache(alg::SSPSDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits, uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) γ, c = 1//4, 1//1 - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) SSPSDIRK2ConstantCache(nlsolver) end @@ -246,8 +233,7 @@ end function alg_cache(alg::SSPSDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits, tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) γ, c = 1//4, 1//1 - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = similar(u); z₂ = nlsolver.z @@ -265,8 +251,7 @@ function alg_cache(alg::Kvaerno3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = Kvaerno3Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, 2tab.γ - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) Kvaerno3ConstantCache(nlsolver,tab) end @@ -287,8 +272,7 @@ function alg_cache(alg::Kvaerno3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = Kvaerno3Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, 2tab.γ - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = nlsolver.z @@ -306,8 +290,7 @@ function alg_cache(alg::Cash4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = Cash4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ,tab.γ - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) Cash4ConstantCache(nlsolver,tab) end @@ -329,8 +312,7 @@ function alg_cache(alg::Cash4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = Cash4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ,tab.γ - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = similar(u); z₅ = nlsolver.z @@ -352,8 +334,7 @@ function alg_cache(alg::Union{Hairer4,Hairer42},u,rate_prototype,uEltypeNoUnits, tab = Hairer42Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) end γ, c = tab.γ, tab.γ - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) Hairer4ConstantCache(nlsolver,tab) end @@ -379,8 +360,7 @@ function alg_cache(alg::Union{Hairer4,Hairer42},u,rate_prototype,uEltypeNoUnits, tab = Hairer42Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) end γ, c = tab.γ, tab.γ - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = similar(u); z₅ = nlsolver.z @@ -403,8 +383,7 @@ function alg_cache(alg::ESDIRK54I8L2SA,u,rate_prototype,uEltypeNoUnits,uBottomEl tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true}) tab = ESDIRK54I8L2SATableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ, tab.γ - J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true)) fsalfirst = zero(rate_prototype) z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u) @@ -423,7 +402,6 @@ function alg_cache(alg::ESDIRK54I8L2SA,u,rate_prototype,uEltypeNoUnits,uBottomEl uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tab = ESDIRK54I8L2SATableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits)) γ, c = tab.γ,tab.γ - J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) - nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) + nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false)) ESDIRK54I8L2SAConstantCache(nlsolver,tab) end diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 56d1ff4aee..3fb745658f 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -36,7 +36,7 @@ function calc_tderivative(integrator, cache) end """ - calc_J(integrator,cache,is_compos) + calc_J(integrator, cache) Interface for calculating the jacobian. @@ -46,36 +46,48 @@ jacobian update function, then it will be called for the update. Otherwise, either ForwardDiff or finite difference will be used depending on the `jac_config` of the cache. """ -function calc_J(integrator, cache::OrdinaryDiffEqConstantCache, is_compos) - @unpack t,dt,uprev,u,f,p = integrator +function calc_J(integrator, cache::OrdinaryDiffEqConstantCache) + @unpack t,uprev,f,p = integrator + if DiffEqBase.has_jac(f) J = f.jac(uprev, p, t) else - cache.uf.t = t - cache.uf.p = p - J = jacobian(cache.uf,uprev,integrator) + @unpack uf = cache + uf.t = t + uf.p = p + J = jacobian(uf,uprev,integrator) end integrator.destats.njacs += 1 - is_compos && (integrator.eigen_est = opnorm(J, Inf)) - return J + + if integrator.alg isa CompositeAlgorithm + integrator.eigen_est = opnorm(J, Inf) + end + + J end -function calc_J(nlsolver, integrator, cache::OrdinaryDiffEqConstantCache, is_compos) - @unpack t,dt,uprev,u,f,p = integrator +function calc_J(nlsolver, integrator, cache::OrdinaryDiffEqConstantCache) + @unpack t,uprev,f,p = integrator + if DiffEqBase.has_jac(f) J = f.jac(uprev, p, t) else - nlsolver.uf.t = t - nlsolver.uf.p = p - J = jacobian(nlsolver.uf,uprev,integrator) + @unpack uf = DiffEqBase.get_cache(nlsolver) + uf.t = t + uf.p = p + J = jacobian(uf,uprev,integrator) end integrator.destats.njacs += 1 - is_compos && (integrator.eigen_est = opnorm(J, Inf)) - return J + + if integrator.alg isa CompositeAlgorithm + integrator.eigen_est = opnorm(J, Inf) + end + + J end """ - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) Interface for calculating the jacobian. @@ -85,20 +97,21 @@ jacobian update function, then it will be called for the update. Otherwise, either ForwardDiff or finite difference will be used depending on the `jac_config` of the cache. """ -function calc_J!(integrator, cache::OrdinaryDiffEqCache, is_compos) +function calc_J!(integrator, cache::OrdinaryDiffEqCache) if isdefined(cache, :nlsolver) - calc_J!(cache.nlsolver, integrator, cache, is_compos) + calc_J!(cache.nlsolver, integrator, cache) elseif isdefined(cache, :J) - calc_J_in_cache!(integrator, cache, is_compos) + calc_J_in_cache!(integrator, cache) else error("No J found in the cache") end end -function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqMutableCache, is_compos) - @unpack t,dt,uprev,u,f,p = integrator - @unpack du1,uf,jac_config = nlsolver +function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqMutableCache) + @unpack t,uprev,f,p = integrator + @unpack du1,uf,jac_config = nlsolver.cache J = nlsolver.cache.J + if DiffEqBase.has_jac(f) f.jac(J, uprev, p, t) else @@ -107,14 +120,16 @@ function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqMutableCac jacobian!(J, uf, uprev, du1, integrator, jac_config) end integrator.destats.njacs += 1 - is_compos && (integrator.eigen_est = opnorm(J, Inf)) + if integrator.alg isa CompositeAlgorithm + integrator.eigen_est = opnorm(J, Inf) + end end -function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqConstantCache, is_compos) - nlsolver.cache.J = calc_J(nlsolver,integrator,cache,is_compos) +function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqConstantCache) + nlsolver.cache.J = calc_J(nlsolver, integrator, cache) end -function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos) +function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache) @unpack t,dt,uprev,u,f,p = integrator J = cache.J if DiffEqBase.has_jac(f) @@ -126,11 +141,13 @@ function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_comp jacobian!(J, uf, uprev, du1, integrator, jac_config) end integrator.destats.njacs += 1 - is_compos && (integrator.eigen_est = opnorm(J, Inf)) + if integrator.alg isa CompositeAlgorithm + integrator.eigen_est = opnorm(J, Inf) + end end -function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqConstantCache, is_compos) - cache.J = calc_J(integrator,cache,is_compos) +function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqConstantCache) + cache.J = calc_J(integrator, cache) end """ @@ -417,7 +434,7 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_ @label J2W W.transform = W_transform; set_gamma!(W, dtgamma) else # concrete W using jacobian from `calc_J!` - new_jac && calc_J!(integrator, cache, is_compos) + new_jac && calc_J!(integrator, cache) new_W && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform) end if isnewton @@ -468,7 +485,7 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_ @label J2W W[W_index].transform = W_transform; set_gamma!(W[W_index], dtgamma) else # concrete W using jacobian from `calc_J!` - new_jac && calc_J!(integrator, cache, is_compos) + new_jac && calc_J!(integrator, cache) new_W && jacobian2W!(W[W_index], mass_matrix, dtgamma, J, W_transform) end if isnewton @@ -519,7 +536,7 @@ function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqMutableCache, dtgamm @label J2W W.transform = W_transform; set_gamma!(W, dtgamma) else # concrete W using jacobian from `calc_J!` - new_jac && calc_J!(nlsolver, integrator, cache, is_compos) + new_jac && calc_J!(nlsolver, integrator, cache) new_W && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform) end if isnewton @@ -549,7 +566,7 @@ function calc_W!(integrator, cache::OrdinaryDiffEqConstantCache, dtgamma, repeat integrator.destats.nw += 1 else integrator.destats.nw += 1 - J = calc_J(integrator, cache, is_compos) + J = calc_J(integrator, cache) W_full = W_transform ? -mass_matrix*inv(dtgamma) + J : -mass_matrix + dtgamma*J W = W_full isa Number ? W_full : lu(W_full) @@ -560,7 +577,7 @@ end function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqConstantCache, dtgamma, repeat_step, W_transform=false) @unpack t,uprev,p,f = integrator - @unpack uf = nlsolver + @unpack uf = nlsolver.cache mass_matrix = integrator.f.mass_matrix isarray = typeof(uprev) <: AbstractArray # calculate W @@ -578,7 +595,7 @@ function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqConstantCache, dtgam integrator.destats.nw += 1 else integrator.destats.nw += 1 - J = calc_J(nlsolver, integrator, cache, is_compos) + J = calc_J(nlsolver, integrator, cache) W_full = W_transform ? -mass_matrix*inv(dtgamma) + J : -mass_matrix + dtgamma*J W = W_full isa Number ? W_full : lu(W_full) @@ -611,6 +628,7 @@ function update_W!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqConstant nothing end - -iip_get_uf(alg::Union{DAEAlgorithm,OrdinaryDiffEqAlgorithm},nf,t,p) = DiffEqDiffTools.UJacobianWrapper(nf,t,p) -oop_get_uf(alg::Union{DAEAlgorithm,OrdinaryDiffEqAlgorithm},nf,t,p) = DiffEqDiffTools.UDerivativeWrapper(nf,t,p) +DiffEqBase.build_uf(alg::Union{DAEAlgorithm,OrdinaryDiffEqAlgorithm},nf,t,p,::Val{true}) = + DiffEqDiffTools.UJacobianWrapper(nf,t,p) +DiffEqBase.build_uf(alg::Union{DAEAlgorithm,OrdinaryDiffEqAlgorithm},nf,t,p,::Val{false}) = + DiffEqDiffTools.UDerivativeWrapper(nf,t,p) \ No newline at end of file diff --git a/src/generic_rosenbrock.jl b/src/generic_rosenbrock.jl index 48d07698e5..c784fee246 100644 --- a/src/generic_rosenbrock.jl +++ b/src/generic_rosenbrock.jl @@ -224,7 +224,7 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab function alg_cache(alg::$algname,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p) uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p) - J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(false)) linsolve = alg.linsolve(Val{:init},uf,u) $constcachename(tf,uf,$tabname(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve) end @@ -236,7 +236,7 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) + J,W = DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,Val(true)) tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) tab = $tabname(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index e98c0177db..e81d679aff 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -379,15 +379,12 @@ function reset_fsal!(integrator) # integrator.reeval_fsal = false end -nlsolve!(integrator, cache) = DiffEqBase.nlsolve!(cache.nlsolver, cache.nlsolver.cache, integrator) -nlsolve!(nlsolver::NLSolver, integrator) = DiffEqBase.nlsolve!(nlsolver, nlsolver.cache, integrator) - DiffEqBase.nlsolve_f(f, alg::OrdinaryDiffEqAlgorithm) = f isa SplitFunction && issplit(alg) ? f.f1 : f DiffEqBase.nlsolve_f(f, alg::DAEAlgorithm) = f DiffEqBase.nlsolve_f(integrator::ODEIntegrator) = nlsolve_f(integrator.f, unwrap_alg(integrator, true)) -function iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) +function DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,::Val{true}) if alg isa NewtonAlgorithm if alg.nlsolve isa NLNewton nf = nlsolve_f(f, alg) @@ -417,7 +414,7 @@ function iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) J, W end -function oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) +function DiffEqBase.build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,::Val{false}) islin = false if alg isa NewtonAlgorithm && alg.nlsolve isa NLNewton nf = nlsolve_f(f, alg) diff --git a/src/perform_step/adams_bashforth_moulton_perform_step.jl b/src/perform_step/adams_bashforth_moulton_perform_step.jl index d890f6f2dc..203f0c9410 100644 --- a/src/perform_step/adams_bashforth_moulton_perform_step.jl +++ b/src/perform_step/adams_bashforth_moulton_perform_step.jl @@ -1542,7 +1542,7 @@ function perform_step!(integrator,cache::CNAB2ConstantCache,repeat_step=false) nlsolver.z = z = zprev # Constant extrapolation nlsolver.tmp += γ*zprev - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + 1//2*z @@ -1558,7 +1558,12 @@ end function initialize!(integrator, cache::CNAB2Cache) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -1569,7 +1574,8 @@ end function perform_step!(integrator, cache::CNAB2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p,alg = integrator @unpack k1,k2,du₁,nlsolver = cache - @unpack z,tmp,k = nlsolver + @unpack z,tmp = nlsolver + k = integrator.fsallast cnt = integrator.iter f1 = integrator.f.f1 f2 = integrator.f.f2 @@ -1591,7 +1597,7 @@ function perform_step!(integrator, cache::CNAB2Cache, repeat_step=false) # initial guess @.. z = dt*du₁ @.. tmp += γ*z - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + 1//2*z @@ -1645,7 +1651,7 @@ function perform_step!(integrator,cache::CNLF2ConstantCache,repeat_step=false) zprev = dt*du₁ nlsolver.z = z = zprev # Constant extrapolation - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z @@ -1662,7 +1668,12 @@ end function initialize!(integrator, cache::CNLF2Cache) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -1673,7 +1684,7 @@ end function perform_step!(integrator, cache::CNLF2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p,alg = integrator @unpack uprev2,k2,du₁,nlsolver = cache - @unpack z,k,tmp = nlsolver + @unpack z,tmp = nlsolver cnt = integrator.iter f1 = integrator.f.f1 f2 = integrator.f.f2 @@ -1696,12 +1707,12 @@ function perform_step!(integrator, cache::CNLF2Cache, repeat_step=false) # initial guess @.. z = dt*du₁ - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z cache.uprev2 .= uprev cache.k2 .= du₁ - integrator.f(k,u,p,t+dt) + integrator.f(integrator.fsallast,u,p,t+dt) integrator.destats.nf += 1 end diff --git a/src/perform_step/bdf_perform_step.jl b/src/perform_step/bdf_perform_step.jl index 7d78caf8ad..278a280985 100644 --- a/src/perform_step/bdf_perform_step.jl +++ b/src/perform_step/bdf_perform_step.jl @@ -46,7 +46,7 @@ end nlsolver.z = z nlsolver.tmp = d1*uₙ₋₁ + d2*uₙ₋₂ + d3*zₙ₋₁ - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return uₙ = nlsolver.tmp + d*z @@ -76,7 +76,12 @@ end function initialize!(integrator, cache::ABDF2Cache) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -120,7 +125,7 @@ end end @.. tmp = d1*uₙ₋₁ + d2*uₙ₋₂ + d3*zₙ₋₁ - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. uₙ = tmp + d*z @@ -195,7 +200,7 @@ function perform_step!(integrator,cache::SBDFConstantCache,repeat_step=false) end nlsolver.z = z - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z @@ -217,7 +222,12 @@ function initialize!(integrator, cache::SBDFCache) @unpack f1, f2 = integrator.f integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -231,8 +241,9 @@ end function perform_step!(integrator, cache::SBDFCache, repeat_step=false) @unpack t,dt,uprev,u,f,p,alg = integrator @unpack uprev2,uprev3,uprev4,k₁,k₂,k₃,du₁,du₂,nlsolver = cache - @unpack tmp,z,k = nlsolver + @unpack tmp,z = nlsolver @unpack f1, f2 = integrator.f + cnt = cache.cnt = min(alg.order, integrator.iter+1) integrator.iter == 1 && !integrator.u_modified && ( cnt = cache.cnt = 1 ) nlsolver.γ = γ = inv(γₖ[cnt]) @@ -258,7 +269,7 @@ function perform_step!(integrator, cache::SBDFCache, repeat_step=false) @.. z = zero(eltype(u)) end - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z @@ -269,7 +280,7 @@ function perform_step!(integrator, cache::SBDFCache, repeat_step=false) f2(du₂, u, p, t+dt) integrator.destats.nf += 1 integrator.destats.nf2 += 1 - @.. k = du₁ + du₂ + @.. integrator.fsallast = du₁ + du₂ end # QNDF1 @@ -319,7 +330,7 @@ function perform_step!(integrator,cache::QNDF1ConstantCache,repeat_step=false) nlsolver.z = dt*integrator.fsalfirst nlsolver.γ = γ - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z @@ -347,7 +358,12 @@ end function initialize!(integrator, cache::QNDF1Cache) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -385,7 +401,7 @@ function perform_step!(integrator,cache::QNDF1Cache,repeat_step=false) # initial guess @.. z = dt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z @@ -466,7 +482,7 @@ function perform_step!(integrator,cache::QNDF2ConstantCache,repeat_step=false) # initial guess nlsolver.z = dt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z @@ -505,7 +521,12 @@ end function initialize!(integrator, cache::QNDF2Cache) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -559,7 +580,7 @@ function perform_step!(integrator,cache::QNDF2Cache,repeat_step=false) # initial guess @.. nlsolver.z = dt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z @@ -663,7 +684,7 @@ function perform_step!(integrator,cache::QNDFConstantCache,repeat_step=false) # initial guess nlsolver.z = dt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z @@ -732,7 +753,12 @@ end function initialize!(integrator, cache::QNDFCache) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -804,7 +830,7 @@ function perform_step!(integrator,cache::QNDFCache,repeat_step=false) # initial guess @.. nlsolver.z = dt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = nlsolver.tmp + γ*z @@ -910,14 +936,14 @@ end ### STEP 1 nlsolver.tmp = uprev - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return z₁ = nlsolver.tmp + z ### STEP 2 nlsolver.tmp = z₁ nlsolver.c = 2 set_new_W!(nlsolver, false) - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return z₂ = z₁ + z ### STEP 3 @@ -925,7 +951,7 @@ end nlsolver.tmp = tmp2 nlsolver.c = 1 set_new_W!(nlsolver, false) - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = tmp2 + z @@ -940,7 +966,12 @@ end function initialize!(integrator, cache::MEBDF2Cache) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -965,14 +996,14 @@ end ### STEP 1 nlsolver.tmp = uprev - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. z₁ = uprev + z ### STEP 2 nlsolver.tmp = z₁ nlsolver.c = 2 set_new_W!(nlsolver, false) - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. z₂ = z₁ + z ### STEP 3 @@ -981,7 +1012,7 @@ end nlsolver.tmp = tmp2 nlsolver.c = 1 set_new_W!(nlsolver, false) - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp2 + z diff --git a/src/perform_step/exponential_rk_perform_step.jl b/src/perform_step/exponential_rk_perform_step.jl index b43f506c43..ab3daabb2a 100644 --- a/src/perform_step/exponential_rk_perform_step.jl +++ b/src/perform_step/exponential_rk_perform_step.jl @@ -42,8 +42,7 @@ end # Classical ExpRK integrators function perform_step!(integrator, cache::LawsonEulerConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator,cache,is_compos) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) nl = _compute_nl(f, uprev, p, t, A) @@ -71,8 +70,7 @@ end function perform_step!(integrator, cache::LawsonEulerCache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,G,J,exphA,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) _compute_nl!(G, f, uprev, p, t, A, rtmp) @@ -98,8 +96,7 @@ end function perform_step!(integrator, cache::NorsettEulerConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator,cache,is_compos) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) if alg.krylov @@ -122,8 +119,7 @@ end function perform_step!(integrator, cache::NorsettEulerCache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack rtmp,J,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) if alg.krylov @@ -144,8 +140,7 @@ end function perform_step!(integrator, cache::ETDRK2ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator,cache,is_compos) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) if alg.krylov @@ -183,8 +178,7 @@ end function perform_step!(integrator, cache::ETDRK2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,F2,J,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) if alg.krylov @@ -236,8 +230,7 @@ end function perform_step!(integrator, cache::ETDRK3ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator,cache,is_compos) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) Au = A * uprev @@ -290,8 +283,7 @@ end function perform_step!(integrator, cache::ETDRK3Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,Au,F2,F3,J,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) F1 = integrator.fsalfirst @@ -352,8 +344,7 @@ end function perform_step!(integrator, cache::ETDRK4ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator,cache,is_compos) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) Au = A * uprev @@ -418,8 +409,7 @@ end function perform_step!(integrator, cache::ETDRK4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,Au,F2,F3,F4,J,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) F1 = integrator.fsalfirst @@ -496,8 +486,7 @@ end function perform_step!(integrator, cache::HochOst4ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator,cache,is_compos) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) Au = A * uprev @@ -572,8 +561,7 @@ end function perform_step!(integrator, cache::HochOst4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,rtmp2,Au,F2,F3,F4,F5,J,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = isa(f, SplitFunction) ? f.f1.f : (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) F1 = integrator.fsalfirst @@ -662,8 +650,7 @@ end # EPIRK integrators function perform_step!(integrator, cache::Exp4ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - J = calc_J(integrator,cache,is_compos) + J = calc_J(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(uprev) is fsaled ts = [dt/3, 2dt/3, dt] @@ -708,8 +695,7 @@ end function perform_step!(integrator, cache::Exp4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,rtmp2,K,J,B,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(u0) is fsaled ts = [dt/3, 2dt/3, dt] @@ -759,8 +745,7 @@ end function perform_step!(integrator, cache::EPIRK4s3AConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - J = calc_J(integrator,cache,is_compos) + J = calc_J(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(uprev) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -792,8 +777,7 @@ end function perform_step!(integrator, cache::EPIRK4s3ACache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,rtmp2,K,J,B,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(u0) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -829,8 +813,7 @@ end function perform_step!(integrator, cache::EPIRK4s3BConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - J = calc_J(integrator,cache,is_compos) + J = calc_J(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(uprev) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -864,8 +847,7 @@ end function perform_step!(integrator, cache::EPIRK4s3BCache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,rtmp2,K,J,B,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(u0) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -906,8 +888,7 @@ end function perform_step!(integrator, cache::EPIRK5s3ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - J = calc_J(integrator,cache,is_compos) + J = calc_J(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(uprev) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -949,8 +930,7 @@ end function perform_step!(integrator, cache::EPIRK5s3Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,k,rtmp,rtmp2,J,B,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(u0) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -997,8 +977,7 @@ end function perform_step!(integrator, cache::EXPRB53s3ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - J = calc_J(integrator,cache,is_compos) + J = calc_J(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(uprev) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -1037,8 +1016,7 @@ end function perform_step!(integrator, cache::EXPRB53s3Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,rtmp2,K,J,B,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(u0) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -1085,8 +1063,7 @@ end function perform_step!(integrator, cache::EPIRK5P1ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - J = calc_J(integrator,cache,is_compos) + J = calc_J(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(uprev) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -1132,8 +1109,7 @@ end function perform_step!(integrator, cache::EPIRK5P1Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,rtmp2,K,J,B,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(u0) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -1183,8 +1159,7 @@ end function perform_step!(integrator, cache::EPIRK5P2ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - J = calc_J(integrator,cache,is_compos) + J = calc_J(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(uprev) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -1231,8 +1206,7 @@ end function perform_step!(integrator, cache::EPIRK5P2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack tmp,rtmp,rtmp2,dR,K,J,B,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - calc_J!(integrator,cache,is_compos) + calc_J!(integrator, cache) alg = unwrap_alg(integrator, true) f0 = integrator.fsalfirst # f(u0) is fsaled kwargs = (tol=integrator.opts.reltol, iop=alg.iop, opnorm=integrator.opts.internalopnorm, @@ -1288,8 +1262,7 @@ end # Adaptive exponential Rosenbrock integrators function perform_step!(integrator, cache::Exprb32ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = calc_J(integrator,cache,is_compos) # get linear operator + A = calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) F1 = integrator.fsalfirst @@ -1317,8 +1290,7 @@ end function perform_step!(integrator, cache::Exprb32Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack utilde,tmp,rtmp,F2,J,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) F1 = integrator.fsalfirst @@ -1354,8 +1326,7 @@ end function perform_step!(integrator, cache::Exprb43ConstantCache, repeat_step=false) @unpack t,dt,uprev,f,p = integrator - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = calc_J(integrator,cache,is_compos) # get linear operator + A = calc_J(integrator, cache) # get linear operator alg = unwrap_alg(integrator, true) Au = A * uprev @@ -1394,8 +1365,7 @@ end function perform_step!(integrator, cache::Exprb43Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack utilde,tmp,rtmp,Au,F2,F3,J,KsCache = cache - is_compos = isa(integrator.alg, CompositeAlgorithm) - A = (calc_J!(integrator,cache,is_compos); J) # get linear operator + A = (calc_J!(integrator, cache); J) # get linear operator alg = unwrap_alg(integrator, true) F1 = integrator.fsalfirst diff --git a/src/perform_step/firk_perform_step.jl b/src/perform_step/firk_perform_step.jl index 4d2511f54b..e5de281f2d 100644 --- a/src/perform_step/firk_perform_step.jl +++ b/src/perform_step/firk_perform_step.jl @@ -42,7 +42,6 @@ end alg = unwrap_alg(integrator, true) @unpack max_iter = alg mass_matrix = integrator.f.mass_matrix - is_compos = integrator.alg isa CompositeAlgorithm # precalculations rtol = @.. reltol^(2/3) / 10 @@ -51,7 +50,7 @@ end c2m1 = c2-1 c1mc2= c1-c2 γdt, αdt, βdt = γ/dt, α/dt, β/dt - J = calc_J(integrator, cache, is_compos) + J = calc_J(integrator, cache) if u isa Number LU1 = -γdt*mass_matrix + J LU2 = -(αdt + βdt*im)*mass_matrix + J @@ -218,14 +217,13 @@ end alg = unwrap_alg(integrator, true) @unpack max_iter = alg mass_matrix = integrator.f.mass_matrix - is_compos = integrator.alg isa CompositeAlgorithm # precalculations c1m1 = c1-1 c2m1 = c2-1 c1mc2= c1-c2 γdt, αdt, βdt = γ/dt, α/dt, β/dt - (new_jac = do_newJ(integrator, alg, cache, repeat_step)) && (calc_J!(integrator, cache, is_compos); cache.W_dt = dt) + (new_jac = do_newJ(integrator, alg, cache, repeat_step)) && (calc_J!(integrator, cache); cache.W_dt = dt) if (new_W = do_newW(integrator, alg, new_jac, cache.W_dt)) @inbounds for II in CartesianIndices(J) W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II] diff --git a/src/perform_step/kencarp_kvaerno_perform_step.jl b/src/perform_step/kencarp_kvaerno_perform_step.jl index 067436377d..add7f8cc97 100644 --- a/src/perform_step/kencarp_kvaerno_perform_step.jl +++ b/src/perform_step/kencarp_kvaerno_perform_step.jl @@ -23,7 +23,12 @@ function initialize!(integrator, cache::Union{Kvaerno3Cache, KenCarp5Cache}) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -50,7 +55,7 @@ end nlsolver.tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -60,7 +65,7 @@ end nlsolver.tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -69,7 +74,7 @@ end nlsolver.tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = 1 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₄ @@ -97,7 +102,7 @@ end @muladd function perform_step!(integrator, cache::Kvaerno3Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,z₃,z₄,atmp,nlsolver = cache - @unpack dz,k,tmp = nlsolver + @unpack gz,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,btilde1,btilde2,btilde3,btilde4,c3,α31,α32 = cache.tab alg = unwrap_alg(integrator, true) @@ -117,7 +122,7 @@ end @.. tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return set_new_W!(nlsolver, false) @@ -129,7 +134,7 @@ end @.. tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -144,7 +149,7 @@ end @.. tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = 1 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₄ @@ -152,12 +157,12 @@ end ################################### Finalize if integrator.opts.adaptive - @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + @.. gz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -208,7 +213,7 @@ end end nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -228,7 +233,7 @@ end nlsolver.tmp = tmp nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -248,7 +253,7 @@ end nlsolver.c = 1 nlsolver.tmp = tmp - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₄ @@ -291,7 +296,7 @@ end @muladd function perform_step!(integrator, cache::KenCarp3Cache, repeat_step=false) @unpack t,dt,uprev,u,p = integrator @unpack z₁,z₂,z₃,z₄,k1,k2,k3,k4,atmp,nlsolver = cache - @unpack dz,k,tmp = nlsolver + @unpack gz,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,btilde1,btilde2,btilde3,btilde4,c3,α31,α32 = cache.tab @unpack ea21,ea31,ea32,ea41,ea42,ea43,eb1,eb2,eb3,eb4 = cache.tab @unpack ebtilde1,ebtilde2,ebtilde3,ebtilde4 = cache.tab @@ -336,7 +341,7 @@ end end nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return set_new_W!(nlsolver, false) @@ -359,7 +364,7 @@ end nlsolver.z = z₃ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -381,7 +386,7 @@ end nlsolver.z = z₄ nlsolver.c = 1 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₄ @@ -398,18 +403,18 @@ end if integrator.opts.adaptive if typeof(integrator.f) <: SplitFunction - #@.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + ebtilde1*k1 + ebtilde2*k2 + ebtilde3*k3 + ebtilde4*k4 - for i in eachindex(dz) - @inbounds dz[i] = btilde1*z₁[i] + btilde2*z₂[i] + btilde3*z₃[i] + btilde4*z₄[i] + ebtilde1*k1[i] + ebtilde2*k2[i] + ebtilde3*k3[i] + ebtilde4*k4[i] + #@.. gz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + ebtilde1*k1 + ebtilde2*k2 + ebtilde3*k3 + ebtilde4*k4 + for i in eachindex(gz) + @inbounds gz[i] = btilde1*z₁[i] + btilde2*z₂[i] + btilde3*z₃[i] + btilde4*z₄[i] + ebtilde1*k1[i] + ebtilde2*k2[i] + ebtilde3*k3[i] + ebtilde4*k4[i] end else - @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + @.. gz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -447,7 +452,7 @@ end nlsolver.tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -456,7 +461,7 @@ end nlsolver.tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -465,7 +470,7 @@ end nlsolver.tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -475,7 +480,7 @@ end nlsolver.tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = 1 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₅ @@ -503,7 +508,7 @@ end @muladd function perform_step!(integrator, cache::Kvaerno4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,z₃,z₄,z₅,atmp,nlsolver = cache - @unpack dz,k,tmp = nlsolver + @unpack gz,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,a51,a52,a53,a54,c3,c4 = cache.tab @unpack α21,α31,α32,α41,α42 = cache.tab @unpack btilde1,btilde2,btilde3,btilde4,btilde5 = cache.tab @@ -526,7 +531,7 @@ end @.. tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return set_new_W!(nlsolver, false) @@ -537,7 +542,7 @@ end @.. tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -548,7 +553,7 @@ end @.. tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -559,7 +564,7 @@ end @.. tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = 1 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₅ @@ -567,12 +572,12 @@ end ################################### Finalize if integrator.opts.adaptive - @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + @.. gz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -629,7 +634,7 @@ end nlsolver.tmp = tmp nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -649,7 +654,7 @@ end nlsolver.tmp = tmp nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -668,7 +673,7 @@ end nlsolver.tmp = tmp nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -688,7 +693,7 @@ end nlsolver.c = c5 u = nlsolver.tmp + γ*z₅ - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -707,7 +712,7 @@ end nlsolver.tmp = tmp nlsolver.c = 1 - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₆ @@ -750,7 +755,7 @@ end @muladd function perform_step!(integrator, cache::KenCarp4Cache, repeat_step=false) @unpack t,dt,uprev,u,p = integrator @unpack z₁,z₂,z₃,z₄,z₅,z₆,atmp,nlsolver = cache - @unpack dz,k,tmp = nlsolver + @unpack gz,tmp = nlsolver @unpack k1,k2,k3,k4,k5,k6 = cache @unpack γ,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a63,a64,a65,c3,c4,c5 = cache.tab @unpack α31,α32,α41,α42,α51,α52,α53,α54,α61,α62,α63,α64,α65 = cache.tab @@ -801,7 +806,7 @@ end end nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return set_new_W!(nlsolver, false) @@ -826,7 +831,7 @@ end nlsolver.z = z₃ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -847,7 +852,7 @@ end nlsolver.z = z₄ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -873,7 +878,7 @@ end nlsolver.z = z₅ nlsolver.c = c5 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -898,7 +903,7 @@ end nlsolver.z = z₆ nlsolver.c = 1 - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₆ @@ -915,22 +920,22 @@ end if integrator.opts.adaptive if typeof(integrator.f) <: SplitFunction - #@.. dz = btilde1*z₁ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + ebtilde1*k1 + ebtilde3*k3 + ebtilde4*k4 + ebtilde5*k5 + ebtilde6*k6 + #@.. gz = btilde1*z₁ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + ebtilde1*k1 + ebtilde3*k3 + ebtilde4*k4 + ebtilde5*k5 + ebtilde6*k6 for i in eachindex(u) - @inbounds dz[i] = btilde1*z₁[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + ebtilde1*k1[i] + ebtilde3*k3[i] + ebtilde4*k4[i] + ebtilde5*k5[i] + ebtilde6*k6[i] + @inbounds gz[i] = btilde1*z₁[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + ebtilde1*k1[i] + ebtilde3*k3[i] + ebtilde4*k4[i] + ebtilde5*k5[i] + ebtilde6*k6[i] end else - # @.. dz = btilde1*z₁ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + # @.. gz = btilde1*z₁ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ @tight_loop_macros for i in eachindex(u) - @inbounds dz[i] = btilde1*z₁[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + @inbounds gz[i] = btilde1*z₁[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] end end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -968,7 +973,7 @@ end nlsolver.tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -977,7 +982,7 @@ end nlsolver.tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -986,7 +991,7 @@ end nlsolver.tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -995,7 +1000,7 @@ end nlsolver.tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = c5 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -1004,7 +1009,7 @@ end nlsolver.tmp = uprev + a61*z₁ + a63*z₃ + a64*z₄ + a65*z₅ nlsolver.c = c6 - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 7 @@ -1014,7 +1019,7 @@ end nlsolver.tmp = uprev + a71*z₁ + a73*z₃ + a74*z₄ + a75*z₅ + a76*z₆ nlsolver.c = 1 - z₇ = nlsolve!(integrator, cache) + z₇ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₇ @@ -1042,7 +1047,7 @@ end @muladd function perform_step!(integrator, cache::Kvaerno5Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,z₃,z₄,z₅,z₆,z₇,atmp,nlsolver = cache - @unpack dz,k,tmp = nlsolver + @unpack gz,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a63,a64,a65,a71,a73,a74,a75,a76,c3,c4,c5,c6 = cache.tab @unpack btilde1,btilde3,btilde4,btilde5,btilde6,btilde7 = cache.tab @unpack α31,α32,α41,α42,α43,α51,α52,α53,α61,α62,α63 = cache.tab @@ -1065,7 +1070,7 @@ end @.. tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return set_new_W!(nlsolver, false) @@ -1076,7 +1081,7 @@ end @.. tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -1087,7 +1092,7 @@ end @.. tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -1097,7 +1102,7 @@ end @.. tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = c5 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -1107,7 +1112,7 @@ end @.. tmp = uprev + a61*z₁ + a63*z₃ + a64*z₄ + a65*z₅ nlsolver.c = c6 - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 7 @@ -1124,7 +1129,7 @@ end @inbounds tmp[i] = uprev[i] + a71*z₁[i] + a73*z₃[i] + a74*z₄[i] + a75*z₅[i] + a76*z₆[i] end nlsolver.c = 1 - z₇ = nlsolve!(integrator, cache) + z₇ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₇ @@ -1132,15 +1137,15 @@ end ################################### Finalize if integrator.opts.adaptive - # @.. dz = btilde1*z₁ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + btilde7*z₇ + # @.. gz = btilde1*z₁ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + btilde7*z₇ @tight_loop_macros for i in eachindex(u) - @inbounds dz[i] = btilde1*z₁[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + btilde7*z₇[i] + @inbounds gz[i] = btilde1*z₁[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + btilde7*z₇[i] end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -1200,7 +1205,7 @@ end nlsolver.tmp = tmp nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -1222,7 +1227,7 @@ end nlsolver.c = c3 nlsolver.tmp = tmp - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -1241,7 +1246,7 @@ end nlsolver.c = c4 nlsolver.tmp = tmp - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -1260,7 +1265,7 @@ end nlsolver.c = c5 nlsolver.tmp = tmp - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -1279,7 +1284,7 @@ end nlsolver.c = c6 nlsolver.tmp = tmp - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 7 @@ -1298,7 +1303,7 @@ end nlsolver.c = c7 nlsolver.tmp = tmp - z₇ = nlsolve!(integrator, cache) + z₇ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 8 @@ -1317,7 +1322,7 @@ end nlsolver.c = 1 nlsolver.tmp = tmp - z₈ = nlsolve!(integrator, cache) + z₈ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₈ @@ -1361,7 +1366,7 @@ end @unpack t,dt,uprev,u,p = integrator @unpack z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,atmp,nlsolver = cache @unpack k1,k2,k3,k4,k5,k6,k7,k8 = cache - @unpack dz,k,tmp = nlsolver + @unpack gz,tmp = nlsolver @unpack γ,a31,a32,a41,a43,a51,a53,a54,a61,a63,a64,a65,a71,a73,a74,a75,a76,a81,a84,a85,a86,a87,c3,c4,c5,c6,c7 = cache.tab @unpack α31,α32,α41,α42,α51,α52,α61,α62,α71,α72,α73,α74,α75,α81,α82,α83,α84,α85 = cache.tab @unpack btilde1,btilde4,btilde5,btilde6,btilde7,btilde8 = cache.tab @@ -1413,7 +1418,7 @@ end end nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return set_new_W!(nlsolver, false) @@ -1438,7 +1443,7 @@ end nlsolver.z = z₃ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -1459,7 +1464,7 @@ end nlsolver.z = z₄ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -1480,7 +1485,7 @@ end nlsolver.z = z₅ nlsolver.c = c5 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -1504,7 +1509,7 @@ end nlsolver.z = z₆ nlsolver.c = c6 - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 7 @@ -1531,7 +1536,7 @@ end nlsolver.z = z₇ nlsolver.c = c7 - z₇ = nlsolve!(integrator, cache) + z₇ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 8 @@ -1558,7 +1563,7 @@ end nlsolver.z = z₈ nlsolver.c = 1 - z₈ = nlsolve!(integrator, cache) + z₈ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₈ @@ -1576,23 +1581,23 @@ end if integrator.opts.adaptive if typeof(integrator.f) <: SplitFunction - #@.. dz = btilde1*z₁ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + btilde7*z₇ + btilde8*z₈ + ebtilde1*k1 + ebtilde4*k4 + ebtilde5*k5 + ebtilde6*k6 + ebtilde7*k7 + ebtilde8*k8 + #@.. gz = btilde1*z₁ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + btilde7*z₇ + btilde8*z₈ + ebtilde1*k1 + ebtilde4*k4 + ebtilde5*k5 + ebtilde6*k6 + ebtilde7*k7 + ebtilde8*k8 for i in eachindex(u) - dz[i] = btilde1*z₁[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + btilde7*z₇[i] + btilde8*z₈[i] + ebtilde1*k1[i] + ebtilde4*k4[i] + ebtilde5*k5[i] + ebtilde6*k6[i] + ebtilde7*k7[i] + ebtilde8*k8[i] + gz[i] = btilde1*z₁[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + btilde7*z₇[i] + btilde8*z₈[i] + ebtilde1*k1[i] + ebtilde4*k4[i] + ebtilde5*k5[i] + ebtilde6*k6[i] + ebtilde7*k7[i] + ebtilde8*k8[i] end else - # @.. dz = btilde1*z₁ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + btilde7*z₇ + btilde8*z₈ + # @.. gz = btilde1*z₁ + btilde4*z₄ + btilde5*z₅ + btilde6*z₆ + btilde7*z₇ + btilde8*z₈ @tight_loop_macros for i in eachindex(u) - @inbounds dz[i] = btilde1*z₁[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + btilde7*z₇[i] + btilde8*z₈[i] + @inbounds gz[i] = btilde1*z₁[i] + btilde4*z₄[i] + btilde5*z₅[i] + btilde6*z₆[i] + btilde7*z₇[i] + btilde8*z₈[i] end end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) diff --git a/src/perform_step/rkc_perform_step.jl b/src/perform_step/rkc_perform_step.jl index eb67371ee9..abd659130a 100644 --- a/src/perform_step/rkc_perform_step.jl +++ b/src/perform_step/rkc_perform_step.jl @@ -576,7 +576,7 @@ function perform_step!(integrator,cache::IRKCConstantCache,repeat_step=false) nlsolver.tmp = uprev + dt*μs₁*du₂ nlsolver.γ = μs₁ nlsolver.c = μs - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) # nlsolvefail(nlsolver) && return gprev = nlsolver.tmp + μs₁*z @@ -608,7 +608,7 @@ function perform_step!(integrator,cache::IRKCConstantCache,repeat_step=false) nlsolver.tmp = (1-μ-ν)*uprev + μ*gprev + ν*gprev2 + dt*μs*f2ⱼ₋₁ + dt*νs*du₂ + (νs - (1 -μ-ν)*μs₁)*dt*du₁ - ν*μs₁*dt*f1ⱼ₋₂ nlsolver.z = dt*f1ⱼ₋₁ nlsolver.c = Cⱼ - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) # ignoring newton method's convergence failure # nlsolvefail(nlsolver) && return u = nlsolver.tmp + μs₁*z @@ -653,7 +653,12 @@ function initialize!(integrator, cache::IRKCCache) @unpack f1, f2 = integrator.f integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -667,7 +672,7 @@ end function perform_step!(integrator, cache::IRKCCache, repeat_step=false) @unpack t,dt,uprev,u,f,p,alg = integrator @unpack gprev,gprev2,f1ⱼ₋₁,f1ⱼ₋₂,f2ⱼ₋₁,du₁,du₂,atmp,nlsolver = cache - @unpack tmp,k,z = nlsolver + @unpack tmp,z = nlsolver @unpack minm = cache.constantcache @unpack f1, f2 = integrator.f @@ -704,7 +709,7 @@ function perform_step!(integrator, cache::IRKCCache, repeat_step=false) @.. nlsolver.tmp = uprev + dt*μs₁*du₂ nlsolver.γ = μs₁ nlsolver.c = μs - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) # ignoring newton method's convergence failure # nlsolvefail(nlsolver) && return @.. gprev = nlsolver.tmp + μs₁*nlsolver.z @@ -738,7 +743,7 @@ function perform_step!(integrator, cache::IRKCCache, repeat_step=false) @.. nlsolver.z = dt*f1ⱼ₋₁ nlsolver.c = Cⱼ - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) # nlsolvefail(nlsolver) && return @.. u = nlsolver.tmp + μs₁*nlsolver.z if (iter < mdeg) @@ -768,7 +773,7 @@ function perform_step!(integrator, cache::IRKCCache, repeat_step=false) if isnewton(nlsolver) && integrator.opts.adaptive update_W!(integrator, cache, dt, false) @.. gprev = dt*0.5*(du₂ - f2ⱼ₋₁) + dt*(0.5 - μs₁)*(du₁ - f1ⱼ₋₁) - nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(gprev),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gprev),false) calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp,t) end diff --git a/src/perform_step/sdirk_perform_step.jl b/src/perform_step/sdirk_perform_step.jl index 538c628eb3..54bf012362 100644 --- a/src/perform_step/sdirk_perform_step.jl +++ b/src/perform_step/sdirk_perform_step.jl @@ -31,7 +31,12 @@ function initialize!(integrator, cache::Union{ImplicitEulerCache, ESDIRK54I8L2SACache}) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.nlsolver.k + _du_cache = du_cache(cache.nlsolver) + if _du_cache === nothing + integrator.fsallast = zero(integrator.fsalfirst) + else + integrator.fsallast = first(_du_cache) + end resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -53,7 +58,7 @@ end end nlsolver.tmp = uprev - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + z @@ -100,7 +105,7 @@ end end nlsolver.tmp .= uprev - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = uprev + z @@ -142,7 +147,7 @@ end end nlsolver.tmp = uprev - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + z @@ -170,7 +175,7 @@ end end nlsolver.tmp = uprev - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = nlsolver.tmp + z @@ -192,7 +197,7 @@ end nlsolver.z = zprev # Constant extrapolation nlsolver.tmp = uprev + γdt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + 1//2*z @@ -244,7 +249,7 @@ end @muladd function perform_step!(integrator, cache::TrapezoidCache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack atmp,nlsolver = cache - @unpack z,jac_config,tmp = nlsolver + @unpack z,tmp = nlsolver alg = unwrap_alg(integrator, true) mass_matrix = integrator.f.mass_matrix @@ -256,7 +261,7 @@ end # initial guess @.. z = dt*integrator.fsalfirst @.. tmp = uprev + γdt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + 1//2*z @@ -322,7 +327,7 @@ end nlsolver.c = γ nlsolver.tmp = uprev + d*zprev - zᵧ = nlsolve!(integrator, cache) + zᵧ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve BDF2 Step @@ -333,7 +338,7 @@ end nlsolver.c = 1 nlsolver.tmp = uprev + ω*zprev + ω*zᵧ - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + d*z @@ -361,9 +366,8 @@ end @muladd function perform_step!(integrator, cache::TRBDF2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack zprev,zᵧ,atmp,nlsolver = cache - @unpack dz,z,k,tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - b = nlsolver.ztmp + @unpack z,gz,tmp = nlsolver + @unpack γ,d,ω,btilde1,btilde2,btilde3,α1,α2 = cache.tab alg = unwrap_alg(integrator, true) @@ -378,7 +382,7 @@ end z .= zᵧ @.. tmp = uprev + d*zprev nlsolver.c = γ - zᵧ .= nlsolve!(integrator, cache) + zᵧ .= nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve BDF2 Step @@ -388,7 +392,7 @@ end @.. tmp = uprev + ω*zprev + ω*zᵧ nlsolver.c = 1 set_new_W!(nlsolver, false) - nlsolve!(integrator, cache) + nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + d*z @@ -396,12 +400,12 @@ end ################################### Finalize if integrator.opts.adaptive - @.. dz = btilde1*zprev + btilde2*zᵧ + btilde3*z + @.. gz = btilde1*zprev + btilde2*zᵧ + btilde3*z if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),W,vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -427,14 +431,14 @@ end end nlsolver.tmp = uprev - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ### Initial Guess Is α₁ = c₂/γ, c₂ = 0 => z₂ = α₁z₁ = 0 z₂ = zero(u) nlsolver.z = z₂ nlsolver.tmp = uprev - z₁ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = uprev + z₁/2 + z₂/2 @@ -463,8 +467,8 @@ end @muladd function perform_step!(integrator, cache::SDIRK2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,atmp,nlsolver = cache - @unpack dz,k,jac_config,tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing + @unpack gz,tmp = nlsolver + alg = unwrap_alg(integrator, true) update_W!(integrator, cache, dt, repeat_step) @@ -481,7 +485,7 @@ end ##### Step 1 nlsolver.tmp = uprev - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 2 @@ -492,7 +496,7 @@ end set_new_W!(nlsolver, false) @.. tmp = uprev - z₁ nlsolver.tmp = tmp - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = uprev + z₁/2 + z₂/2 @@ -500,12 +504,12 @@ end ################################### Finalize if integrator.opts.adaptive - @.. dz = z₁/2 - z₂/2 + @.. gz = z₁/2 - z₂/2 if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),W,vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -532,7 +536,7 @@ end # first stage nlsolver.tmp = uprev + γdt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return uprev = α*nlsolver.tmp + β*z @@ -541,7 +545,7 @@ end γdt = γ*dt update_W!(integrator, cache, γdt, repeat_step) nlsolver.tmp = uprev + γdt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp @@ -591,7 +595,7 @@ end @muladd function perform_step!(integrator, cache::SDIRK22Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack atmp,nlsolver = cache - @unpack z,jac_config,tmp = nlsolver + @unpack z,tmp = nlsolver @unpack a,α,β = cache.tab alg = unwrap_alg(integrator, true) mass_matrix = integrator.f.mass_matrix @@ -604,7 +608,7 @@ end # first stage @.. z = dt*integrator.fsalfirst @.. tmp = uprev + γdt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = α*tmp + β*z @@ -613,7 +617,7 @@ end γdt = γ*dt update_W!(integrator, cache, γdt, repeat_step) @.. tmp = uprev + γdt*integrator.fsalfirst - z = nlsolve!(integrator, cache) + z = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = nlsolver.tmp @@ -688,7 +692,7 @@ end nlsolver.c = 1 nlsolver.tmp = uprev - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 2 @@ -699,7 +703,7 @@ end nlsolver.tmp = uprev + z₁/2 nlsolver.c = 1 - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + z₂/2 @@ -716,7 +720,7 @@ end @muladd function perform_step!(integrator, cache::SSPSDIRK2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,nlsolver = cache - @unpack dz,k,jac_config,tmp = nlsolver + @unpack gz,tmp = nlsolver alg = unwrap_alg(integrator, true) γ = eltype(u)(1//4) @@ -736,7 +740,7 @@ end nlsolver.tmp = uprev ##### Step 1 - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 2 @@ -748,7 +752,7 @@ end @.. tmp = uprev + z₁/2 nlsolver.tmp = tmp set_new_W!(nlsolver, false) - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + z₂/2 @@ -775,7 +779,7 @@ end nlsolver.c = γ nlsolver.tmp = uprev - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ##### Step 2 @@ -786,7 +790,7 @@ end nlsolver.tmp = uprev + a21*z₁ nlsolver.c = c2 - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -797,7 +801,7 @@ end nlsolver.tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -808,7 +812,7 @@ end nlsolver.tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -819,7 +823,7 @@ end nlsolver.tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = 1 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₅ @@ -855,8 +859,8 @@ end @muladd function perform_step!(integrator, cache::Cash4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,z₃,z₄,z₅,atmp,nlsolver = cache - @unpack dz,k,tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing + @unpack gz,tmp = nlsolver + @unpack γ,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,c2,c3,c4 = cache.tab @unpack b1hat1,b2hat1,b3hat1,b4hat1,b1hat2,b2hat2,b3hat2,b4hat2 = cache.tab alg = unwrap_alg(integrator, true) @@ -871,7 +875,7 @@ end nlsolver.tmp = uprev # initial step of NLNewton iteration - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ##### Step 2 @@ -884,7 +888,7 @@ end nlsolver.tmp = tmp set_new_W!(nlsolver, false) nlsolver.c = c2 - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -894,7 +898,7 @@ end nlsolver.z = z₃ @.. tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -905,7 +909,7 @@ end @.. tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -915,7 +919,7 @@ end nlsolver.z = z₅ @.. tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = 1 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₅ @@ -931,12 +935,12 @@ end btilde3 = b3hat1-a53; btilde4 = b4hat1-a54; btilde5 = -γ end - @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + @.. gz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),W,vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -961,7 +965,7 @@ end z₁ = zero(u) nlsolver.z, nlsolver.tmp = z₁, uprev nlsolver.c = γ - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ##### Step 2 @@ -970,7 +974,7 @@ end nlsolver.z = z₂ nlsolver.tmp = uprev + a21*z₁ nlsolver.c = c2 - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -979,7 +983,7 @@ end nlsolver.z = z₃ nlsolver.tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -988,7 +992,7 @@ end nlsolver.z = z₄ nlsolver.tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -998,7 +1002,7 @@ end nlsolver.z = z₅ nlsolver.tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = 1 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₅ @@ -1026,8 +1030,7 @@ end @muladd function perform_step!(integrator, cache::Hairer4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,z₃,z₄,z₅,atmp,nlsolver = cache - @unpack dz,k,jac_config,tmp = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing + @unpack gz,tmp = nlsolver @unpack γ,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,c2,c3,c4 = cache.tab @unpack α21,α31,α32,α41,α43 = cache.tab @unpack bhat1,bhat2,bhat3,bhat4,btilde1,btilde2,btilde3,btilde4,btilde5 = cache.tab @@ -1049,7 +1052,7 @@ end ##### Step 1 nlsolver.c = γ - z₁ = nlsolve!(integrator, cache) + z₁ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ##### Step 2 @@ -1060,7 +1063,7 @@ end nlsolver.tmp = tmp nlsolver.c = c2 set_new_W!(nlsolver, false) - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -1069,7 +1072,7 @@ end nlsolver.z = z₃ @.. tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -1079,7 +1082,7 @@ end nlsolver.z = z₄ @.. tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -1089,7 +1092,7 @@ end nlsolver.z = z₅ @.. tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = 1 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₅ @@ -1097,15 +1100,15 @@ end ################################### Finalize if integrator.opts.adaptive - # @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ + # @.. gz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ @tight_loop_macros for i in eachindex(u) - dz[i] = btilde1*z₁[i] + btilde2*z₂[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] + gz[i] = btilde1*z₁[i] + btilde2*z₂[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] end if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - nlsolver.linsolve(vec(tmp),W,vec(dz),false) + get_linsolve(nlsolver)(vec(tmp),get_W(nlsolver),vec(gz),false) else - tmp .= dz + tmp .= gz end calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol,integrator.opts.internalnorm,t) integrator.EEst = integrator.opts.internalnorm(atmp,t) @@ -1145,7 +1148,7 @@ end nlsolver.tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 3 @@ -1154,7 +1157,7 @@ end nlsolver.tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -1163,7 +1166,7 @@ end nlsolver.tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -1172,7 +1175,7 @@ end nlsolver.tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = c5 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -1181,7 +1184,7 @@ end nlsolver.tmp = uprev + a61*z₁ + a62*z₂+ a63*z₃ + a64*z₄ + a65*z₅ nlsolver.c = c6 - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 7 @@ -1190,7 +1193,7 @@ end nlsolver.tmp = uprev + a71*z₁ + a72*z₂ + a73*z₃ + a74*z₄ + a75*z₅ + a76*z₆ nlsolver.c = c7 - z₇ = nlsolve!(integrator, cache) + z₇ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 8 @@ -1199,7 +1202,7 @@ end nlsolver.tmp = uprev + a81*z₁ + a82*z₂ + a83*z₃ + a84*z₄ + a85*z₅ + a86*z₆ + a87*z₇ nlsolver.c = 1 - z₈ = nlsolve!(integrator, cache) + z₈ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return u = nlsolver.tmp + γ*z₈ @@ -1222,7 +1225,7 @@ end @muladd function perform_step!(integrator, cache::ESDIRK54I8L2SACache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator @unpack z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,atmp,nlsolver = cache - @unpack k,tmp = nlsolver + @unpack tmp = nlsolver @unpack γ, a31, a32, a41, a42, a43, @@ -1250,7 +1253,7 @@ end @.. tmp = uprev + γ*z₁ nlsolver.c = γ - z₂ = nlsolve!(integrator, cache) + z₂ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return set_new_W!(nlsolver, false) @@ -1260,7 +1263,7 @@ end @.. tmp = uprev + a31*z₁ + a32*z₂ nlsolver.c = c3 - z₃ = nlsolve!(integrator, cache) + z₃ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 4 @@ -1270,7 +1273,7 @@ end @.. tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ nlsolver.c = c4 - z₄ = nlsolve!(integrator, cache) + z₄ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 5 @@ -1279,7 +1282,7 @@ end @.. tmp = uprev + a51*z₁ + a52*z₂ + a53*z₃ + a54*z₄ nlsolver.c = c5 - z₅ = nlsolve!(integrator, cache) + z₅ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 6 @@ -1288,7 +1291,7 @@ end @.. tmp = uprev + a61*z₁ + a62*z₂ + a63*z₃ + a64*z₄ + a65*z₅ nlsolver.c = c6 - z₆ = nlsolve!(integrator, cache) + z₆ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 7 @@ -1297,7 +1300,7 @@ end @.. tmp = uprev + a71*z₁ + a72*z₂ + a73*z₃ + a74*z₄ + a75*z₅ + a76*z₆ nlsolver.c = c7 - z₇ = nlsolve!(integrator, cache) + z₇ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return ################################## Solve Step 8 @@ -1306,7 +1309,7 @@ end @.. nlsolver.tmp = uprev + a81*z₁ + a82*z₂ + a83*z₃ + a84*z₄ + a85*z₅ + a86*z₆ + a87*z₇ nlsolver.c = oneunit(nlsolver.c) - z₈ = nlsolve!(integrator, cache) + z₈ = nlsolve!(nlsolver, integrator) nlsolvefail(nlsolver) && return @.. u = tmp + γ*z₈ diff --git a/test/integrators/resize_tests.jl b/test/integrators/resize_tests.jl index 87730f28c7..0bc80db8ce 100644 --- a/test/integrators/resize_tests.jl +++ b/test/integrators/resize_tests.jl @@ -21,15 +21,16 @@ resize!(i, 5) @test length(i.cache.uprev) == 5 # nlsolver fields @test length(i.cache.nlsolver.z) == 5 -@test length(i.cache.nlsolver.dz) == 5 -@test length(i.cache.nlsolver.weight) == 5 -@test length(i.cache.nlsolver.ztmp) == 5 +@test length(i.cache.nlsolver.gz) == 5 @test length(i.cache.nlsolver.tmp) == 5 -@test length(i.cache.nlsolver.k) == 5 -@test length(i.cache.nlsolver.du1) == 5 +@test length(i.cache.nlsolver.cache.dz) == 5 +@test length(i.cache.nlsolver.cache.atmp) == 5 +@test length(i.cache.nlsolver.cache.k) == 5 +@test length(i.cache.nlsolver.cache.weight) == 5 +@test length(i.cache.nlsolver.cache.du1) == 5 # ForwardDiff -@test length(i.cache.nlsolver.jac_config.duals[1]) == 5 -@test length(i.cache.nlsolver.jac_config.duals[2]) == 5 +@test length(i.cache.nlsolver.cache.jac_config.duals[1]) == 5 +@test length(i.cache.nlsolver.cache.jac_config.duals[2]) == 5 @test size(i.cache.nlsolver.cache.W) == (5,5) @test size(i.cache.nlsolver.cache.J) == (5,5) solve!(i) @@ -40,16 +41,17 @@ resize!(i, 5) @test length(i.cache.uprev) == 5 # nlsolver fields @test length(i.cache.nlsolver.z) == 5 -@test length(i.cache.nlsolver.dz) == 5 -@test length(i.cache.nlsolver.weight) == 5 -@test length(i.cache.nlsolver.ztmp) == 5 +@test length(i.cache.nlsolver.gz) == 5 @test length(i.cache.nlsolver.tmp) == 5 -@test length(i.cache.nlsolver.k) == 5 -@test length(i.cache.nlsolver.du1) == 5 +@test length(i.cache.nlsolver.cache.dz) == 5 +@test length(i.cache.nlsolver.cache.atmp) == 5 +@test length(i.cache.nlsolver.cache.k) == 5 +@test length(i.cache.nlsolver.cache.weight) == 5 +@test length(i.cache.nlsolver.cache.du1) == 5 # DiffEqDiffTools -@test length(i.cache.nlsolver.jac_config.x1) == 5 -@test length(i.cache.nlsolver.jac_config.fx) == 5 -@test length(i.cache.nlsolver.jac_config.fx1) == 5 +@test length(i.cache.nlsolver.cache.jac_config.x1) == 5 +@test length(i.cache.nlsolver.cache.jac_config.fx) == 5 +@test length(i.cache.nlsolver.cache.jac_config.fx1) == 5 @test size(i.cache.nlsolver.cache.W) == (5,5) @test size(i.cache.nlsolver.cache.J) == (5,5) solve!(i) diff --git a/test/interface/mass_matrix_tests.jl b/test/interface/mass_matrix_tests.jl index 1ba5dff4ec..2d531d0cdf 100644 --- a/test/interface/mass_matrix_tests.jl +++ b/test/interface/mass_matrix_tests.jl @@ -75,7 +75,7 @@ using OrdinaryDiffEq, Test, LinearAlgebra, Statistics sol = solve(prob, ImplicitMidpoint(extrapolant = :constant, nlsolve=NLFunctional()),dt=1/10,reltol=1e-7,abstol=1e-10) sol2 = solve(prob2,ImplicitMidpoint(extrapolant = :constant, nlsolve=NLFunctional()),dt=1/10,reltol=1e-7,abstol=1e-10) - @test norm(sol .- sol2) ≈ 0 atol=1e-7 + @test norm(sol .- sol2) ≈ 0 atol=1.1e-7 sol = solve(prob,ImplicitEuler(nlsolve=NLAnderson()),dt=1/10,adaptive=false) sol2 = solve(prob2,ImplicitEuler(nlsolve=NLAnderson()),dt=1/10,adaptive=false)