Skip to content

Commit 0863dab

Browse files
Merge pull request #1049 from mauro3/m3/sparse_jac
Fix inability to use sparse Jacobian with colorvec for OOP
2 parents c712c67 + c68e2ba commit 0863dab

File tree

4 files changed

+90
-66
lines changed

4 files changed

+90
-66
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2626
[compat]
2727
ArrayInterface = "1.1, 2.0"
2828
DataStructures = "0.17"
29-
DiffEqBase = "6.17.1"
29+
DiffEqBase = "6.19"
3030
ExponentialUtilities = "1.2"
3131
FiniteDiff = "2"
3232
ForwardDiff = "0.10.3"

src/derivative_utils.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -531,23 +531,19 @@ function build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,::Val{IIP}) where IIP
531531
end
532532
W = WOperator(f.mass_matrix, dt, J, IIP)
533533
else
534-
J = false .* _vec(u) .* _vec(u)'
534+
J = if f.jac_prototype === nothing
535+
false .* _vec(u) .* _vec(u)'
536+
else
537+
deepcopy(f.jac_prototype)
538+
end
535539
isdae = alg isa DAEAlgorithm
536540
W = if isdae
537541
J
538542
elseif IIP
539543
similar(J)
540544
else
541-
W = if u isa StaticArray
542-
lu(J)
543-
elseif u isa Number
544-
u
545-
else
546-
LU{LinearAlgebra.lutype(uEltypeNoUnits)}(Matrix{uEltypeNoUnits}(undef, 0, 0),
547-
Vector{LinearAlgebra.BlasInt}(undef, 0),
548-
zero(LinearAlgebra.BlasInt))
549-
end
550-
end # end W
545+
ArrayInterface.lu_instance(J)
546+
end
551547
end
552548
return J, W
553549
end

src/derivative_wrappers.jl

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,9 @@ end
3838

3939
jacobian_autodiff(f, x, odefun) = (ForwardDiff.derivative(f,x),1)
4040
function jacobian_autodiff(f, x::AbstractArray, odefun)
41-
if DiffEqBase.has_colorvec(odefun)
42-
colorvec = odefun.colorvec
43-
sparsity = odefun.jac_prototype
44-
jac_prototype = nothing
45-
else
46-
colorvec = 1:length(x)
47-
sparsity = nothing
48-
jac_prototype = odefun.jac_prototype
49-
end
41+
colorvec = DiffEqBase.has_colorvec(odefun) ? odefun.colorvec : 1:length(x)
42+
sparsity = odefun.sparsity
43+
jac_prototype = odefun.jac_prototype
5044
maxcolor = maximum(colorvec)
5145
chunksize = getsize(default_chunk_size(maxcolor))
5246
num_of_chunks = Int(ceil(maxcolor / chunksize))
@@ -71,22 +65,15 @@ jacobian_finitediff(f, x, diff_type, dir, colorvec, sparsity, jac_prototype) =
7165
jacobian_finitediff(f, x::AbstractArray, diff_type, dir, colorvec, sparsity, jac_prototype) =
7266
(FiniteDiff.finite_difference_jacobian(f, x, diff_type, eltype(x), diff_type==Val{:forward} ? f(x) : similar(x),
7367
dir = dir, colorvec = colorvec, sparsity = sparsity, jac_prototype = jac_prototype),_nfcount(maximum(colorvec),diff_type))
74-
7568
function jacobian(f, x, integrator)
7669
alg = unwrap_alg(integrator, true)
7770
local tmp
7871
if alg_autodiff(alg)
7972
J, tmp = jacobian_autodiff(f, x, integrator.f)
8073
else
81-
if DiffEqBase.has_colorvec(integrator.f)
82-
colorvec = integrator.f.colorvec
83-
sparsity = integrator.f.jac_prototype
84-
jac_prototype = nothing
85-
else
86-
colorvec = 1:length(x)
87-
sparsity = nothing
88-
jac_prototype = integrator.f.jac_prototype
89-
end
74+
colorvec = DiffEqBase.has_colorvec(integrator.f) ? integrator.f.colorvec : 1:length(x)
75+
sparsity = integrator.f.sparsity
76+
jac_prototype = integrator.f.jac_prototype
9077
dir = diffdir(integrator)
9178
J, tmp = jacobian_finitediff(f, x, alg.diff_type, dir, colorvec, sparsity, jac_prototype)
9279
end
@@ -96,10 +83,10 @@ end
9683

