Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ function calc_J(integrator, cache::OrdinaryDiffEqConstantCache, is_compos)
if DiffEqBase.has_jac(f)
J = f.jac(uprev, p, t)
else
cache.uf.t = t
cache.uf.p = p
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! 👍 Can you remove the updates of uf in the oop version of calc_W!?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of delete oop version of cache_W! directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And only use update_W!? Currently I don't like that calc_W! is called calc_W! both for in-place and out-of-place functions but updates W only for the in-place functions and returns a new W for the out-of-place functions. So I thought maybe it could be done similar to calc_J and calc_J!, i.e., we could have an out-of-place variant calc_W and an updating variant calc_W!.

Moreover, I'm not happy with the fact that at the moment there's a lot of reuse strategies hidden in the implementation of calc_W!. Maybe it would be cleaner if calc_W! would just calculate W, without considering any reuse strategies, similar to what we do in calc_J. Ideally then we would have a separate function (such as update_W!) that first figures out if J should be updated (and then calls calc_J!) and then if W should be updated (and then calls calc_W!).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.

J = jacobian(cache.uf,uprev,integrator)
end
integrator.destats.njacs += 1
Expand All @@ -63,6 +65,8 @@ function calc_J(nlsolver, integrator, cache::OrdinaryDiffEqConstantCache, is_com
if DiffEqBase.has_jac(f)
J = f.jac(uprev, p, t)
else
nlsolver.uf.t = t
nlsolver.uf.p = p
J = jacobian(nlsolver.uf,uprev,integrator)
end
integrator.destats.njacs += 1
Expand All @@ -81,7 +85,7 @@ jacobian update function, then it will be called for the update. Otherwise,
either ForwardDiff or finite difference will be used depending on the
`jac_config` of the cache.
"""
function calc_J!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos)
function calc_J!(integrator, cache::OrdinaryDiffEqCache, is_compos)
if isdefined(cache, :nlsolver)
calc_J!(cache.nlsolver, integrator, cache, is_compos)
elseif isdefined(cache, :J)
Expand All @@ -106,6 +110,10 @@ function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqMutableCac
is_compos && (integrator.eigen_est = opnorm(J, Inf))
end

function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqConstantCache, is_compos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There already exists an implementation for out-of-place function, it's called calc_J. So if we need an in-place variant calc_J! (which I'm not sure about) we could just define calc_J! as cache.J = calc_J(...)

nlsolver.cache.J = calc_J(nlsolver,integrator,cache,is_compos)
end

function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos)
@unpack t,dt,uprev,u,f,p = integrator
J = cache.J
Expand All @@ -121,6 +129,10 @@ function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_comp
is_compos && (integrator.eigen_est = opnorm(J, Inf))
end

function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqConstantCache, is_compos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said, I'm not sure we want to do this.

cache.J = calc_J(integrator,cache,is_compos)
end

"""
WOperator(mass_matrix,gamma,J[;transform=false])

Expand Down Expand Up @@ -312,6 +324,7 @@ end

@noinline _throwWJerror(W, J) = throw(DimensionMismatch("W: $(axes(W)), J: $(axes(J))"))
@noinline _throwWMerror(W, mass_matrix) = throw(DimensionMismatch("W: $(axes(W)), mass matrix: $(axes(mass_matrix))"))
@noinline _throwJMerror(J, mass_matrix) = throw(DimensionMismatch("J: $(axes(J)), mass matrix: $(axes(mass_matrix))"))

@inline function jacobian2W!(W::AbstractMatrix, mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT
# check size and dimension
Expand Down Expand Up @@ -341,6 +354,28 @@ end
return nothing
end

@inline function jacobian2W(mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT
# check size and dimension
mass_matrix isa UniformScaling || @boundscheck axes(mass_matrix) === axes(J) || _throwJMerror(J, mass_matrix)
@inbounds if W_transform
invdtgamma = inv(dtgamma)
if MT <: UniformScaling
λ = -mass_matrix.λ
W = J + (λ * invdtgamma)*I
else
W = muladd(-mass_matrix, invdtgamma, J)
end
else
if MT <: UniformScaling
λ = -mass_matrix.λ
W = dtgamma*J + λ*I
else
W = muladd(dtgamma, J, -mass_matrix)
end
end
return W
end

function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack J,W = cache
Expand Down