Skip to content

Commit 974ff9e

Browse files
committed
steps towards iterative solvers + tests for general numbers
1 parent fa64510 commit 974ff9e

File tree

2 files changed

+64
-38
lines changed

2 files changed

+64
-38
lines changed

src/iterative_wrappers.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,13 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
153153
M = (M === Identity()) ? I : InvPreconditioner(M)
154154
N = (N === Identity()) ? I : InvPreconditioner(N)
155155

156-
atol = float(cache.abstol)
157-
rtol = float(cache.reltol)
156+
Ta = eltype(cache.A)
157+
158+
atol = Ta(float(cache.abstol))
159+
rtol = Ta(float(cache.reltol))
158160
itmax = cache.maxiters
159161
verbose = cache.verbose ? 1 : 0
160-
162+
161163
args = (cache.cacheval, cache.A, cache.b)
162164
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
163165
history = true, alg.kwargs...)

test/basictests.jl

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using Test
33
import Random
44

55
const Dual64 = ForwardDiff.Dual{Nothing, Float64, 1}
6+
Base.:^(x::MultiFloat{T, N}, y::Int) where {T,N} = MultiFloat{T, N}(BigFloat(x)^y)
7+
Base.:^(x::MultiFloat{T, N}, y::Float64) where {T,N} = MultiFloat{T, N}(BigFloat(x)^y)
68

79
n = 8
810
A = Matrix(I, n, n)
@@ -19,18 +21,21 @@ prob2 = LinearProblem(A2, b2; u0 = x2)
1921

2022
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
2123

22-
function test_interface(alg, prob1, prob2)
23-
A1 = prob1.A
24-
b1 = prob1.b
25-
x1 = prob1.u0
26-
A2 = prob2.A
27-
b2 = prob2.b
28-
x2 = prob2.u0
29-
30-
y = solve(prob1, alg; cache_kwargs...)
24+
function test_interface(alg, prob1, prob2; T=Float64)
25+
A1 = prob1.A .|> T
26+
b1 = prob1.b .|> T
27+
x1 = prob1.u0 .|> T
28+
A2 = prob2.A .|> T
29+
b2 = prob2.b .|> T
30+
x2 = prob2.u0 .|> T
31+
32+
myprob1 = LinearProblem(A1, b1; u0 = x1)
33+
myprob2 = LinearProblem(A2, b2; u0 = x2)
34+
35+
y = solve(myprob1, alg; cache_kwargs...)
3136
@test A1 * y b1
3237

33-
cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
38+
cache = SciMLBase.init(myprob1, alg; cache_kwargs...) # initialize cache
3439
y = solve(cache)
3540
@test A1 * y b1
3641

@@ -140,44 +145,44 @@ end
140145
end
141146

142147
@testset "Sparspak Factorization (Float64x1)" begin
143-
A1 = sparse(A / 1) .|> Float64x1
144-
b1 = rand(n) .|> Float64x1
145-
x1 = zero(b) .|> Float64x1
146-
A2 = sparse(A / 2) .|> Float64x1
147-
b2 = rand(n) .|> Float64x1
148-
x2 = zero(b) .|> Float64x1
148+
A1 = sparse(A / 1)
149+
b1 = rand(n)
150+
x1 = zero(b)
151+
A2 = sparse(A / 2)
152+
b2 = rand(n)
153+
x2 = zero(b)
149154

150155
prob1 = LinearProblem(A1, b1; u0 = x1)
151156
prob2 = LinearProblem(A2, b2; u0 = x2)
152-
test_interface(SparspakFactorization(), prob1, prob2)
157+
test_interface(SparspakFactorization(), prob1, prob2; T=Float64x1)
153158
end
154159

155160
@testset "Sparspak Factorization (Float64x2)" begin
156-
A1 = sparse(A / 1) .|> Float64x2
157-
b1 = rand(n) .|> Float64x2
158-
x1 = zero(b) .|> Float64x2
159-
A2 = sparse(A / 2) .|> Float64x2
160-
b2 = rand(n) .|> Float64x2
161-
x2 = zero(b) .|> Float64x2
161+
A1 = sparse(A / 1)
162+
b1 = rand(n)
163+
x1 = zero(b)
164+
A2 = sparse(A / 2)
165+
b2 = rand(n)
166+
x2 = zero(b)
162167