9784
jacobian_finitediff_forward!(J,f,x,jac_config,forwardcache,integrator,colorvec)=
9885
(FiniteDiff.finite_difference_jacobian!(J,f,x,jac_config,forwardcache,
99-
dir=diffdir(integrator),colorvec=colorvec,sparsity=integrator.f.jac_prototype);maximum(colorvec))
86+
dir=diffdir(integrator),colorvec=colorvec,sparsity=integrator.f.sparsity);maximum(colorvec))
10087
jacobian_finitediff!(J,f,x,jac_config,integrator,colorvec)=
10188
(FiniteDiff.finite_difference_jacobian!(J,f,x,jac_config,
102-
dir=diffdir(integrator),colorvec=colorvec,sparsity=integrator.f.jac_prototype);2*maximum(colorvec))
89+
dir=diffdir(integrator),colorvec=colorvec,sparsity=integrator.f.sparsity);2*maximum(colorvec))
10390

10491
function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, fx::AbstractArray{<:Number}, integrator::DiffEqBase.DEIntegrator, jac_config)
10592
alg = unwrap_alg(integrator, true)
@@ -125,13 +112,9 @@ end
125112
function DiffEqBase.build_jac_config(alg::Union{OrdinaryDiffEqAlgorithm,DAEAlgorithm},f,uf,du1,uprev,u,tmp,du2,::Val{transform}=Val(true)) where transform
126113
if !DiffEqBase.has_jac(f) && ((!transform && !DiffEqBase.has_Wfact(f)) || (transform && !DiffEqBase.has_Wfact_t(f)))
127114
if alg_autodiff(alg)
128-
if DiffEqBase.has_colorvec(f)
129-
colorvec = f.colorvec
130-
sparsity = f.jac_prototype
131-
else
132-
colorvec = 1:length(uprev)
133-
sparsity = nothing
134-
end
115+
colorvec = DiffEqBase.has_colorvec(f) ? f.colorvec : 1:length(uprev)
116+
sparsity = f.sparsity
117+
jac_prototype = f.jac_prototype
135118
jac_config = ForwardColorJacCache(uf,uprev,colorvec=colorvec,sparsity=sparsity)
136119
else
137120
if alg.diff_type != Val{:complex}

test/interface/sparsediff_tests.jl

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,28 @@ using OrdinaryDiffEq
33
using SparseArrays
44
using LinearAlgebra
55

6+
## in-place
67
#https://github.com/JuliaDiffEq/SparseDiffTools.jl/blob/master/test/test_integration.jl
7-
function f(dx,x,p,t)
8-
for i in 2:length(x)-1
9-
dx[i] = x[i-1] - 2x[i] + x[i+1]
10-
end
11-
dx[1] = -2x[1] + x[2]
12-
dx[end] = x[end-1] - 2x[end]
13-
nothing
8+
function f_ip(dx,x,p,t)
9+
for i in 2:length(x)-1
10+
dx[i] = x[i-1] - 2x[i] + x[i+1]
11+
end
12+
dx[1] = -2x[1] + x[2]
13+
dx[end] = x[end-1] - 2x[end]
14+
nothing
15+
end
16+
17+
## out-of-place
18+
function f_oop(x,p,t)
19+
dx = similar(x)
20+
for i in 2:length(x)-1
21+
dx[i] = x[i-1] - 2x[i] + x[i+1]
1422
end
23+
dx[1] = -2x[1] + x[2]
24+
dx[end] = x[end-1] - 2x[end]
25+
return dx
26+
end
27+
1528

