Skip to content

Commit 7d55f9c

Browse files
dkarraschFrancesco Fucci
authored andcommitted
Complete size checks in BLAS.[sy/he]mm! (JuliaLang#45605)
1 parent 53ba43e commit 7d55f9c

File tree

3 files changed

+56
-10
lines changed

3 files changed

+56
-10
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,11 +1540,27 @@ for (mfname, elty) in ((:dsymm_,:Float64),
15401540
require_one_based_indexing(A, B, C)
15411541
m, n = size(C)
15421542
j = checksquare(A)
1543-
if j != (side == 'L' ? m : n)
1544-
throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)"))
1545-
end
1546-
if size(B,2) != n
1547-
throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
1543+
M, N = size(B)
1544+
if side == 'L'
1545+
if j != m
1546+
throw(DimensionMismatch(lazy"A has first dimension $j but needs to match first dimension of C, $m"))
1547+
end
1548+
if N != n
1549+
throw(DimensionMismatch(lazy"B has second dimension $N but needs to match second dimension of C, $n"))
1550+
end
1551+
if j != M
1552+
throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $M"))
1553+
end
1554+
else
1555+
if j != n
1556+
throw(DimensionMismatch(lazy"B has second dimension $j but needs to match second dimension of C, $n"))
1557+
end
1558+
if N != j
1559+
throw(DimensionMismatch(lazy"A has second dimension $N but needs to match first dimension of B, $j"))
1560+
end
1561+
if M != m
1562+
throw(DimensionMismatch(lazy"A has first dimension $M but needs to match first dimension of C, $m"))
1563+
end
15481564
end
15491565
chkstride1(A)
15501566
chkstride1(B)
@@ -1614,11 +1630,27 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64),
16141630
require_one_based_indexing(A, B, C)
16151631
m, n = size(C)
16161632
j = checksquare(A)
1617-
if j != (side == 'L' ? m : n)
1618-
throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)"))
1619-
end
1620-
if size(B,2) != n
1621-
throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
1633+
M, N = size(B)
1634+
if side == 'L'
1635+
if j != m
1636+
throw(DimensionMismatch(lazy"A has first dimension $j but needs to match first dimension of C, $m"))
1637+
end
1638+
if N != n
1639+
throw(DimensionMismatch(lazy"B has second dimension $N but needs to match second dimension of C, $n"))
1640+
end
1641+
if j != M
1642+
throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $M"))
1643+
end
1644+
else
1645+
if j != n
1646+
throw(DimensionMismatch(lazy"B has second dimension $j but needs to match second dimension of C, $n"))
1647+
end
1648+
if N != j
1649+
throw(DimensionMismatch(lazy"A has second dimension $N but needs to match first dimension of B, $j"))
1650+
end
1651+
if M != m
1652+
throw(DimensionMismatch(lazy"A has first dimension $M but needs to match first dimension of C, $m"))
1653+
end
16221654
end
16231655
chkstride1(A)
16241656
chkstride1(B)

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,19 @@ Random.seed!(100)
227227
@test_throws DimensionMismatch BLAS.symm('R','U',Cmn,Cnn)
228228
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cmn)
229229
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cnm)
230+
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cmn,one(elty),Cnn)
231+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnm,one(elty),Cmn)
232+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnn,one(elty),Cnm)
233+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cmn,one(elty),Cnn)
230234
if elty <: BlasComplex
231235
@test_throws DimensionMismatch BLAS.hemm('L','U',Cnm,Cnn)
232236
@test_throws DimensionMismatch BLAS.hemm('R','U',Cmn,Cnn)
233237
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cmn)
234238
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cnm)
239+
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cmn,one(elty),Cnn)
240+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnm,one(elty),Cmn)
241+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnn,one(elty),Cnm)
242+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cmn,one(elty),Cnn)
235243
end
236244
end
237245
end

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ end
352352
C = zeros(eltya,n,n)
353353
@test Hermitian(aherm) * a aherm * a
354354
@test a * Hermitian(aherm) a * aherm
355+
# rectangular multiplication
356+
@test [a; a] * Hermitian(aherm) [a; a] * aherm
357+
@test Hermitian(aherm) * [a a] aherm * [a a]
355358
@test Hermitian(aherm) * Hermitian(aherm) aherm*aherm
356359
@test_throws DimensionMismatch Hermitian(aherm) * Vector{eltya}(undef, n+1)
357360
LinearAlgebra.mul!(C,a,Hermitian(aherm))
@@ -360,6 +363,9 @@ end
360363
@test Symmetric(asym) * Symmetric(asym) asym*asym
361364
@test Symmetric(asym) * a asym * a
362365
@test a * Symmetric(asym) a * asym
366+
# rectangular multiplication
367+
@test Symmetric(asym) * [a a] asym * [a a]
368+
@test [a; a] * Symmetric(asym) [a; a] * asym
363369
@test_throws DimensionMismatch Symmetric(asym) * Vector{eltya}(undef, n+1)
364370
LinearAlgebra.mul!(C,a,Symmetric(asym))
365371
@test C a*asym

0 commit comments

Comments
 (0)