Skip to content

Commit 0bcd287

Browse files
jishnubDilumAluthge
authored andcommitted
Split generic_matmul for strided matrices into two halves (#54552)
1 parent ba6d190 commit 0bcd287

File tree

2 files changed

+54
-28
lines changed

2 files changed

+54
-28
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,10 @@ wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U')
576576
wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U')
577577
wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U')
578578

579+
wrapper_char_NTC(A::AbstractArray) = uppercase(wrapper_char(A)) == 'N'
580+
wrapper_char_NTC(A::Union{StridedArray, Adjoint, Transpose}) = true
581+
wrapper_char_NTC(A::Union{Symmetric, Hermitian}) = false
582+
579583
Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar)
580584
# merge the result of this before return, so that we can type-assert the return such
581585
# that even if the tmerge is inaccurate, inference can still identify that the

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,24 @@ true
293293
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β)
294294
# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
295295
@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
296-
generic_matmatmul!(
296+
generic_matmatmul_wrapper!(
297297
C,
298298
wrapper_char(A),
299299
wrapper_char(B),
300300
_unwrap(A),
301301
_unwrap(B),
302-
α, β
302+
α, β,
303+
Val(wrapper_char_NTC(A) & wrapper_char_NTC(B))
303304
)
304305

306+
# this indirection allows is to specialize on the types of the wrappers of A and B to some extent,
307+
# even though the wrappers are stripped off in mul!
308+
# By default, we ignore the wrapper info and forward the arguments to generic_matmatmul!
309+
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, @nospecialize(val))
310+
generic_matmatmul!(C, tA, tB, A, B, α, β)
311+
end
312+
313+
305314
"""
306315
rmul!(A, B)
307316
@@ -368,9 +377,9 @@ julia> lmul!(F.Q, B)
368377
"""
369378
lmul!(A, B)
370379

371-
# THE one big BLAS dispatch
372-
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
373-
α::Number, β::Number) where {T<:BlasFloat}
380+
# THE one big BLAS dispatch. This is split into two methods to improve latency
381+
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
382+
α::Number, β::Number, ::Val{true}) where {T<:BlasFloat}
374383
mA, nA = lapack_size(tA, A)
375384
mB, nB = lapack_size(tB, B)
376385
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
@@ -389,19 +398,37 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA,
389398
# and extract the char corresponding to the wrapper type
390399
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
391400
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
392-
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
393-
if tA_uc == 'T' && tB_uc == 'N' && A === B
394-
return syrk_wrapper!(C, 'T', A, α, β)
395-
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
396-
return syrk_wrapper!(C, 'N', A, α, β)
397-
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
398-
return herk_wrapper!(C, 'C', A, α, β)
399-
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
400-
return herk_wrapper!(C, 'N', A, α, β)
401-
else
402-
return gemm_wrapper!(C, tA, tB, A, B, α, β)
401+
if tA_uc == 'T' && tB_uc == 'N' && A === B
402+
return syrk_wrapper!(C, 'T', A, α, β)
403+
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
404+
return syrk_wrapper!(C, 'N', A, α, β)
405+
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
406+
return herk_wrapper!(C, 'C', A, α, β)
407+
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
408+
return herk_wrapper!(C, 'N', A, α, β)
409+
else
410+
return gemm_wrapper!(C, tA, tB, A, B, α, β)
411+
end
412+
end
413+
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
414+
α::Number, β::Number, ::Val{false}) where {T<:BlasFloat}
415+
mA, nA = lapack_size(tA, A)
416+
mB, nB = lapack_size(tB, B)
417+
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
418+
if size(C) != (mA, nB)
419+
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
403420
end
421+
return _rmul_or_fill!(C, β)
422+
end
423+
if size(C) == size(A) == size(B) == (2,2)
424+
return matmul2x2!(C, tA, tB, A, B, α, β)
425+
end
426+
if size(C) == size(A) == size(B) == (3,3)
427+
return matmul3x3!(C, tA, tB, A, B, α, β)
404428
end
429+
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
430+
# and extract the char corresponding to the wrapper type
431+
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
405432
alpha, beta = promote(α, β, zero(T))
406433
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
407434
if tA_uc == 'S' && tB_uc == 'N'
@@ -421,18 +448,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S
421448
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
422449
generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
423450

424-
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
425-
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
426-
α::Number, β::Number) where {T<:BlasReal}
427-
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
428-
# and extract the char corresponding to the wrapper type
429-
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
430-
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
431-
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
432-
gemm_wrapper!(C, tA, tB, A, B, α, β)
433-
else
434-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
435-
end
451+
function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
452+
α::Number, β::Number, ::Val{true}) where {T<:BlasReal}
453+
gemm_wrapper!(C, tA, tB, A, B, α, β)
454+
end
455+
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
456+
α::Number, β::Number, ::Val{false}) where {T<:BlasReal}
457+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
436458
end
437459
# legacy method
438460
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},

0 commit comments

Comments
 (0)