1629
function second_derivative_stencil(N)
1730
A = zeros(N,N)
@@ -34,22 +47,54 @@ jac_sp = sparse(generate_sparsity_pattern(10))
3447
colors = repeat(1:3,10)[1:10]
3548
u0=[1.,2.,3,4,5,5,4,3,2,1]
3649
tspan=(0.,10.)
37-
odefun_sp= ODEFunction(f,colorvec=colors,jac_prototype=jac_sp)
38-
prob_sp = ODEProblem(odefun_sp,u0,tspan)
39-
prob_std = ODEProblem(f,u0,tspan)
40-
41-
sol_sp=solve(prob_sp,Rodas5(autodiff=false),abstol=1e-10,reltol=1e-10)
42-
@test sol_sp.retcode==:Success#test sparse finitediff
43-
sol=solve(prob_std,Rodas5(autodiff=false),abstol=1e-10,reltol=1e-10)
44-
@test sol_sp.u[end]sol.u[end] atol=1e-10
45-
@test length(sol_sp.t)==length(sol.t)
46-
47-
sol_sp=solve(prob_sp,Rodas5(autodiff=false))
48-
sol=solve(prob_std,Rodas5(autodiff=false))
49-
@test sol_sp.u[end]sol.u[end]
50-
@test length(sol_sp.t)==length(sol.t)
51-
52-
sol_sp=solve(prob_sp,Rodas5())
53-
sol=solve(prob_std,Rodas5())
54-
@test sol_sp.u[end]sol.u[end]
55-
@test length(sol_sp.t)==length(sol.t)
50+
51+
for f in [f_oop, f_ip]
52+
odefun_std = ODEFunction(f)
53+
prob_std = ODEProblem(odefun_std,u0,tspan)
54+
55+
for ad in [true, false]
56+
for Solver in [Rodas5, Trapezoid, KenCarp4]
57+
for tol in [nothing, 1e-10]
58+
sol_std=solve(prob_std,Solver(autodiff=ad),reltol=tol,abstol=tol)
59+
@test sol_std.retcode==:Success
60+
for (i,prob) in enumerate(map(f->ODEProblem(f,u0,tspan),
61+
[ODEFunction(f,colorvec=colors,jac_prototype=jac_sp),
62+
ODEFunction(f,jac_prototype=jac_sp),
63+
ODEFunction(f,colorvec=colors)
64+
]))
65+
# TODO: these broken test-cases need to be investigated.
66+
# Note they only happen for prob=ODEFunction(f,colorvec=colors).
67+
isbroken = i==3 && (
68+
(f, ad, tol) == (f_oop, true, nothing) ||
69+
(f, ad, Solver, tol) == (f_oop, false, Trapezoid, nothing) ||
70+
(f, Solver, tol) == (f_ip, Trapezoid, nothing) ||
71+
(f, Solver) == (f_oop, Rodas5) ||
72+
(f, Solver) == (f_ip, Rodas5) ||
73+
(f, Solver) == (f_oop, KenCarp4) ||
74+
(f, Solver) == (f_ip, KenCarp4)
75+
)
76+
sol=solve(prob,Solver(autodiff=ad),reltol=tol,abstol=tol)
77+
@test sol.retcode==:Success
78+
if tol !=nothing
79+
if isbroken
80+
@test_broken sol_std.u[end]sol.u[end] atol=tol
81+
else
82+
@test sol_std.u[end]sol.u[end] atol=tol
83+
end
84+
else
85+
if isbroken
86+
@test_broken sol_std.u[end]sol.u[end]
87+
else
88+
@test sol_std.u[end]sol.u[end]
89+
end
90+
end
91+
if isbroken
92+
@test_broken length(sol_std.t)==length(sol.t)
93+
else
94+
@test length(sol_std.t)==length(sol.t)
95+
end
96+
end
97+
end
98+
end
99+
end
100+
end

0 commit comments

Comments
 (0)