@@ -293,15 +293,24 @@ true
293293@inline mul! (C:: AbstractMatrix , A:: AbstractVecOrMat , B:: AbstractVecOrMat , α:: Number , β:: Number ) = _mul! (C, A, B, α, β)
294294# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
295295@inline _mul! (C:: AbstractMatrix , A:: AbstractVecOrMat , B:: AbstractVecOrMat , α:: Number , β:: Number ) =
296- generic_matmatmul ! (
296+ generic_matmatmul_wrapper ! (
297297 C,
298298 wrapper_char (A),
299299 wrapper_char (B),
300300 _unwrap (A),
301301 _unwrap (B),
302- α, β
302+ α, β,
303+ Val (wrapper_char_NTC (A) & wrapper_char_NTC (B))
303304 )
304305
306+ # this indirection allows is to specialize on the types of the wrappers of A and B to some extent,
307+ # even though the wrappers are stripped off in mul!
308+ # By default, we ignore the wrapper info and forward the arguments to generic_matmatmul!
309+ Base. @constprop :aggressive function generic_matmatmul_wrapper! (C, tA, tB, A, B, α, β, @nospecialize (val))
310+ generic_matmatmul! (C, tA, tB, A, B, α, β)
311+ end
312+
313+
305314"""
306315 rmul!(A, B)
307316
@@ -368,9 +377,9 @@ julia> lmul!(F.Q, B)
368377"""
369378lmul! (A, B)
370379
371- # THE one big BLAS dispatch
372- Base. @constprop :aggressive function generic_matmatmul ! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
373- α:: Number , β:: Number ) where {T<: BlasFloat }
380+ # THE one big BLAS dispatch. This is split into two methods to improve latency
381+ Base. @constprop :aggressive function generic_matmatmul_wrapper ! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
382+ α:: Number , β:: Number , :: Val{true} ) where {T<: BlasFloat }
374383 mA, nA = lapack_size (tA, A)
375384 mB, nB = lapack_size (tB, B)
376385 if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
@@ -389,19 +398,37 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA,
389398 # and extract the char corresponding to the wrapper type
390399 tA_uc, tB_uc = uppercase (tA), uppercase (tB)
391400 # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
392- if all (map (in ((' N' , ' T' , ' C' )), (tA_uc, tB_uc)))
393- if tA_uc == ' T' && tB_uc == ' N' && A === B
394- return syrk_wrapper! (C, ' T' , A, α, β)
395- elseif tA_uc == ' N' && tB_uc == ' T' && A === B
396- return syrk_wrapper! (C, ' N' , A, α, β)
397- elseif tA_uc == ' C' && tB_uc == ' N' && A === B
398- return herk_wrapper! (C, ' C' , A, α, β)
399- elseif tA_uc == ' N' && tB_uc == ' C' && A === B
400- return herk_wrapper! (C, ' N' , A, α, β)
401- else
402- return gemm_wrapper! (C, tA, tB, A, B, α, β)
401+ if tA_uc == ' T' && tB_uc == ' N' && A === B
402+ return syrk_wrapper! (C, ' T' , A, α, β)
403+ elseif tA_uc == ' N' && tB_uc == ' T' && A === B
404+ return syrk_wrapper! (C, ' N' , A, α, β)
405+ elseif tA_uc == ' C' && tB_uc == ' N' && A === B
406+ return herk_wrapper! (C, ' C' , A, α, β)
407+ elseif tA_uc == ' N' && tB_uc == ' C' && A === B
408+ return herk_wrapper! (C, ' N' , A, α, β)
409+ else
410+ return gemm_wrapper! (C, tA, tB, A, B, α, β)
411+ end
412+ end
413+ Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
414+ α:: Number , β:: Number , :: Val{false} ) where {T<: BlasFloat }
415+ mA, nA = lapack_size (tA, A)
416+ mB, nB = lapack_size (tB, B)
417+ if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
418+ if size (C) != (mA, nB)
419+ throw (DimensionMismatch (lazy " C has dimensions $(size(C)), should have ($mA,$nB)" ))
403420 end
421+ return _rmul_or_fill! (C, β)
422+ end
423+ if size (C) == size (A) == size (B) == (2 ,2 )
424+ return matmul2x2! (C, tA, tB, A, B, α, β)
425+ end
426+ if size (C) == size (A) == size (B) == (3 ,3 )
427+ return matmul3x3! (C, tA, tB, A, B, α, β)
404428 end
429+ # We convert the chars to uppercase to potentially unwrap a WrapperChar,
430+ # and extract the char corresponding to the wrapper type
431+ tA_uc, tB_uc = uppercase (tA), uppercase (tB)
405432 alpha, beta = promote (α, β, zero (T))
406433 if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
407434 if tA_uc == ' S' && tB_uc == ' N'
@@ -421,18 +448,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S
421448 _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
422449 generic_matmatmul! (C, tA, tB, A, B, _add. alpha, _add. beta)
423450
424- # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
425- Base. @constprop :aggressive function generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
426- α:: Number , β:: Number ) where {T<: BlasReal }
427- # We convert the chars to uppercase to potentially unwrap a WrapperChar,
428- # and extract the char corresponding to the wrapper type
429- tA_uc, tB_uc = uppercase (tA), uppercase (tB)
430- # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
431- if all (map (in ((' N' , ' T' , ' C' )), (tA_uc, tB_uc)))
432- gemm_wrapper! (C, tA, tB, A, B, α, β)
433- else
434- _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (α, β))
435- end
451+ function generic_matmatmul_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
452+ α:: Number , β:: Number , :: Val{true} ) where {T<: BlasReal }
453+ gemm_wrapper! (C, tA, tB, A, B, α, β)
454+ end
455+ Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
456+ α:: Number , β:: Number , :: Val{false} ) where {T<: BlasReal }
457+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (α, β))
436458end
437459# legacy method
438460Base. @constprop :aggressive generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
0 commit comments