Skip to content

Commit e46c08b

Browse files
authored
Merge pull request #876 from huanglangwen/simplify_dutils
[WIP] Simplify `derivative_utils.jl`
2 parents 35030ac + 58804f1 commit e46c08b

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

src/derivative_utils.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ function calc_J(integrator, cache::OrdinaryDiffEqConstantCache, is_compos)
5151
if DiffEqBase.has_jac(f)
5252
J = f.jac(uprev, p, t)
5353
else
54+
cache.uf.t = t
55+
cache.uf.p = p
5456
J = jacobian(cache.uf,uprev,integrator)
5557
end
5658
integrator.destats.njacs += 1
@@ -63,6 +65,8 @@ function calc_J(nlsolver, integrator, cache::OrdinaryDiffEqConstantCache, is_com
6365
if DiffEqBase.has_jac(f)
6466
J = f.jac(uprev, p, t)
6567
else
68+
nlsolver.uf.t = t
69+
nlsolver.uf.p = p
6670
J = jacobian(nlsolver.uf,uprev,integrator)
6771
end
6872
integrator.destats.njacs += 1
@@ -81,7 +85,7 @@ jacobian update function, then it will be called for the update. Otherwise,
8185
either ForwardDiff or finite difference will be used depending on the
8286
`jac_config` of the cache.
8387
"""
84-
function calc_J!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos)
88+
function calc_J!(integrator, cache::OrdinaryDiffEqCache, is_compos)
8589
if isdefined(cache, :nlsolver)
8690
calc_J!(cache.nlsolver, integrator, cache, is_compos)
8791
elseif isdefined(cache, :J)
@@ -106,6 +110,10 @@ function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqMutableCac
106110
is_compos && (integrator.eigen_est = opnorm(J, Inf))
107111
end
108112

113+
function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqConstantCache, is_compos)
114+
nlsolver.cache.J = calc_J(nlsolver,integrator,cache,is_compos)
115+
end
116+
109117
function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos)
110118
@unpack t,dt,uprev,u,f,p = integrator
111119
J = cache.J
@@ -121,6 +129,10 @@ function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_comp
121129
is_compos && (integrator.eigen_est = opnorm(J, Inf))
122130
end
123131

132+
function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqConstantCache, is_compos)
133+
cache.J = calc_J(integrator,cache,is_compos)
134+
end
135+
124136
"""
125137
WOperator(mass_matrix,gamma,J[;transform=false])
126138
@@ -312,6 +324,7 @@ end
312324

313325
@noinline _throwWJerror(W, J) = throw(DimensionMismatch("W: $(axes(W)), J: $(axes(J))"))
314326
@noinline _throwWMerror(W, mass_matrix) = throw(DimensionMismatch("W: $(axes(W)), mass matrix: $(axes(mass_matrix))"))
327+
@noinline _throwJMerror(J, mass_matrix) = throw(DimensionMismatch("J: $(axes(J)), mass matrix: $(axes(mass_matrix))"))
315328

316329
@inline function jacobian2W!(W::AbstractMatrix, mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT
317330
# check size and dimension
@@ -341,6 +354,28 @@ end
341354
return nothing
342355
end
343356

357+
@inline function jacobian2W(mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT
358+
# check size and dimension
359+
mass_matrix isa UniformScaling || @boundscheck axes(mass_matrix) === axes(J) || _throwJMerror(J, mass_matrix)
360+
@inbounds if W_transform
361+
invdtgamma = inv(dtgamma)
362+
if MT <: UniformScaling
363+
λ = -mass_matrix.λ
364+
W = J +* invdtgamma)*I
365+
else
366+
W = muladd(-mass_matrix, invdtgamma, J)
367+
end
368+
else
369+
if MT <: UniformScaling
370+
λ = -mass_matrix.λ
371+
W = dtgamma*J + λ*I
372+
else
373+
W = muladd(dtgamma, J, -mass_matrix)
374+
end
375+
end
376+
return W
377+
end
378+
344379
function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false)
345380
@unpack t,dt,uprev,u,f,p = integrator
346381
@unpack J,W = cache

0 commit comments

Comments
 (0)