163168
prob1 = LinearProblem(A1, b1; u0 = x1)
164169
prob2 = LinearProblem(A2, b2; u0 = x2)
165-
test_interface(SparspakFactorization(), prob1, prob2)
170+
test_interface(SparspakFactorization(), prob1, prob2; T=Float64x2)
166171
end
167172

168173
@testset "Sparspak Factorization (Dual64)" begin
169-
A1 = sparse(A / 1) .|> Dual64
170-
b1 = rand(n) .|> Dual64
171-
x1 = zero(b) .|> Dual64
172-
A2 = sparse(A / 2) .|> Dual64
173-
b2 = rand(n) .|> Dual64
174-
x2 = zero(b) .|> Dual64
174+
A1 = sparse(A / 1)
175+
b1 = rand(n)
176+
x1 = zero(b)
177+
A2 = sparse(A / 2)
178+
b2 = rand(n)
179+
x2 = zero(b)
175180

176181
prob1 = LinearProblem(A1, b1; u0 = x1)
177182
prob2 = LinearProblem(A2, b2; u0 = x2)
178-
test_interface(SparspakFactorization(), prob1, prob2)
183+
test_interface(SparspakFactorization(), prob1, prob2; T=Dual64)
179184
end
180-
185+
181186
@testset "FastLAPACK Factorizations" begin
182187
A1 = A / 1
183188
b1 = rand(n)
@@ -225,7 +230,14 @@ end
225230
("GMRES", KrylovJL_GMRES(kwargs...)),
226231
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
227232
("MINRES", KrylovJL_MINRES(kwargs...)))
228-
@testset "$(alg[1])" begin test_interface(alg[2], prob1, prob2) end
233+
@testset "$(alg[1])" begin
234+
test_interface(alg[2], prob1, prob2)
235+
test_interface(alg[2], prob1, prob2; T=Float64x1)
236+
test_interface(alg[2], prob1, prob2; T=Float64x2)
237+
# test_interface(alg[2], prob1, prob2; T=Dual64)
238+
# https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/646
239+
# ForwardDiff.Dual is a Real, not an AbstractFloat
240+
end
229241
end
230242
end
231243

@@ -237,7 +249,14 @@ end
237249
# ("BICGSTAB",IterativeSolversJL_BICGSTAB(kwargs...)),
238250
# ("MINRES",IterativeSolversJL_MINRES(kwargs...)),
239251
)
240-
@testset "$(alg[1])" begin test_interface(alg[2], prob1, prob2) end
252+
@testset "$(alg[1])" begin
253+
test_interface(alg[2], prob1, prob2)
254+
test_interface(alg[2], prob1, prob2; T=Float64x1)
255+
test_interface(alg[2], prob1, prob2; T=Float64x2)
256+
# test_interface(alg[2], prob1, prob2; T=Dual64)
257+
# https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/givens.jl#L77
258+
# ForwardDiff.Dual is a Real, not an AbstractFloat
259+
end
241260
end
242261
end
243262

@@ -246,7 +265,12 @@ end
246265
for alg in (("Default", KrylovKitJL(kwargs...)),
247266
("CG", KrylovKitJL_CG(kwargs...)),
248267
("GMRES", KrylovKitJL_GMRES(kwargs...)))
249-
@testset "$(alg[1])" begin test_interface(alg[2], prob1, prob2) end
268+
@testset "$(alg[1])" begin
269+
test_interface(alg[2], prob1, prob2)
270+
test_interface(alg[2], prob1, prob2; T=Float64x1)
271+
test_interface(alg[2], prob1, prob2; T=Float64x2)
272+
test_interface(alg[2], prob1, prob2; T=Dual64)
273+
end
250274
@test alg[2] isa KrylovKitJL
251275
end
252276
end

0 commit comments

Comments
 (0)