Skip to content

Commit d749f0e

Browse files
authored
Fix zero elements for block-matrix kron involving Diagonal (#55941)
Currently, it's assumed that the zero element is identical for the matrix, but this is not necessary if the elements are matrices themselves and have different sizes. This PR ensures that `kron` for a `Diagonal` has the correct zero elements. Current: ```julia julia> D = Diagonal(1:2) 2×2 Diagonal{Int64, UnitRange{Int64}}: 1 ⋅ ⋅ 2 julia> B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2); julia> size.(kron(D, B)) 4×4 Matrix{Tuple{Int64, Int64}}: (2, 2) (2, 3) (2, 2) (2, 2) (3, 2) (3, 3) (2, 2) (2, 2) (2, 2) (2, 2) (2, 2) (2, 3) (2, 2) (2, 2) (3, 2) (3, 3) ``` This PR ```julia julia> size.(kron(D, B)) 4×4 Matrix{Tuple{Int64, Int64}}: (2, 2) (2, 3) (2, 2) (2, 3) (3, 2) (3, 3) (3, 2) (3, 3) (2, 2) (2, 3) (2, 2) (2, 3) (3, 2) (3, 3) (3, 2) (3, 3) ``` Note the differences e.g. in the `CartesianIndex(4,1)`, `CartesianIndex(3,2)` and `CartesianIndex(3,3)` elements.
1 parent 3b3a70f commit d749f0e

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -686,16 +686,33 @@ for Tri in (:UpperTriangular, :LowerTriangular)
686686
end
687687

688688
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
689-
valA = A.diag; nA = length(valA)
690-
valB = B.diag; nB = length(valB)
689+
valA = A.diag; mA, nA = size(A)
690+
valB = B.diag; mB, nB = size(B)
691691
nC = checksquare(C)
692692
@boundscheck nC == nA*nB ||
693693
throw(DimensionMismatch(lazy"expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))
694-
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
694+
zerofilled = false
695+
if !(isempty(A) || isempty(B))
696+
z = A[1,1] * B[1,1]
697+
if haszero(typeof(z))
698+
# in this case, the zero is unique
699+
fill!(C, zero(z))
700+
zerofilled = true
701+
end
702+
end
695703
@inbounds for i = 1:nA, j = 1:nB
696704
idx = (i-1)*nB+j
697705
C[idx, idx] = valA[i] * valB[j]
698706
end
707+
if !zerofilled
708+
for j in 1:nA, i in 1:mA
709+
Δrow, Δcol = (i-1)*mB, (j-1)*nB
710+
for k in 1:nB, l in 1:mB
711+
i == j && k == l && continue
712+
C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
713+
end
714+
end
715+
end
699716
return C
700717
end
701718

@@ -722,7 +739,15 @@ end
722739
(mC, nC) = size(C)
723740
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
724741
throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
725-
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
742+
zerofilled = false
743+
if !(isempty(A) || isempty(B))
744+
z = A[1,1] * B[1,1]
745+
if haszero(typeof(z))
746+
# in this case, the zero is unique
747+
fill!(C, zero(z))
748+
zerofilled = true
749+
end
750+
end
726751
m = 1
727752
@inbounds for j = 1:nA
728753
A_jj = A[j,j]
@@ -733,6 +758,18 @@ end
733758
end
734759
m += (nA - 1) * mB
735760
end
761+
if !zerofilled
762+
# populate the zero elements
763+
for i in 1:mA
764+
i == j && continue
765+
A_ij = A[i, j]
766+
Δrow, Δcol = (i-1)*mB, (j-1)*nB
767+
for k in 1:nB, l in 1:nA
768+
B_lk = B[l, k]
769+
C[Δrow + l, Δcol + k] = A_ij * B_lk
770+
end
771+
end
772+
end
736773
m += mB
737774
end
738775
return C
@@ -745,17 +782,36 @@ end
745782
(mC, nC) = size(C)
746783
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
747784
throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
748-
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
785+
zerofilled = false
786+
if !(isempty(A) || isempty(B))
787+
z = A[1,1] * B[1,1]
788+
if haszero(typeof(z))
789+
# in this case, the zero is unique
790+
fill!(C, zero(z))
791+
zerofilled = true
792+
end
793+
end
749794
m = 1
750795
@inbounds for j = 1:nA
751796
for l = 1:mB
752797
Bll = B[l,l]
753-
for k = 1:mA
754-
C[m] = A[k,j] * Bll
798+
for i = 1:mA
799+
C[m] = A[i,j] * Bll
755800
m += nB
756801
end
757802
m += 1
758803
end
804+
if !zerofilled
805+
for i in 1:mA
806+
A_ij = A[i, j]
807+
Δrow, Δcol = (i-1)*mB, (j-1)*nB
808+
for k in 1:nB, l in 1:mB
809+
l == k && continue
810+
B_lk = B[l, k]
811+
C[Δrow + l, Δcol + k] = A_ij * B_lk
812+
end
813+
end
814+
end
759815
m -= nB
760816
end
761817
return C

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,4 +1391,14 @@ end
13911391
@test checkbounds(Bool, D, diagind(D, IndexCartesian()))
13921392
end
13931393

1394+
@testset "zeros in kron with block matrices" begin
1395+
D = Diagonal(1:2)
1396+
B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2)
1397+
@test kron(D, B) == kron(Array(D), B)
1398+
@test kron(B, D) == kron(B, Array(D))
1399+
D2 = Diagonal([ones(2,2), ones(3,3)])
1400+
@test kron(D, D2) == kron(D, Array{eltype(D2)}(D2))
1401+
@test kron(D2, D) == kron(Array{eltype(D2)}(D2), D)
1402+
end
1403+
13941404
end # module TestDiagonal

0 commit comments

Comments
 (0)