-
-
Notifications
You must be signed in to change notification settings - Fork 235
[WIP] Simplify derivative_utils.jl
#876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,8 @@ function calc_J(integrator, cache::OrdinaryDiffEqConstantCache, is_compos) | |
| if DiffEqBase.has_jac(f) | ||
| J = f.jac(uprev, p, t) | ||
| else | ||
| cache.uf.t = t | ||
| cache.uf.p = p | ||
| J = jacobian(cache.uf,uprev,integrator) | ||
| end | ||
| integrator.destats.njacs += 1 | ||
|
|
@@ -63,6 +65,8 @@ function calc_J(nlsolver, integrator, cache::OrdinaryDiffEqConstantCache, is_com | |
| if DiffEqBase.has_jac(f) | ||
| J = f.jac(uprev, p, t) | ||
| else | ||
| nlsolver.uf.t = t | ||
| nlsolver.uf.p = p | ||
| J = jacobian(nlsolver.uf,uprev,integrator) | ||
| end | ||
| integrator.destats.njacs += 1 | ||
|
|
@@ -81,7 +85,7 @@ jacobian update function, then it will be called for the update. Otherwise, | |
| either ForwardDiff or finite difference will be used depending on the | ||
| `jac_config` of the cache. | ||
| """ | ||
| function calc_J!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos) | ||
| function calc_J!(integrator, cache::OrdinaryDiffEqCache, is_compos) | ||
| if isdefined(cache, :nlsolver) | ||
| calc_J!(cache.nlsolver, integrator, cache, is_compos) | ||
| elseif isdefined(cache, :J) | ||
|
|
@@ -106,6 +110,10 @@ function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqMutableCac | |
| is_compos && (integrator.eigen_est = opnorm(J, Inf)) | ||
| end | ||
|
|
||
| function calc_J!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqConstantCache, is_compos) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There already exists an implementation for out-of-place function, it's called |
||
| nlsolver.cache.J = calc_J(nlsolver,integrator,cache,is_compos) | ||
| end | ||
|
|
||
| function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos) | ||
| @unpack t,dt,uprev,u,f,p = integrator | ||
| J = cache.J | ||
|
|
@@ -121,6 +129,10 @@ function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqMutableCache, is_comp | |
| is_compos && (integrator.eigen_est = opnorm(J, Inf)) | ||
| end | ||
|
|
||
| function calc_J_in_cache!(integrator, cache::OrdinaryDiffEqConstantCache, is_compos) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I said, I'm not sure we want to do this. |
||
| cache.J = calc_J(integrator,cache,is_compos) | ||
| end | ||
|
|
||
| """ | ||
| WOperator(mass_matrix,gamma,J[;transform=false]) | ||
|
|
||
|
|
@@ -312,6 +324,7 @@ end | |
|
|
||
| @noinline _throwWJerror(W, J) = throw(DimensionMismatch("W: $(axes(W)), J: $(axes(J))")) | ||
| @noinline _throwWMerror(W, mass_matrix) = throw(DimensionMismatch("W: $(axes(W)), mass matrix: $(axes(mass_matrix))")) | ||
| @noinline _throwJMerror(J, mass_matrix) = throw(DimensionMismatch("J: $(axes(J)), mass matrix: $(axes(mass_matrix))")) | ||
|
|
||
| @inline function jacobian2W!(W::AbstractMatrix, mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT | ||
| # check size and dimension | ||
|
|
@@ -341,6 +354,28 @@ end | |
| return nothing | ||
| end | ||
|
|
||
| @inline function jacobian2W(mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT | ||
| # check size and dimension | ||
| mass_matrix isa UniformScaling || @boundscheck axes(mass_matrix) === axes(J) || _throwJMerror(J, mass_matrix) | ||
| @inbounds if W_transform | ||
| invdtgamma = inv(dtgamma) | ||
| if MT <: UniformScaling | ||
| λ = -mass_matrix.λ | ||
| W = J + (λ * invdtgamma)*I | ||
| else | ||
| W = muladd(-mass_matrix, invdtgamma, J) | ||
| end | ||
| else | ||
| if MT <: UniformScaling | ||
| λ = -mass_matrix.λ | ||
| W = dtgamma*J + λ*I | ||
| else | ||
| W = muladd(dtgamma, J, -mass_matrix) | ||
| end | ||
| end | ||
| return W | ||
| end | ||
|
|
||
| function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false) | ||
| @unpack t,dt,uprev,u,f,p = integrator | ||
| @unpack J,W = cache | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! 👍 Can you remove the updates of
ufin the oop version ofcalc_W!?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking of delete oop version of
cache_W!directly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And only use
update_W!? Currently I don't like thatcalc_W!is calledcalc_W!both for in-place and out-of-place functions but updatesWonly for the in-place functions and returns a newWfor the out-of-place functions. So I thought maybe it could be done similar tocalc_Jandcalc_J!, i.e., we could have an out-of-place variantcalc_Wand an updating variantcalc_W!.Moreover, I'm not happy with the fact that at the moment there's a lot of reuse strategies hidden in the implementation of
calc_W!. Maybe it would be cleaner ifcalc_W!would just calculateW, without considering any reuse strategies, similar to what we do incalc_J. Ideally then we would have a separate function (such asupdate_W!) that first figures out ifJshould be updated (and then callscalc_J!) and then ifWshould be updated (and then callscalc_W!).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.