Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -531,21 +531,27 @@ function build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,::Val{IIP}) where IIP
end
W = WOperator(f.mass_matrix, dt, J, IIP)
else
J = false .* _vec(u) .* _vec(u)'
J = if f.jac_prototype === nothing
false .* _vec(u) .* _vec(u)'
else
deepcopy(f.jac_prototype)
end
isdae = alg isa DAEAlgorithm
W = if isdae
J
elseif IIP
similar(J)
else
W = if u isa StaticArray
lu(J)
if u isa StaticArray
lu(J)
elseif u isa Number
u
else
elseif f.jac_prototype===nothing
LU{LinearAlgebra.lutype(uEltypeNoUnits)}(Matrix{uEltypeNoUnits}(undef, 0, 0),
Vector{LinearAlgebra.BlasInt}(undef, 0),
zero(LinearAlgebra.BlasInt))
Vector{LinearAlgebra.BlasInt}(undef, 0),
zero(LinearAlgebra.BlasInt))
else
ArrayInterface.lu_instance(f.jac_prototype)
end
end # end W
end
Expand Down
35 changes: 9 additions & 26 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,9 @@ end

jacobian_autodiff(f, x, odefun) = (ForwardDiff.derivative(f,x),1)
function jacobian_autodiff(f, x::AbstractArray, odefun)
if DiffEqBase.has_colorvec(odefun)
colorvec = odefun.colorvec
sparsity = odefun.jac_prototype
jac_prototype = nothing
else
colorvec = 1:length(x)
sparsity = nothing
jac_prototype = odefun.jac_prototype
end
colorvec = DiffEqBase.has_colorvec(odefun) ? odefun.colorvec : 1:length(x)
sparsity = odefun.jac_prototype
jac_prototype = odefun.jac_prototype
maxcolor = maximum(colorvec)
chunksize = getsize(default_chunk_size(maxcolor))
num_of_chunks = Int(ceil(maxcolor / chunksize))
Expand All @@ -71,22 +65,15 @@ jacobian_finitediff(f, x, diff_type, dir, colorvec, sparsity, jac_prototype) =
jacobian_finitediff(f, x::AbstractArray, diff_type, dir, colorvec, sparsity, jac_prototype) =
(FiniteDiff.finite_difference_jacobian(f, x, diff_type, eltype(x), diff_type==Val{:forward} ? f(x) : similar(x),
dir = dir, colorvec = colorvec, sparsity = sparsity, jac_prototype = jac_prototype),_nfcount(maximum(colorvec),diff_type))

function jacobian(f, x, integrator)
alg = unwrap_alg(integrator, true)
local tmp
if alg_autodiff(alg)
J, tmp = jacobian_autodiff(f, x, integrator.f)
else
if DiffEqBase.has_colorvec(integrator.f)
colorvec = integrator.f.colorvec
sparsity = integrator.f.jac_prototype
jac_prototype = nothing
else
colorvec = 1:length(x)
sparsity = nothing
jac_prototype = integrator.f.jac_prototype
end
colorvec = DiffEqBase.has_colorvec(integrator.f) ? integrator.f.colorvec : 1:length(x)
sparsity = integrator.f.jac_prototype
jac_prototype = integrator.f.jac_prototype
dir = diffdir(integrator)
J, tmp = jacobian_finitediff(f, x, alg.diff_type, dir, colorvec, sparsity, jac_prototype)
end
Expand Down Expand Up @@ -125,13 +112,9 @@ end
function DiffEqBase.build_jac_config(alg::Union{OrdinaryDiffEqAlgorithm,DAEAlgorithm},f,uf,du1,uprev,u,tmp,du2,::Val{transform}=Val(true)) where transform
if !DiffEqBase.has_jac(f) && ((!transform && !DiffEqBase.has_Wfact(f)) || (transform && !DiffEqBase.has_Wfact_t(f)))
if alg_autodiff(alg)
if DiffEqBase.has_colorvec(f)
colorvec = f.colorvec
sparsity = f.jac_prototype
else
colorvec = 1:length(uprev)
sparsity = nothing
end
colorvec = DiffEqBase.has_colorvec(f) ? f.colorvec : 1:length(uprev)
sparsity = f.jac_prototype
jac_prototype = f.jac_prototype
jac_config = ForwardColorJacCache(uf,uprev,colorvec=colorvec,sparsity=sparsity)
else
if alg.diff_type != Val{:complex}
Expand Down
97 changes: 71 additions & 26 deletions test/interface/sparsediff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,28 @@ using OrdinaryDiffEq
using SparseArrays
using LinearAlgebra

