diff --git a/src/algorithms.jl b/src/algorithms.jl index 5c96dcab1e..3d42dfdb1a 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -53,14 +53,15 @@ struct ImplicitEulerExtrapolation{CS,AD,F,F2} <: OrdinaryDiffEqImplicitExtrapola max_order::Int min_order::Int init_order::Int + threading::Bool end ImplicitEulerExtrapolation(;chunk_size=0,autodiff=true,diff_type=Val{:forward}, linsolve=DEFAULT_LINSOLVE, - max_order=10,min_order=1,init_order=5) = + max_order=10,min_order=1,init_order=5,threading=true) = ImplicitEulerExtrapolation{chunk_size,autodiff, - typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order) + typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order,threading) struct ExtrapolationMidpointDeuflhard <: OrdinaryDiffEqExtrapolationVarOrderVarStepAlgorithm n_min::Int # Minimal extrapolation order diff --git a/src/caches/extrapolation_caches.jl b/src/caches/extrapolation_caches.jl index 6df2487210..0c40fc17f5 100644 --- a/src/caches/extrapolation_caches.jl +++ b/src/caches/extrapolation_caches.jl @@ -70,10 +70,12 @@ end @cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache uprev::uType u_tmp::uType + u_tmps::Array{uType,1} utilde::uType tmp::uType atmp::uNoUnitsType k_tmp::rateType + k_tmps::Array{rateType,1} dtpropose::dtType T::arrayType cur_order::Int @@ -89,7 +91,8 @@ end tf::TFType uf::UFType linsolve_tmp::rateType - linsolve::F + linsolve_tmps::Array{rateType,1} + linsolve::Array{F,1} jac_config::JCType grad_config::GCType end @@ -121,10 +124,21 @@ end function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}}) u_tmp = similar(u) + u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads()) + + for i=1:Threads.nthreads() + u_tmps[i] = zero(u_tmp) + end + utilde = similar(u) tmp = similar(u) - k = zero(rate_prototype) k_tmp = zero(rate_prototype) + k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads()) + + for i=1:Threads.nthreads() + k_tmps[i] = zero(rate_prototype) + end + cur_order = max(alg.init_order, alg.min_order) dtpropose = zero(dt) T = Array{typeof(u),2}(undef, alg.max_order, alg.max_order) @@ -143,22 +157,36 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni du2 = zero(rate_prototype) if DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing - W = WOperator(f, dt, true) + W_el = WOperator(f, dt, true) J = nothing # is J = W.J better? else J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype? - W = similar(J) + W_el = similar(J) + end + W = Array{typeof(W_el),1}(undef, Threads.nthreads()) + for i=1:Threads.nthreads() + W[i] = zero(W_el) end tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p) uf = DiffEqDiffTools.UJacobianWrapper(f,t,p) linsolve_tmp = zero(rate_prototype) - linsolve = alg.linsolve(Val{:init},uf,u) + linsolve_tmps = Array{typeof(linsolve_tmp),1}(undef, Threads.nthreads()) + + for i=1:Threads.nthreads() + linsolve_tmps[i] = zero(rate_prototype) + end + + linsolve_el = alg.linsolve(Val{:init},uf,u) + linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads()) + for i=1:Threads.nthreads() + linsolve[i] = alg.linsolve(Val{:init},uf,u) + end grad_config = build_grad_config(alg,f,tf,du1,t) jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2) - ImplicitEulerExtrapolationCache(uprev,u_tmp,utilde,tmp,atmp,k_tmp,dtpropose,T,cur_order,work,A,step_no, - du1,du2,J,W,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config) + ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no, + du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config) end diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 70869a5f63..872be25218 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -392,6 +392,57 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_ return nothing end +function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_index::Int, W_transform=false) + @unpack t,dt,uprev,u,f,p = integrator + @unpack J,W = cache + alg = unwrap_alg(integrator, true) + mass_matrix = integrator.f.mass_matrix + is_compos = integrator.alg isa CompositeAlgorithm + isnewton = alg isa NewtonAlgorithm + + if W_transform && DiffEqBase.has_Wfact_t(f) + f.Wfact_t(W[W_index], u, p, dtgamma, t) + is_compos && (integrator.eigen_est = opnorm(LowerTriangular(W[W_index]), Inf) + inv(dtgamma)) # TODO: better estimate + return nothing + elseif !W_transform && DiffEqBase.has_Wfact(f) + f.Wfact(W[W_index], u, p, dtgamma, t) + if is_compos + opn = opnorm(LowerTriangular(W[W_index]), Inf) + integrator.eigen_est = (opn + one(opn)) / dtgamma # TODO: better estimate + end + return nothing + end + + # fast pass + # we only want to factorize the linear operator once + new_jac = true + new_W = true + if (f isa ODEFunction && islinear(f.f)) || (integrator.alg isa SplitAlgorithms && f isa SplitFunction && islinear(f.f1.f)) + new_jac = false + @goto J2W # Jump to W calculation directly, because we already have J + end + + # check if we need to update J or W + W_dt = isnewton ? cache.nlsolver.cache.W_dt : dt # TODO: RosW + new_jac = isnewton ? do_newJ(integrator, alg, cache, repeat_step) : true + new_W = isnewton ? do_newW(integrator, cache.nlsolver, new_jac, W_dt) : true + + # calculate W + if DiffEqBase.has_jac(f) && f.jac_prototype !== nothing && !ArrayInterface.isstructured(f.jac_prototype) + isnewton || DiffEqBase.update_coefficients!(W[W_index],uprev,p,t) # we will call `update_coefficients!` in NLNewton + @label J2W + W[W_index].transform = W_transform; set_gamma!(W[W_index], dtgamma) + else # concrete W using jacobian from `calc_J!` + new_jac && calc_J!(integrator, cache, is_compos) + new_W && jacobian2W!(W[W_index], mass_matrix, dtgamma, J, W_transform) + end + if isnewton + set_new_W!(cache.nlsolver, new_W) && DiffEqBase.set_W_dt!(cache.nlsolver, dt) + end + new_W && (integrator.destats.nw += 1) + return nothing +end + function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false) @unpack t,dt,uprev,u,f,p = integrator @unpack J,W = nlsolver.cache diff --git a/src/perform_step/extrapolation_perform_step.jl b/src/perform_step/extrapolation_perform_step.jl index 347bc139c7..ed6e2015b4 100644 --- a/src/perform_step/extrapolation_perform_step.jl +++ b/src/perform_step/extrapolation_perform_step.jl @@ -249,27 +249,39 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_ @unpack t,dt,uprev,u,f,p = integrator @unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache @unpack J,W,uf,tf,linsolve_tmp,jac_config = cache + @unpack u_tmps, k_tmps, linsolve_tmps = cache max_order = min(size(T)[1],cur_order+1) - for i in 1:max_order - dt_temp = dt/(2^(i-1)) # Romberg sequence - calc_W!(integrator, cache, dt_temp, repeat_step) - k_tmp = copy(integrator.fsalfirst) - u_tmp = copy(uprev) - for j in 1:2^(i-1) - linsolve_tmp = dt_temp*k_tmp - cache.linsolve(vec(k_tmp), W, vec(linsolve_tmp), !repeat_step) - @.. k_tmp = -k_tmp - @.. u_tmp = u_tmp + k_tmp - f(k_tmp, u_tmp,p,t+j*dt_temp) - end + let max_order=max_order, uprev=uprev, dt=dt, p=p, t=t, T=T, W=W, + integrator=integrator, cache=cache, repeat_step = repeat_step, + k_tmps=k_tmps, u_tmps=u_tmps + Threads.@threads for i in 1:2 + startIndex = (i == 1) ? 1 : max_order + endIndex = (i == 1) ? max_order - 1 : max_order + for index in startIndex:endIndex + dt_temp = dt/(2^(index-1)) # Romberg sequence + calc_W!(integrator, cache, dt_temp, repeat_step, Threads.threadid()) + k_tmps[Threads.threadid()] = copy(integrator.fsalfirst) + u_tmps[Threads.threadid()] = copy(uprev) + for j in 1:2^(index-1) + @.. linsolve_tmps[Threads.threadid()] = dt_temp*k_tmps[Threads.threadid()] + cache.linsolve[Threads.threadid()](vec(k_tmps[Threads.threadid()]), W[Threads.threadid()], vec(linsolve_tmps[Threads.threadid()]), !repeat_step) + @.. k_tmps[Threads.threadid()] = -k_tmps[Threads.threadid()] + @.. u_tmps[Threads.threadid()] = u_tmps[Threads.threadid()] + k_tmps[Threads.threadid()] + f(k_tmps[Threads.threadid()], u_tmps[Threads.threadid()],p,t+j*dt_temp) + end - @.. T[i,1] = u_tmp - for j in 2:i - @.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) + @.. T[index,1] = u_tmps[Threads.threadid()] + end + end + for i in 2:max_order + for j in 2:i + @.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) + end end end + integrator.dt = dt if integrator.opts.adaptive @@ -332,23 +344,33 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache max_order = min(size(T)[1], cur_order+1) - for i in 1:max_order - dt_temp = dt/(2^(i-1)) # Romberg sequence - W = calc_W!(integrator, cache, dt_temp, repeat_step) - k_copy = integrator.fsalfirst - u_tmp = uprev - for j in 1:2^(i-1) - k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev)) - integrator.destats.nsolve += 1 - u_tmp = u_tmp + k - k_copy = f(u_tmp, p, t+j*dt_temp) + let max_order=max_order, dt=dt, integrator=integrator, cache=cache, repeat_step=repeat_step, + uprev=uprev, T=T + Threads.@threads for i in 1:2 + startIndex = (i==1) ? 1 : max_order + endIndex = (i==1) ? max_order-1 : max_order + for index in startIndex:endIndex + dt_temp = dt/(2^(index-1)) # Romberg sequence + W = calc_W!(integrator, cache, dt_temp, repeat_step) + k_copy = integrator.fsalfirst + u_tmp = uprev + for j in 1:2^(index-1) + k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev)) + integrator.destats.nsolve += 1 + u_tmp = u_tmp + k + k_copy = f(u_tmp, p, t+j*dt_temp) + end + T[index,1] = u_tmp + end end - T[i,1] = u_tmp - # Richardson Extrapolation - for j in 2:i - T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) + + for i=2:max_order + for j=2:i + T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1) + end end end + integrator.destats.nf += 2^(max_order) - 1 integrator.dt = dt @@ -391,9 +413,9 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache # Use extrapolated value of u integrator.u = T[cache.cur_order, cache.cur_order] - k = f(integrator.u, p, t+dt) + k_temp = f(integrator.u, p, t+dt) integrator.destats.nf += 1 - integrator.fsallast = k + integrator.fsallast = k_temp integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast end