Skip to content

Commit e30ce3e

Browse files
committed
Using ArrayInterface.lu_instance and added tests
1 parent 343281d commit e30ce3e

File tree

4 files changed

+53
-45
lines changed

4 files changed

+53
-45
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2222
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2323
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2424
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
25-
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2625

2726
[compat]
2827
ArrayInterface = "1.1, 2.0"

src/OrdinaryDiffEq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module OrdinaryDiffEq
55

66
using Logging
77

8-
using MuladdMacro, SparseArrays, SuiteSparse
8+
using MuladdMacro, SparseArrays
99

1010
using LinearAlgebra
1111

src/derivative_utils.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -538,23 +538,12 @@ function build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,::Val{IIP}) where IIP
538538
elseif IIP
539539
similar(J)
540540
else
541-
W = if u isa StaticArray
542-
lu(J)
543-
elseif u isa Number
544-
u
541+
if f.jac_prototype===nothing
542+
LU{LinearAlgebra.lutype(uEltypeNoUnits)}(Matrix{uEltypeNoUnits}(undef, 0, 0),
543+
Vector{LinearAlgebra.BlasInt}(undef, 0),
544+
zero(LinearAlgebra.BlasInt))
545545
else
546-
# TODO: make this more general, maybe by running calc_W and use its returned value
547-
if f.jac_prototype isa SparseMatrixCSC
548-
SuiteSparse.UMFPACK.UmfpackLU(Ptr{Cvoid}(), Ptr{Cvoid}(), 1, 1,
549-
f.jac_prototype.colptr[1:1],
550-
f.jac_prototype.rowval[1:1],
551-
f.jac_prototype.nzval[1:1],
552-
0)
553-
else
554-
LU{LinearAlgebra.lutype(uEltypeNoUnits)}(Matrix{uEltypeNoUnits}(undef, 0, 0),
555-
Vector{LinearAlgebra.BlasInt}(undef, 0),
556-
zero(LinearAlgebra.BlasInt))
557-
end
546+
ArrayInterface.lu_instance(f.jac_prototype)
558547
end
559548
end # end W
560549
end

test/interface/sparsediff_tests.jl

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,30 @@
1-
using Test
1+
using Test
22
using OrdinaryDiffEq
33
using SparseArrays
44
using LinearAlgebra
55

6+
## in-place
67
#https://github.com/JuliaDiffEq/SparseDiffTools.jl/blob/master/test/test_integration.jl
7-
function f(dx,x,p,t)
8-
for i in 2:length(x)-1
9-
dx[i] = x[i-1] - 2x[i] + x[i+1]
10-
end
11-
dx[1] = -2x[1] + x[2]
12-
dx[end] = x[end-1] - 2x[end]
13-
nothing
8+
function f_ip(dx,x,p,t)
9+
for i in 2:length(x)-1
10+
dx[i] = x[i-1] - 2x[i] + x[i+1]
11+
end
12+
dx[1] = -2x[1] + x[2]
13+
dx[end] = x[end-1] - 2x[end]
14+
nothing
15+
end
16+
17+
## out-of-place
18+
function f_oop(x,p,t)
19+
dx = similar(x)
20+
for i in 2:length(x)-1
21+
dx[i] = x[i-1] - 2x[i] + x[i+1]
1422
end
23+
dx[1] = -2x[1] + x[2]
24+
dx[end] = x[end-1] - 2x[end]
25+
return dx
26+
end
27+
1528

1629
function second_derivative_stencil(N)
1730
A = zeros(N,N)
@@ -34,22 +47,29 @@ jac_sp = sparse(generate_sparsity_pattern(10))
3447
colors = repeat(1:3,10)[1:10]
3548
u0=[1.,2.,3,4,5,5,4,3,2,1]
3649
tspan=(0.,10.)
37-
odefun_sp= ODEFunction(f,colorvec=colors,jac_prototype=jac_sp)
38-
prob_sp = ODEProblem(odefun_sp,u0,tspan)
39-
prob_std = ODEProblem(f,u0,tspan)
40-
41-
sol_sp=solve(prob_sp,Rodas5(autodiff=false),abstol=1e-10,reltol=1e-10)
42-
@test sol_sp.retcode==:Success#test sparse finitediff
43-
sol=solve(prob_std,Rodas5(autodiff=false),abstol=1e-10,reltol=1e-10)
44-
@test sol_sp.u[end]sol.u[end] atol=1e-10
45-
@test length(sol_sp.t)==length(sol.t)
46-
47-
sol_sp=solve(prob_sp,Rodas5(autodiff=false))
48-
sol=solve(prob_std,Rodas5(autodiff=false))
49-
@test sol_sp.u[end]sol.u[end]
50-
@test length(sol_sp.t)==length(sol.t)
51-
52-
sol_sp=solve(prob_sp,Rodas5())
53-
sol=solve(prob_std,Rodas5())
54-
@test sol_sp.u[end]sol.u[end]
55-
@test length(sol_sp.t)==length(sol.t)
50+
51+
for f in [f_ip, f_oop]
52+
odefun_std = ODEFunction(f)
53+
prob_std = ODEProblem(odefun_std,u0,tspan)
54+
55+
for ad in [true, false]
56+
for tol in [nothing, 1e-10]
57+
sol_std=solve(prob_std,Rodas5(autodiff=ad),reltol=tol,abstol=tol)
58+
@test sol_std.retcode==:Success
59+
for prob in map(f->ODEProblem(f,u0,tspan),
60+
[ODEFunction(f,colorvec=colors,jac_prototype=jac_sp),
61+
ODEFunction(f,jac_prototype=jac_sp),
62+
#ODEFunction(f,colorvec=colors) # this one fails both the u[end] and length tests
63+
])
64+
sol=solve(prob,Rodas5(autodiff=ad),reltol=tol,abstol=tol)
65+
@test sol.retcode==:Success
66+
if tol !=nothing
67+
@test sol_std.u[end]sol.u[end] atol=tol
68+
else
69+
@test sol_std.u[end]sol.u[end]
70+
end
71+
@test length(sol_std.t)==length(sol.t)
72+
end
73+
end
74+
end
75+
end

0 commit comments

Comments
 (0)