From c22f21453057b5937e8371edf28c1f24b0a36267 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 18 Aug 2024 16:53:56 +0530 Subject: [PATCH] Fix tr for Symmetric/Hermitian block matrices --- stdlib/LinearAlgebra/src/symmetric.jl | 4 ++-- stdlib/LinearAlgebra/test/symmetric.jl | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 55630595f6fb2..c336785792588 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -449,8 +449,8 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) = Base.copy(A::Transpose{<:Any,<:Hermitian}) = Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U)) -tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations) -tr(A::Hermitian) = real(tr(A.data)) +tr(A::Symmetric{<:Number}) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations) +tr(A::Hermitian{<:Number}) = real(tr(A.data)) Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo)) Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo)) diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 89e9ca0d6a51d..5f1293ab2cdd7 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -1116,4 +1116,15 @@ end end end +@testset "tr for block matrices" begin + m = [1 2; 3 4] + for b in (m, m * (1 + im)) + M = fill(b, 3, 3) + for ST in (Symmetric, Hermitian) + S = ST(M) + @test tr(S) == sum(diag(S)) + end + end +end + end # module TestSymmetric