Skip to content

Commit 6928846

Browse files
authored
Reland "Support broadcasting over structured block matrices #53909" (#54460)
This was reverted in JuliaLang/julia#54332. This needs JuliaLang/julia#54459 for the tests to pass. Opening this now to not forget about it.
1 parent 0204ec1 commit 6928846

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

src/structuredbroadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
88
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
99
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()
1010

11-
const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
12-
for ST in Base.uniontypes(StructuredMatrix)
11+
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12+
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
1313
@eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}())
1414
end
1515

@@ -133,6 +133,7 @@ fails as `zero(::Tuple{Int})` is not defined. However,
133133
iszerodefined(::Type) = false
134134
iszerodefined(::Type{<:Number}) = true
135135
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)
136+
iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T)
136137

137138
count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0)
138139

@@ -160,6 +161,7 @@ fzero(::Type{T}) where T = Some(T)
160161
fzero(r::Ref) = Some(r[])
161162
fzero(t::Tuple{Any}) = Some(only(t))
162163
fzero(S::StructuredMatrix) = Some(zero(eltype(S)))
164+
fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = Some(haszero(T) ? zero(T)*I : nothing)
163165
fzero(x) = nothing
164166
function fzero(bc::Broadcast.Broadcasted)
165167
args = map(fzero, bc.args)

test/special.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Random.seed!(1)
111111
struct TypeWithZero end
112112
Base.promote_rule(::Type{TypeWithoutZero}, ::Type{TypeWithZero}) = TypeWithZero
113113
Base.convert(::Type{TypeWithZero}, ::TypeWithoutZero) = TypeWithZero()
114+
Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x))
114115
Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero()
115116
LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero()
116117
LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero

test/structuredbroadcast.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,4 +307,62 @@ end
307307
@test select_first.(missing, diag) isa Matrix{Missing}
308308
end
309309

310+
@testset "broadcast over structured matrices with matrix elements" begin
311+
function standardbroadcastingtests(D, T)
312+
M = [x for x in D]
313+
Dsum = D .+ D
314+
@test Dsum isa T
315+
@test Dsum == M .+ M
316+
Dcopy = copy.(D)
317+
@test Dcopy isa T
318+
@test Dcopy == D
319+
Df = float.(D)
320+
@test Df isa T
321+
@test Df == D
322+
@test eltype(eltype(Df)) <: AbstractFloat
323+
@test (x -> (x,)).(D) == (x -> (x,)).(M)
324+
@test (x -> 1).(D) == ones(Int,size(D))
325+
@test all(==(2), ndims.(D))
326+
@test_throws MethodError size.(D)
327+
end
328+
@testset "Diagonal" begin
329+
@testset "square" begin
330+
A = [1 3; 2 4]
331+
D = Diagonal([A, A])
332+
standardbroadcastingtests(D, Diagonal)
333+
@test sincos.(D) == sincos.(Matrix{eltype(D)}(D))
334+
M = [x for x in D]
335+
@test cos.(D) == cos.(M)
336+
end
337+
338+
@testset "different-sized square blocks" begin
339+
D = Diagonal([ones(3,3), fill(3.0,2,2)])
340+
standardbroadcastingtests(D, Diagonal)
341+
end
342+
343+
@testset "rectangular blocks" begin
344+
D = Diagonal([ones(Bool,3,4), ones(Bool,2,3)])
345+
standardbroadcastingtests(D, Diagonal)
346+
end
347+
348+
@testset "incompatible sizes" begin
349+
A = reshape(1:12, 4, 3)
350+
B = reshape(1:12, 3, 4)
351+
D1 = Diagonal(fill(A, 2))
352+
D2 = Diagonal(fill(B, 2))
353+
@test_throws DimensionMismatch D1 .+ D2
354+
end
355+
end
356+
@testset "Bidiagonal" begin
357+
A = [1 3; 2 4]
358+
B = Bidiagonal(fill(A,3), fill(A,2), :U)
359+
standardbroadcastingtests(B, Bidiagonal)
360+
end
361+
@testset "UpperTriangular" begin
362+
A = [1 3; 2 4]
363+
U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3])
364+
standardbroadcastingtests(U, UpperTriangular)
365+
end
366+
end
367+
310368
end

0 commit comments

Comments
 (0)