Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/caches/extrapolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,23 @@ end

@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
uprev::uType
u_tmp::uType
u_tmps::Array{uType,1}
utilde::uType
tmp::uType
atmp::uNoUnitsType
k_tmp::rateType
k_tmps::Array{rateType,1}
dtpropose::dtType
T::arrayType
cur_order::Int
work::dtType
A::Int
step_no::Int


du1::rateType
du2::rateType
J::JType
W::WType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve_tmps::Array{rateType,1}
linsolve::Array{F,1}
jac_config::JCType
Expand Down Expand Up @@ -126,7 +121,8 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
u_tmp = similar(u)
u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
u_tmps[1] = u_tmp
for i=2:Threads.nthreads()
u_tmps[i] = zero(u_tmp)
end

Expand All @@ -135,7 +131,8 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
k_tmp = zero(rate_prototype)
k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
k_tmps[1] = k_tmp
for i=2:Threads.nthreads()
k_tmps[i] = zero(rate_prototype)
end

Expand Down Expand Up @@ -163,10 +160,17 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W_el = similar(J)
end

W = Array{typeof(W_el),1}(undef, Threads.nthreads())
for i=1:Threads.nthreads()
W[i] = zero(W_el)
W[1] = W_el
for i=2:Threads.nthreads()
if W_el isa WOperator
W_el = WOperator(f, dt, true)
else
W[i] = zero(W_el)
end
end

tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
linsolve_tmp = zero(rate_prototype)
Expand All @@ -178,15 +182,16 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni

linsolve_el = alg.linsolve(Val{:init},uf,u)
linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads())
for i=1:Threads.nthreads()
linsolve[1] = linsolve_el
for i=2:Threads.nthreads()
linsolve[i] = alg.linsolve(Val{:init},uf,u)
end
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)


ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config)
ImplicitEulerExtrapolationCache(uprev,u_tmps,utilde,tmp,atmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
du1,du2,J,W,tf,uf,linsolve_tmps,linsolve,jac_config,grad_config)
end


Expand Down Expand Up @@ -699,4 +704,4 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation,u,rate_prototype,uElty
ImplicitHairerWannerExtrapolationCache(utilde, u_temp1, u_temp2, u_temp3, u_temp4, tmp, T, res, fsalfirst, k, k_tmps,
cc.Q, cc.n_curr, cc.n_old, cc.coefficients, cc.stage_number, cc.sigma, du1, du2, J, W, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config)
end
end
6 changes: 3 additions & 3 deletions src/perform_step/extrapolation_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ end
function initialize!(integrator,cache::ImplicitEulerExtrapolationCache)
integrator.kshortsize = 2

integrator.fsalfirst = zero(cache.k_tmp)
integrator.fsalfirst = zero(first(cache.k_tmps))
integrator.f(integrator.fsalfirst, integrator.u, integrator.p, integrator.t)
integrator.fsallast = zero(integrator.fsalfirst)
resize!(integrator.k, integrator.kshortsize)
Expand All @@ -247,8 +247,8 @@ end

function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_step=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache
@unpack J,W,uf,tf,linsolve_tmp,jac_config = cache
@unpack T,utilde,atmp,dtpropose,cur_order,A = cache
@unpack J,W,uf,tf,jac_config = cache
@unpack u_tmps, k_tmps, linsolve_tmps = cache

max_order = min(size(T)[1],cur_order+1)
Expand Down