Skip to content

Commit ef3982d

Browse files
Merge pull request #875 from JuliaDiffEq/memory
memory optimize the ImplicitEulerExtrapolation cache
2 parents 65bda13 + 6c5f074 commit ef3982d

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

src/caches/extrapolation_caches.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,23 @@ end
6969

7070
@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
7171
uprev::uType
72-
u_tmp::uType
7372
u_tmps::Array{uType,1}
7473
utilde::uType
7574
tmp::uType
7675
atmp::uNoUnitsType
77-
k_tmp::rateType
7876
k_tmps::Array{rateType,1}
7977
dtpropose::dtType
8078
T::arrayType
8179
cur_order::Int
8280
work::dtType
8381
A::Int
8482
step_no::Int
85-
86-
8783
du1::rateType
8884
du2::rateType
8985
J::JType
9086
W::WType
9187
tf::TFType
9288
uf::UFType
93-
linsolve_tmp::rateType
9489
linsolve_tmps::Array{rateType,1}
9590
linsolve::Array{F,1}
9691
jac_config::JCType
@@ -126,7 +121,8 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
126121
u_tmp = similar(u)
127122
u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads())
128123

129-
for i=1:Threads.nthreads()
124+
u_tmps[1] = u_tmp
125+
for i=2:Threads.nthreads()
130126
u_tmps[i] = zero(u_tmp)
131127
end
132128

@@ -135,7 +131,8 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
135131
k_tmp = zero(rate_prototype)
136132
k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads())
137133

138-
for i=1:Threads.nthreads()
134+
k_tmps[1] = k_tmp
135+
for i=2:Threads.nthreads()
139136
k_tmps[i] = zero(rate_prototype)
140137
end
141138

@@ -163,10 +160,17 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
163160
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
164161
W_el = similar(J)
165162
end
163+
166164
W = Array{typeof(W_el),1}(undef, Threads.nthreads())
167-
for i=1:Threads.nthreads()
168-
W[i] = zero(W_el)
165+
W[1] = W_el
166+
for i=2:Threads.nthreads()
167+
if W_el isa WOperator
168+
W_el = WOperator(f, dt, true)
169+
else
170+
W[i] = zero(W_el)
171+
end
169172
end
173+
170174
tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
171175
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
172176
linsolve_tmp = zero(rate_prototype)
@@ -178,15 +182,16 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
178182

179183
linsolve_el = alg.linsolve(Val{:init},uf,u)
180184
linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads())
181-
for i=1:Threads.nthreads()
185+
linsolve[1] = linsolve_el
186+
for i=2:Threads.nthreads()
182187
linsolve[i] = alg.linsolve(Val{:init},uf,u)
183188
end
184189
grad_config = build_grad_config(alg,f,tf,du1,t)
185190
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)
186191

187192

188-
ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
189-
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config)
193+
ImplicitEulerExtrapolationCache(uprev,u_tmps,utilde,tmp,atmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
194+
du1,du2,J,W,tf,uf,linsolve_tmps,linsolve,jac_config,grad_config)
190195
end
191196

192197

@@ -699,4 +704,4 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation,u,rate_prototype,uElty
699704
ImplicitHairerWannerExtrapolationCache(utilde, u_temp1, u_temp2, u_temp3, u_temp4, tmp, T, res, fsalfirst, k, k_tmps,
700705
cc.Q, cc.n_curr, cc.n_old, cc.coefficients, cc.stage_number, cc.sigma, du1, du2, J, W, tf, uf, linsolve_tmp,
701706
linsolve, jac_config, grad_config)
702-
end
707+
end

src/perform_step/extrapolation_perform_step.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ end
233233
function initialize!(integrator,cache::ImplicitEulerExtrapolationCache)
234234
integrator.kshortsize = 2
235235

236-
integrator.fsalfirst = zero(cache.k_tmp)
236+
integrator.fsalfirst = zero(first(cache.k_tmps))
237237
integrator.f(integrator.fsalfirst, integrator.u, integrator.p, integrator.t)
238238
integrator.fsallast = zero(integrator.fsalfirst)
239239
resize!(integrator.k, integrator.kshortsize)
@@ -247,8 +247,8 @@ end
247247

248248
function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_step=false)
249249
@unpack t,dt,uprev,u,f,p = integrator
250-
@unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache
251-
@unpack J,W,uf,tf,linsolve_tmp,jac_config = cache
250+
@unpack T,utilde,atmp,dtpropose,cur_order,A = cache
251+
@unpack J,W,uf,tf,jac_config = cache
252252
@unpack u_tmps, k_tmps, linsolve_tmps = cache
253253

254254
max_order = min(size(T)[1],cur_order+1)

0 commit comments

Comments
 (0)