Skip to content

Commit 32c1435

Browse files
committed
Fix qr(A)'\b
1 parent d73ea16 commit 32c1435

File tree

3 files changed

+102
-34
lines changed

3 files changed

+102
-34
lines changed

stdlib/LinearAlgebra/src/factorization.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,26 +96,44 @@ function (/)(B::VecOrMat{Complex{T}}, F::Factorization{T}) where T<:BlasReal
9696
return copy(reinterpret(Complex{T}, x))
9797
end
9898

99-
function \(F::Factorization, B::AbstractVecOrMat)
99+
# convenience methods
100+
## return only the solution of a least squares problem while avoiding promoting
101+
## vectors to matrices.
102+
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
103+
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X
104+
105+
## append right hand side with zeros if necessary
106+
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
107+
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))
108+
109+
# General fallback definition for handling under- and overdetermined system as well as square problems
110+
function \(adjF::Union{<:Factorization,Adjoint{<:Any,<:Factorization}}, B::AbstractVecOrMat)
100111
require_one_based_indexing(B)
101-
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
102-
BB = similar(B, TFB, size(B))
103-
copyto!(BB, B)
104-
ldiv!(F, BB)
105-
end
106-
_rows(b::AbstractVector, r::AbstractVector) = b[r]
107-
_rows(B::AbstractVector, r::AbstractMatrix) = B[r, :]
108-
function \(adjF::Adjoint{<:Any,<:Factorization}, B::AbstractVecOrMat)
109-
require_one_based_indexing(B)
110-
F = adjF.parent
111-
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
112-
BB = similar(B, TFB, size(B))
113-
copyto!(BB, B)
114-
ldiv!(adjoint(F), BB)
112+
m, n = size(adjF)
113+
if m != size(B, 1)
114+
throw(DimensionMismatch("arguments must have the same number of rows"))
115+
end
116+
# F = adjF.parent
117+
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(adjF)))
118+
119+
# For wide problem we (often) compute a minimum norm solution. The solution
120+
# is larger than the right hand side so we use size(adjF, 2).
121+
BB = _zeros(TFB, B, n)
122+
123+
if n > size(B, 1)
124+
# Underdetermined
125+
fill!(BB, 0)
126+
copyto!(view(BB, 1:m, :), B)
127+
else
128+
copyto!(BB, B)
129+
end
130+
131+
ldiv!(adjF, BB)
132+
115133
# For tall problems, we compute a least sqaures solution so only part
116134
# of the rhs should be returned from \ while ldiv! uses (and returns)
117135
# the complete rhs
118-
return (>)(size(adjF)...) ? _rows(BB, 1:size(adjF, 2)) : BB
136+
return _cut_B(BB, 1:n)
119137
end
120138

121139
function /(B::AbstractMatrix, F::Factorization)

stdlib/LinearAlgebra/src/qr.jl

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ end
465465
Base.propertynames(F::QRPivoted, private::Bool=false) =
466466
(:R, :Q, :p, :P, (private ? fieldnames(typeof(F)) : ())...)
467467

468+
adjoint(F::Union{QR,QRPivoted,QRCompactWY}) = Adjoint(F)
469+
468470
abstract type AbstractQ{T} <: AbstractMatrix{T} end
469471

470472
inv(Q::AbstractQ) = Q'
@@ -932,28 +934,35 @@ function ldiv!(A::QRPivoted, B::StridedMatrix)
932934
B
933935
end
934936

935-
# convenience methods
936-
## return only the solution of a least squares problem while avoiding promoting
937-
## vectors to matrices.
938-
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
939-
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X
940-
941-
## append right hand side with zeros if necessary
942-
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
943-
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))
937+
function _apply_permutation!(F::QRPivoted, B::AbstractVecOrMat)
938+
# Apply permutation but only to the top part of the solution vector since
939+
# it's padded with zeros for underdetermined problems
940+
B[1:length(F.p), :] = B[F.p, :]
941+
return B
942+
end
943+
_apply_permutation!(F::Factorization, B::AbstractVecOrMat) = B
944944

