diff --git a/src/caches/extrapolation_caches.jl b/src/caches/extrapolation_caches.jl index 0c40fc17f5..3b4f6b94cd 100644 --- a/src/caches/extrapolation_caches.jl +++ b/src/caches/extrapolation_caches.jl @@ -69,12 +69,10 @@ end @cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache uprev::uType - u_tmp::uType u_tmps::Array{uType,1} utilde::uType tmp::uType atmp::uNoUnitsType - k_tmp::rateType k_tmps::Array{rateType,1} dtpropose::dtType T::arrayType @@ -82,15 +80,12 @@ end work::dtType A::Int step_no::Int - - du1::rateType du2::rateType J::JType W::WType tf::TFType uf::UFType - linsolve_tmp::rateType linsolve_tmps::Array{rateType,1} linsolve::Array{F,1} jac_config::JCType @@ -126,7 +121,8 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni u_tmp = similar(u) u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads()) - for i=1:Threads.nthreads() + u_tmps[1] = u_tmp + for i=2:Threads.nthreads() u_tmps[i] = zero(u_tmp) end @@ -135,7 +131,8 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni k_tmp = zero(rate_prototype) k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads()) - for i=1:Threads.nthreads() + k_tmps[1] = k_tmp + for i=2:Threads.nthreads() k_tmps[i] = zero(rate_prototype) end @@ -163,10 +160,17 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype? W_el = similar(J) end + W = Array{typeof(W_el),1}(undef, Threads.nthreads()) - for i=1:Threads.nthreads() - W[i] = zero(W_el) + W[1] = W_el + for i=2:Threads.nthreads() + if W_el isa WOperator + W_el = WOperator(f, dt, true) + else + W[i] = zero(W_el) + end end + tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p) uf = DiffEqDiffTools.UJacobianWrapper(f,t,p) linsolve_tmp = zero(rate_prototype) @@ -178,15 +182,16 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni linsolve_el = alg.linsolve(Val{:init},uf,u) linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads()) - for i=1:Threads.nthreads() + linsolve[1] = linsolve_el + for i=2:Threads.nthreads() linsolve[i] = alg.linsolve(Val{:init},uf,u) end grad_config = build_grad_config(alg,f,tf,du1,t) jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2) - ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no, - du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config) + ImplicitEulerExtrapolationCache(uprev,u_tmps,utilde,tmp,atmp,k_tmps,dtpropose,T,cur_order,work,A,step_no, + du1,du2,J,W,tf,uf,linsolve_tmps,linsolve,jac_config,grad_config) end @@ -699,4 +704,4 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation,u,rate_prototype,uElty ImplicitHairerWannerExtrapolationCache(utilde, u_temp1, u_temp2, u_temp3, u_temp4, tmp, T, res, fsalfirst, k, k_tmps, cc.Q, cc.n_curr, cc.n_old, cc.coefficients, cc.stage_number, cc.sigma, du1, du2, J, W, tf, uf, linsolve_tmp, linsolve, jac_config, grad_config) -end \ No newline at end of file +end diff --git a/src/perform_step/extrapolation_perform_step.jl b/src/perform_step/extrapolation_perform_step.jl index ed6e2015b4..b85e4df18b 100644 --- a/src/perform_step/extrapolation_perform_step.jl +++ b/src/perform_step/extrapolation_perform_step.jl @@ -233,7 +233,7 @@ end function initialize!(integrator,cache::ImplicitEulerExtrapolationCache) integrator.kshortsize = 2 - integrator.fsalfirst = zero(cache.k_tmp) + integrator.fsalfirst = zero(first(cache.k_tmps)) integrator.f(integrator.fsalfirst, integrator.u, integrator.p, integrator.t) integrator.fsallast = zero(integrator.fsalfirst) resize!(integrator.k, integrator.kshortsize) @@ -247,8 +247,8 @@ end function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache - @unpack J,W,uf,tf,linsolve_tmp,jac_config = cache + @unpack T,utilde,atmp,dtpropose,cur_order,A = cache + @unpack J,W,uf,tf,jac_config = cache @unpack u_tmps, k_tmps, linsolve_tmps = cache max_order = min(size(T)[1],cur_order+1)