Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/caches/firk_caches.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mutable struct RadauIIA5ConstantCache{F,Tab,Tol,Dt,U} <: OrdinaryDiffEqConstantCache
mutable struct RadauIIA5ConstantCache{F,Tab,Tol,Dt,U,JType} <: OrdinaryDiffEqConstantCache
uf::F
tab::Tab
κ::Tol
Expand All @@ -10,6 +10,7 @@ mutable struct RadauIIA5ConstantCache{F,Tab,Tol,Dt,U} <: OrdinaryDiffEqConstantC
dtprev::Dt
W_dt::Dt
status::DiffEqBase.NLStatus
J::JType
end

function alg_cache(alg::RadauIIA5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
Expand All @@ -19,8 +20,13 @@ function alg_cache(alg::RadauIIA5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeN
tab = RadauIIA5Tableau(uToltype, real(tTypeNoUnits))

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1//100)
if rate_prototype isa Number
J = 0
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)'
end

RadauIIA5ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, dt, dt, Convergence)
RadauIIA5ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, dt, dt, Convergence, J)
end

mutable struct RadauIIA5Cache{uType,cuType,uNoUnitsType,rateType,JType,W1Type,W2Type,UF,JC,F1,F2,Tab,Tol,Dt,rTol,aTol} <: OrdinaryDiffEqMutableCache
Expand Down
168 changes: 64 additions & 104 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,10 @@ function alg_cache(alg::Rosenbrock23,u,rate_prototype,uEltypeNoUnits,uBottomElty
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = Rosenbrock23ConstantCache(real(uBottomEltypeNoUnits),identity,identity)
tab = Rosenbrock23Tableau(real(uBottomEltypeNoUnits))
tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
linsolve_tmp = zero(rate_prototype)
Expand All @@ -101,19 +92,10 @@ function alg_cache(alg::Rosenbrock32,u,rate_prototype,uEltypeNoUnits,uBottomElty
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = Rosenbrock32ConstantCache(real(uBottomEltypeNoUnits),identity,identity)
tab = Rosenbrock32Tableau(real(uBottomEltypeNoUnits))

tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
Expand All @@ -124,52 +106,63 @@ function alg_cache(alg::Rosenbrock32,u,rate_prototype,uEltypeNoUnits,uBottomElty
Rosenbrock32Cache(u,uprev,k₁,k₂,k₃,du1,du2,f₁,fsalfirst,fsallast,dT,J,W,tmp,atmp,tab,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config)
end

struct Rosenbrock23ConstantCache{T,TF,UF} <: OrdinaryDiffEqConstantCache
struct Rosenbrock23ConstantCache{T,TF,UF,JType,WType,F} <: OrdinaryDiffEqConstantCache
c₃₂::T
d::T
tf::TF
uf::UF
J::JType
W::WType
linsolve::F
end

function Rosenbrock23ConstantCache(T::Type,tf,uf)
c₃₂ = convert(T,6 + sqrt(2))
d = convert(T,1/(2+sqrt(2)))
Rosenbrock23ConstantCache(c₃₂,d,tf,uf)
function Rosenbrock23ConstantCache(T::Type,tf,uf,J,W,linsolve)
tab = Rosenbrock23Tableau(T)
Rosenbrock23ConstantCache(tab.c₃₂,tab.d,tf,uf,J,W,linsolve)
end

function alg_cache(alg::Rosenbrock23,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rosenbrock23ConstantCache(real(uBottomEltypeNoUnits),tf,uf)
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rosenbrock23ConstantCache(real(uBottomEltypeNoUnits),tf,uf,J,W,linsolve)
end

struct Rosenbrock32ConstantCache{T,TF,UF} <: OrdinaryDiffEqConstantCache
struct Rosenbrock32ConstantCache{T,TF,UF,JType,WType,F} <: OrdinaryDiffEqConstantCache
c₃₂::T
d::T
tf::TF
uf::UF
J::JType
W::WType
linsolve::F
end

function Rosenbrock32ConstantCache(T::Type,tf,uf)
c₃₂ = convert(T,6 + sqrt(2))
d = convert(T,1/(2+sqrt(2)))
Rosenbrock32ConstantCache(c₃₂,d,tf,uf)
function Rosenbrock32ConstantCache(T::Type,tf,uf,J,W,linsolve)
tab = Rosenbrock32Tableau(T)
Rosenbrock32ConstantCache(tab.c₃₂,tab.d,tf,uf,J,W,linsolve)
end

function alg_cache(alg::Rosenbrock32,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rosenbrock32ConstantCache(real(uBottomEltypeNoUnits),tf,uf)
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rosenbrock32ConstantCache(real(uBottomEltypeNoUnits),tf,uf,J,W,linsolve)
end

################################################################################

### 3rd order specialized Rosenbrocks

struct Rosenbrock33ConstantCache{TF,UF,Tab} <: OrdinaryDiffEqConstantCache
struct Rosenbrock33ConstantCache{TF,UF,Tab,JType,WType,F} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
end

@cache mutable struct Rosenbrock33Cache{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: RosenbrockMutableCache
Expand Down Expand Up @@ -209,16 +202,7 @@ function alg_cache(alg::ROS3P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = ROS3PConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits))
Expand All @@ -236,7 +220,9 @@ end
function alg_cache(alg::ROS3P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rosenbrock33ConstantCache(tf,uf,ROS3PConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)))
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rosenbrock33ConstantCache(tf,uf,ROS3PConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve)
end

@cache mutable struct Rosenbrock34Cache{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: RosenbrockMutableCache
Expand Down Expand Up @@ -276,16 +262,7 @@ function alg_cache(alg::Rodas3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = Rodas3ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits))
Expand All @@ -301,16 +278,21 @@ function alg_cache(alg::Rodas3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn
linsolve,jac_config,grad_config)
end

struct Rosenbrock34ConstantCache{TF,UF,Tab} <: OrdinaryDiffEqConstantCache
struct Rosenbrock34ConstantCache{TF,UF,Tab,JType,WType,F} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
end

function alg_cache(alg::Rodas3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rosenbrock34ConstantCache(tf,uf,Rodas3ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)))
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rosenbrock34ConstantCache(tf,uf,Rodas3ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve)
end

################################################################################
Expand All @@ -330,10 +312,13 @@ jac_cache(c::Rosenbrock4Cache) = (c.J,c.W)

### Rodas methods

struct Rodas4ConstantCache{TF,UF,Tab} <: OrdinaryDiffEqConstantCache
struct Rodas4ConstantCache{TF,UF,Tab,JType,WType,F} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
end

@cache mutable struct Rodas4Cache{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: RosenbrockMutableCache
Expand Down Expand Up @@ -381,16 +366,7 @@ function alg_cache(alg::Rodas4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = Rodas4ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits))
Expand All @@ -410,7 +386,9 @@ end
function alg_cache(alg::Rodas4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rodas4ConstantCache(tf,uf,Rodas4ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)))
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rodas4ConstantCache(tf,uf,Rodas4ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve)
end

function alg_cache(alg::Rodas42,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
Expand All @@ -428,16 +406,7 @@ function alg_cache(alg::Rodas42,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoU
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = Rodas42ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits))
Expand All @@ -457,7 +426,9 @@ end
function alg_cache(alg::Rodas42,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rodas4ConstantCache(tf,uf,Rodas42ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)))
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rodas4ConstantCache(tf,uf,Rodas42ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve)
end

function alg_cache(alg::Rodas4P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
Expand All @@ -475,16 +446,7 @@ function alg_cache(alg::Rodas4P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoU
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = Rodas4PConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits))
Expand All @@ -504,17 +466,22 @@ end
function alg_cache(alg::Rodas4P,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rodas4ConstantCache(tf,uf,Rodas4PConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)))
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rodas4ConstantCache(tf,uf,Rodas4PConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve)
end

################################################################################

### Rosenbrock5

struct Rosenbrock5ConstantCache{TF,UF,Tab} <: OrdinaryDiffEqConstantCache
struct Rosenbrock5ConstantCache{TF,UF,Tab,JType,WType,F} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
end

@cache mutable struct Rosenbrock5Cache{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: RosenbrockMutableCache
Expand Down Expand Up @@ -566,16 +533,7 @@ function alg_cache(alg::Rodas5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
if ArrayInterface.isstructured(f.jac_prototype) || f.jac_prototype isa SparseMatrixCSC
J = similar(f.jac_prototype)
W = similar(J)
elseif DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
end
J,W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
tab = Rodas5ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits))
Expand All @@ -595,7 +553,9 @@ end
function alg_cache(alg::Rodas5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
Rosenbrock5ConstantCache(tf,uf,Rodas5ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)))
J,W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
linsolve = alg.linsolve(Val{:init},uf,u)
Rosenbrock5ConstantCache(tf,uf,Rodas5ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,linsolve)
end

################################################################################
Expand Down
Loading