## in-place
#https://github.com/JuliaDiffEq/SparseDiffTools.jl/blob/master/test/test_integration.jl
function f(dx,x,p,t)
for i in 2:length(x)-1
dx[i] = x[i-1] - 2x[i] + x[i+1]
end
dx[1] = -2x[1] + x[2]
dx[end] = x[end-1] - 2x[end]
nothing
function f_ip(dx,x,p,t)
for i in 2:length(x)-1
dx[i] = x[i-1] - 2x[i] + x[i+1]
end
dx[1] = -2x[1] + x[2]
dx[end] = x[end-1] - 2x[end]
nothing
end

## out-of-place
function f_oop(x,p,t)
dx = similar(x)
for i in 2:length(x)-1
dx[i] = x[i-1] - 2x[i] + x[i+1]
end
dx[1] = -2x[1] + x[2]
dx[end] = x[end-1] - 2x[end]
return dx
end


function second_derivative_stencil(N)
A = zeros(N,N)
Expand All @@ -34,22 +47,54 @@ jac_sp = sparse(generate_sparsity_pattern(10))
colors = repeat(1:3,10)[1:10]
u0=[1.,2.,3,4,5,5,4,3,2,1]
tspan=(0.,10.)
odefun_sp= ODEFunction(f,colorvec=colors,jac_prototype=jac_sp)
prob_sp = ODEProblem(odefun_sp,u0,tspan)
prob_std = ODEProblem(f,u0,tspan)

sol_sp=solve(prob_sp,Rodas5(autodiff=false),abstol=1e-10,reltol=1e-10)
@test sol_sp.retcode==:Success#test sparse finitediff
sol=solve(prob_std,Rodas5(autodiff=false),abstol=1e-10,reltol=1e-10)
@test sol_sp.u[end]≈sol.u[end] atol=1e-10
@test length(sol_sp.t)==length(sol.t)

sol_sp=solve(prob_sp,Rodas5(autodiff=false))
sol=solve(prob_std,Rodas5(autodiff=false))
@test sol_sp.u[end]≈sol.u[end]
@test length(sol_sp.t)==length(sol.t)

sol_sp=solve(prob_sp,Rodas5())
sol=solve(prob_std,Rodas5())
@test sol_sp.u[end]≈sol.u[end]
@test length(sol_sp.t)==length(sol.t)

for f in [f_oop, f_ip]
odefun_std = ODEFunction(f)
prob_std = ODEProblem(odefun_std,u0,tspan)

for ad in [true, false]
for Solver in [Rodas5, Trapezoid, KenCarp4]
for tol in [nothing, 1e-10]
# @show f,ad,Solver,tol
sol_std=solve(prob_std,Solver(autodiff=ad),reltol=tol,abstol=tol)
@test sol_std.retcode==:Success
for (i,prob) in enumerate(map(f->ODEProblem(f,u0,tspan),
[ODEFunction(f,colorvec=colors,jac_prototype=jac_sp),
ODEFunction(f,jac_prototype=jac_sp),
ODEFunction(f,colorvec=colors)
]))
isbroken = i==3 && (
(f, ad, tol) == (f_oop, true, nothing) ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all with AD fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only those with ODEFunction(f,colorvec=colors), and Trapezoid passes there too. Note that they are not completely wrong, just about a factor 10 larger (different) error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a closer look at the broken tests:maximum(sol_std.u[end].-sol.u[end]) is around 1e-5 for tol=1e-10 and around 1e-3 for tol=nothing.

(f, ad, Solver, tol) == (f_oop, false, Trapezoid, nothing) ||
(f, Solver, tol) == (f_ip, Trapezoid, nothing) ||
(f, Solver) == (f_oop, Rodas5) ||
(f, Solver) == (f_ip, Rodas5) ||
(f, Solver) == (f_oop, KenCarp4) ||
(f, Solver) == (f_ip, KenCarp4)
)
Comment on lines +67 to +75
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These failing tests need investigation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't expect colors without sparsity to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the color stuff should be independent of the matrix type of J.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are still "broken".

# @show i
sol=solve(prob,Solver(autodiff=ad),reltol=tol,abstol=tol)
@test sol.retcode==:Success
if tol !=nothing
if isbroken
@test_broken sol_std.u[end]≈sol.u[end] atol=tol
else
@test sol_std.u[end]≈sol.u[end] atol=tol
end
else
if isbroken
@test_broken sol_std.u[end]≈sol.u[end]
else
@test sol_std.u[end]≈sol.u[end]
end
end
if isbroken
@test_broken length(sol_std.t)==length(sol.t)
else
@test length(sol_std.t)==length(sol.t)
end
end
end
end
end
end