@@ -772,6 +772,135 @@ func.func @warpgroup_mma_128_128_64(
772772 return
773773}
774774
775+ // CHECK-LABEL: @warpgroup_mma_store(
776+ // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
777+ func.func @warpgroup_mma_store (
778+ %result1 : !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
779+ %result2 : !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
780+ %matrixD: memref <128 x128 xf32 ,3 >) {
781+ // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
782+ // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
783+ // CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
784+ // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
785+ // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
786+ // CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
787+ // CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
788+ // CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32
789+
790+ // ### Store {d0, d1} of each thread ###
791+
792+ // CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
793+ // CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]] : i32
794+ // CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]] : i32
795+ // CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
796+ // CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
797+ // CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
798+ // CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32
799+ // CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32
800+ // CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
801+ // CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32
802+ // CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32
803+ // CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
804+ // CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32
805+ // CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32
806+ // CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
807+ // CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
808+ // CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
809+ // CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
810+ // CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
811+ // CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
812+ // CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
813+ // CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
814+
815+ // ### Store {d2, d3} of each thread ###
816+
817+ // CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
818+ // CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32
819+ // CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32
820+ // CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
821+ // CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
822+ // CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
823+ // CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
824+ // CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
825+ // CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
826+ // CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
827+ // CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
828+
829+ // ### Store {d4, d5} of each thread ###
830+
831+ // CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
832+ // CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32
833+ // CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32
834+ // CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
835+ // CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
836+ // CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
837+ // CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
838+ // CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
839+ // CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
840+ // CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
841+ // CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
842+
843+ // ### Store {d6, d7} of each thread ###
844+
845+ // CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
846+ // CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32
847+ // CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32
848+ // CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
849+ // CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
850+ // CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
851+ // CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
852+ // CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
853+ // CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
854+ // CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
855+ // CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
856+
857+ // Pattern continues similarly 28x times until {... d62, d63}
858+
859+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
860+ // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
861+
862+ // ### Store {d64, d65} of each thread ###
863+
864+ // CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
865+ // CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
866+ // CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
867+ // CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
868+ // CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
869+ // CHECK: %[[WS2:.+]] = llvm.mlir.constant(32 : i32) : i32
870+ // CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
871+ // CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WS2]] : i32
872+ // CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WS2]] : i32
873+ // CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32
874+ // CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32
875+ // CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32
876+ // CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
877+ // CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
878+ // CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
879+ // CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32
880+ // CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
881+ // CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32
882+ // CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32
883+ // CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
884+ // CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32
885+ // CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32
886+ // CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
887+ // CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
888+ // CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
889+ // CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
890+ // CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
891+ // CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
892+ // CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
893+ // CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
894+
895+ // Pattern continues similarly 31x times until {... d126, d127}
896+
897+ nvgpu.warpgroup.mma.store [%result1 , %result2 ], %matrixD :
898+ !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
899+ !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>
900+ to memref <128 x128 xf32 ,3 >
901+ return
902+ }
903+
775904transform.sequence failures (propagate ) {
776905^bb1 (%arg1: !transform.any_op ):
777906 %0 = transform.structured.match ops {[" func.func" ]} in %arg1
0 commit comments