Skip to content

Commit 243ebc3

Browse files
authored
Support broadcasting over structured block matrices (#53909)
Fix https://github.com/JuliaLang/julia/issues/48664 After this, broadcasting over structured block matrices with matrix-valued elements works: ```julia julia> D = Diagonal([[1 2; 3 4], [5 6; 7 8]]) 2×2 Diagonal{Matrix{Int64}, Vector{Matrix{Int64}}}: [1 2; 3 4] ⋅ ⋅ [5 6; 7 8] julia> D .+ D 2×2 Diagonal{Matrix{Int64}, Vector{Matrix{Int64}}}: [2 4; 6 8] ⋅ ⋅ [10 12; 14 16] julia> cos.(D) 2×2 Matrix{Matrix{Float64}}: [0.855423 -0.110876; -0.166315 0.689109] [1.0 0.0; 0.0 1.0] [1.0 0.0; 0.0 1.0] [0.928384 -0.069963; -0.0816235 0.893403] ``` Such operations show up when using `BlockArrays`. The implementation is a bit hacky as it uses `0I` as the zero element in `fzero`, which isn't really the correct zero if the blocks are rectangular. Nonetheless, this works, as `fzero` is only used to determine if the structure is preserved.
1 parent 1febcd6 commit 243ebc3

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

stdlib/LinearAlgebra/src/structuredbroadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
88
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
99
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()
1010

11-
const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
12-
for ST in Base.uniontypes(StructuredMatrix)
11+
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12+
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
1313
@eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}())
1414
end
1515

@@ -133,6 +133,7 @@ fails as `zero(::Tuple{Int})` is not defined. However,
133133
iszerodefined(::Type) = false
134134
iszerodefined(::Type{<:Number}) = true
135135
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)
136+
iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T)
136137

137138
fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0))
138139
# Like sparse matrices, we assume that the zero-preservation property of a broadcasted
@@ -144,6 +145,7 @@ fzero(::Type{T}) where T = T
144145
fzero(r::Ref) = r[]
145146
fzero(t::Tuple{Any}) = t[1]
146147
fzero(S::StructuredMatrix) = zero(eltype(S))
148+
fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = haszero(T) ? zero(T)*I : missing
147149
fzero(x) = missing
148150
function fzero(bc::Broadcast.Broadcasted)
149151
args = map(fzero, bc.args)

stdlib/LinearAlgebra/test/special.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Random.seed!(1)
111111
struct TypeWithZero end
112112
Base.promote_rule(::Type{TypeWithoutZero}, ::Type{TypeWithZero}) = TypeWithZero
113113
Base.convert(::Type{TypeWithZero}, ::TypeWithoutZero) = TypeWithZero()
114+
Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x))
114115
Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero()
115116
LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero()
116117
LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero

stdlib/LinearAlgebra/test/structuredbroadcast.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,62 @@ end
280280
# structured broadcast with function returning non-number type
281281
@test tuple.(Diagonal([1, 2])) == [(1,) (0,); (0,) (2,)]
282282

283+
@testset "broadcast over structured matrices with matrix elements" begin
284+
function standardbroadcastingtests(D, T)
285+
M = [x for x in D]
286+
Dsum = D .+ D
287+
@test Dsum isa T
288+
@test Dsum == M .+ M
289+
Dcopy = copy.(D)
290+
@test Dcopy isa T
291+
@test Dcopy == D
292+
Df = float.(D)
293+
@test Df isa T
294+
@test Df == D
295+
@test eltype(eltype(Df)) <: AbstractFloat
296+
@test (x -> (x,)).(D) == (x -> (x,)).(M)
297+
@test (x -> 1).(D) == ones(Int,size(D))
298+
@test all(==(2), ndims.(D))
299+
@test_throws MethodError size.(D)
300+
end
301+
@testset "Diagonal" begin
302+
@testset "square" begin
303+
A = [1 3; 2 4]
304+
D = Diagonal([A, A])
305+
standardbroadcastingtests(D, Diagonal)
306+
@test sincos.(D) == sincos.(Matrix{eltype(D)}(D))
307+
M = [x for x in D]
308+
@test cos.(D) == cos.(M)
309+
end
310+
311+
@testset "different-sized square blocks" begin
312+
D = Diagonal([ones(3,3), fill(3.0,2,2)])
313+
standardbroadcastingtests(D, Diagonal)
314+
end
315+
316+
@testset "rectangular blocks" begin
317+
D = Diagonal([ones(Bool,3,4), ones(Bool,2,3)])
318+
standardbroadcastingtests(D, Diagonal)
319+
end
320+
321+
@testset "incompatible sizes" begin
322+
A = reshape(1:12, 4, 3)
323+
B = reshape(1:12, 3, 4)
324+
D1 = Diagonal(fill(A, 2))
325+
D2 = Diagonal(fill(B, 2))
326+
@test_throws DimensionMismatch D1 .+ D2
327+
end
328+
end
329+
@testset "Bidiagonal" begin
330+
A = [1 3; 2 4]
331+
B = Bidiagonal(fill(A,3), fill(A,2), :U)
332+
standardbroadcastingtests(B, Bidiagonal)
333+
end
334+
@testset "UpperTriangular" begin
335+
A = [1 3; 2 4]
336+
U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3])
337+
standardbroadcastingtests(U, UpperTriangular)
338+
end
339+
end
340+
283341
end

0 commit comments

Comments
 (0)