945-
function (\)(A::Union{QR{TA},QRCompactWY{TA},QRPivoted{TA}}, B::AbstractVecOrMat{TB}) where {TA,TB}
945+
function ldiv!(Fadj::Adjoint{<:Any,<:Union{QR,QRCompactWY,QRPivoted}}, B::AbstractVecOrMat)
946946
require_one_based_indexing(B)
947-
S = promote_type(TA,TB)
948-
m, n = size(A)
949-
m == size(B,1) || throw(DimensionMismatch("Both inputs should have the same number of rows"))
947+
m, n = size(Fadj)
950948

951-
AA = Factorization{S}(A)
949+
# We don't allow solutions overdetermined systems. It would at least be
950+
if m > n
951+
throw(DimensionMismatch("overdetermined systems are not supported"))
952+
end
953+
if n != size(B, 1)
954+
throw(DimensionMismatch("inputs should have the same number of rows"))
955+
end
956+
F = parent(Fadj)
952957

953-
X = _zeros(S, B, n)
954-
X[1:size(B, 1), :] = B
955-
ldiv!(AA, X)
956-
return _cut_B(X, 1:n)
958+
B = _apply_permutation!(F, B)
959+
960+
# For underdetermined system, the triangular solve should only be applied to the top
961+
# part of B that contains the rhs. For square problems, the view corresponds to B itself
962+
ldiv!(LowerTriangular(adjoint(F.R)), view(B, 1:size(F.R, 2), :))
963+
lmul!(F.Q, B)
964+
965+
return B
957966
end
958967

959968
# With a real lhs and complex rhs with the same precision, we can reinterpret the complex

stdlib/LinearAlgebra/test/qr.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,4 +371,45 @@ end
371371
end
372372
end
373373

374+
@testset "adjoint of QR" begin
375+
n = 5
376+
B = randn(5, 2)
377+
378+
@testset "size(b)=$(size(b))" for b in (B[:, 1], B)
379+
@testset "size(A)=$(size(A))" for A in (
380+
randn(n, n),
381+
# Wide problems become minimum norm (in x) problems similarly to LQ
382+
randn(n + 2, n),
383+
complex.(randn(n, n), randn(n, n)))
384+
385+
@testset "QRCompactWY" begin
386+
F = qr(A)
387+
x = F'\b
388+
@test x A'\b
389+
@test length(size(x)) == length(size(b))
390+
end
391+
392+
@testset "QR" begin
393+
F = LinearAlgebra.qrfactUnblocked!(copy(A))
394+
x = F'\b
395+
@test x A'\b
396+
@test length(size(x)) == length(size(b))
397+
end
398+
399+
@testset "QRPivoted" begin
400+
F = LinearAlgebra.qr(A, Val(true))
401+
x = F'\b
402+
@test x A'\b
403+
@test length(size(x)) == length(size(b))
404+
end
405+
end
406+
@test_throws DimensionMismatch("overdetermined systems are not supported") qr(randn(n - 2, n))'\b
407+
@test_throws DimensionMismatch("arguments must have the same number of rows") qr(randn(n, n + 1))'\b
408+
@test_throws DimensionMismatch("overdetermined systems are not supported") LinearAlgebra.qrfactUnblocked!(randn(n - 2, n))'\b
409+
@test_throws DimensionMismatch("arguments must have the same number of rows") LinearAlgebra.qrfactUnblocked!(randn(n, n + 1))'\b
410+
@test_throws DimensionMismatch("overdetermined systems are not supported") qr(randn(n - 2, n), Val(true))'\b
411+
@test_throws DimensionMismatch("arguments must have the same number of rows") qr(randn(n, n + 1), Val(true))'\b
412+
end
413+
end
414+
374415
end # module TestQR

0 commit comments

Comments
 (0)