@@ -39,7 +39,7 @@ namespace mlir {
3939
4040using namespace mlir ;
4141
42- // / Number of bits that needs to excluded when building matrix descriptor for
42+ // / Number of bits that needs to be excluded when building matrix descriptor for
4343// / wgmma operations.
4444constexpr int exclude4LSB = 4 ;
4545
@@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
11601160 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
11611161 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
11621162
1163- LogicalResult getWgmmaShape (int64_t sizeM, int64_t sizeN, Type inputElemType,
1164- int &wgmmaShapeM, int &wgmmaShapeN,
1165- int &wgmmaShapeK) const {
1166- wgmmaShapeM = 64 ;
1167- wgmmaShapeN = sizeN;
1168- if (inputElemType.isTF32 ()) {
1169- wgmmaShapeK = 8 ;
1170- } else if (inputElemType.isF16 () || inputElemType.isBF16 ()) {
1171- wgmmaShapeK = 16 ;
1172- } else if (inputElemType.isFloat8E4M3FN () || inputElemType.isFloat8E5M2 () ||
1173- inputElemType.isInteger (16 )) {
1174- wgmmaShapeK = 32 ;
1175- } else if (inputElemType.isInteger (1 )) {
1176- wgmmaShapeK = 256 ;
1177- } else {
1178- llvm_unreachable (" msg: not supported K shape" );
1163+ // / This is a helper class to generate required NVVM Ops for warp-group level
1164+ // / matrix multiplication.
1165+ // / When the given GEMM shape is larger than the shape of
1166+ // / a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1167+ // / Op(s), group and execute them asynchronously. The class also handles
1168+ // / waiting for completion and iterates through WarpgroupMatrixDescriptor to
1169+ // / create descriptors for each instruction.
1170+ // /
1171+ // / For example this is the case when the shape of GEMM is 128x128x128
1172+ // /
1173+ // / nvvm.wgmma.fence.aligned
1174+ // /
1175+ // / nvvm.wgmma.mma.async descA, descB
1176+ // / iterate(descA, descB)
1177+ // / nvvm.wgmma.mma.async descA, descB
1178+ // / [6x times more]
1179+ // /
1180+ // / nvvm.wgmma.group.sync.aligned
1181+ // / nvvm.wgmma.wait.group.sync [groupId]
1182+ // /
1183+ class WarpgroupGemm {
1184+ nvgpu::WarpgroupMmaOp op;
1185+ ImplicitLocOpBuilder b;
1186+ OpAdaptor adaptor;
1187+ const LLVMTypeConverter &typeConverter;
1188+
1189+ // Entire shape of the given Op
1190+ int64_t totalM, totalN, totalK;
1191+
1192+ // Shape of one wgmma instruction
1193+ int wgmmaM = 0 , wgmmaN = 0 , wgmmaK = 0 ;
1194+
1195+ // Iteration counts for GEMM
1196+ int iterationM = 0 , iterationN = 0 , iterationK = 0 ;
1197+
1198+ // / The function returns the shape of wgmma instruction that is defined in
1199+ // / PTX programming guide.
1200+ // / https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1201+ void findWgmmaShape (int64_t sizeM, int64_t sizeN, Type inputElemType) {
1202+ wgmmaM = 64 ;
1203+ wgmmaN = sizeN;
1204+ if (inputElemType.isTF32 ()) {
1205+ wgmmaK = 8 ;
1206+ } else if (inputElemType.isF16 () || inputElemType.isBF16 ()) {
1207+ wgmmaK = 16 ;
1208+ } else if (inputElemType.isFloat8E4M3FN () ||
1209+ inputElemType.isFloat8E5M2 () || inputElemType.isInteger (16 )) {
1210+ wgmmaK = 32 ;
1211+ } else if (inputElemType.isInteger (1 )) {
1212+ wgmmaK = 256 ;
1213+ } else {
1214+ llvm_unreachable (" msg: not supported K shape" );
1215+ }
1216+ LLVM_DEBUG (DBGS () << " Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1217+ << " , n = " << wgmmaN << " , k = " << wgmmaK << " ]\n " );
11791218 }
1180- LLVM_DEBUG (DBGS () << " Generating wgmma.mma.async shape[m = " << wgmmaShapeM
1181- << " , n = " << wgmmaShapeN << " , k = " << wgmmaShapeK
1182- << " ]\n " );
1183- return success ();
1184- }
11851219
1186- Value generateNVVMWgmmaOp (ImplicitLocOpBuilder &b, int m, int n, int k,
1187- Type resultStructType, Value inout,
1188- Value descriptorA, Value descriptorB) const {
1189- MLIRContext *ctx = b.getContext ();
1190- auto shape = NVVM::MMAShapeAttr::get (ctx, m, n, k);
1191- auto scaleOut = NVVM::WGMMAScaleOutAttr::get (ctx, NVVM::WGMMAScaleOut::one);
1192- auto scaleIn = NVVM::WGMMAScaleInAttr::get (ctx, NVVM::WGMMAScaleIn::one);
1193- auto layoutA = NVVM::MMALayoutAttr::get (ctx, NVVM::MMALayout::row);
1194- auto layoutB = NVVM::MMALayoutAttr::get (ctx, NVVM::MMALayout::col);
1195- // todo: handle other input and output types
1196- auto itype = NVVM::WGMMATypesAttr::get (ctx, NVVM::WGMMATypes::f16 );
1197- auto overflow =
1198- NVVM::MMAIntOverflowAttr::get (ctx, NVVM::MMAIntOverflow::wrapped);
1199- Value res = b.create <NVVM::WgmmaMmaAsyncOp>(
1200- resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
1201- scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1202- return res;
1203- }
1204-
1205- LogicalResult
1206- matchAndRewrite (nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1207- ConversionPatternRewriter &rewriter) const override {
1208- ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1209- int64_t sizeM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1210- int64_t sizeN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
1211- int64_t sizeK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1212-
1213- LLVM_DEBUG (DBGS () << " ===--- GEMM D[" << sizeM << " ][" << sizeN << " ] += A["
1214- << sizeM << " ][" << sizeK << " ] * B[" << sizeK << " ]["
1215- << sizeN << " ] ---===\n " );
1220+ // / Generates WGMMATypesAttr from MLIR Type
1221+ NVVM::WGMMATypesAttr generateWgmmaType (Type type) const {
1222+ auto getWgmmaType = [](Type elemType) {
1223+ if (elemType.isF32 () || elemType.isTF32 ())
1224+ return NVVM::WGMMATypes::tf32;
1225+ if (elemType.isF16 ())
1226+ return NVVM::WGMMATypes::f16 ;
1227+ if (elemType.isBF16 ())
1228+ return NVVM::WGMMATypes::bf16 ;
1229+ if (elemType.isFloat8E4M3FN ())
1230+ return NVVM::WGMMATypes::e4m3;
1231+ if (elemType.isFloat8E5M2 ())
1232+ return NVVM::WGMMATypes::e5m2;
1233+ if (elemType.isInteger (1 ))
1234+ return NVVM::WGMMATypes::b1;
1235+ if (elemType.isInteger (8 ))
1236+ return NVVM::WGMMATypes::s8;
1237+ if (elemType.isUnsignedInteger (8 ))
1238+ return NVVM::WGMMATypes::u8 ;
1239+ llvm_unreachable (" unsupported type" );
1240+ };
1241+ return NVVM::WGMMATypesAttr::get (op->getContext (), getWgmmaType (type));
1242+ }
12161243
1217- int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
1218- if (failed (getWgmmaShape (sizeM, sizeN, rewriter.getF16Type (), wgmmaShapeM,
1219- wgmmaShapeN, wgmmaShapeK))) {
1220- return failure ();
1244+ // / Generates layout attribute for the input matrix for wgmma instruction
1245+ NVVM::MMALayoutAttr
1246+ generateWgmmaLayout (std::optional<bool > transpose) const {
1247+ if (transpose.value_or (false ))
1248+ return NVVM::MMALayoutAttr::get (op->getContext (), NVVM::MMALayout::col);
1249+ return NVVM::MMALayoutAttr::get (op->getContext (), NVVM::MMALayout::row);
12211250 }
12221251
1223- Value descriptorA = adaptor.getDescriptorA ();
1224- Value descriptorB = adaptor.getDescriptorB ();
1252+ // / Generates shape attribute for wgmma instruction
1253+ NVVM::MMAShapeAttr generateWgmmaShape () const {
1254+ return NVVM::MMAShapeAttr::get (op->getContext (), wgmmaM, wgmmaN, wgmmaK);
1255+ }
12251256
1226- // Generate wgmma group
1227- MemRefType typeTensorA = op.getDescriptorA ().getType ().getTensor ();
1228- MemRefType typeTensorB = op.getDescriptorB ().getType ().getTensor ();
1257+ // / Generates scale attributes of output matrix for wgmma instruction
1258+ NVVM::WGMMAScaleOutAttr generateScaleOut () const {
1259+ return NVVM::WGMMAScaleOutAttr::get (op->getContext (),
1260+ NVVM::WGMMAScaleOut::one);
1261+ }
1262+ // / Generates scale attributes of input matrix for wgmma instruction
1263+ NVVM::WGMMAScaleInAttr generateScaleIn () const {
1264+ return NVVM::WGMMAScaleInAttr::get (op->getContext (),
1265+ NVVM::WGMMAScaleIn::one);
1266+ }
12291267
1230- auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1268+ // / Basic function to generate Add
1269+ Value makeAdd (Value lhs, Value rhs) {
12311270 return b.create <LLVM::AddOp>(lhs.getType (), lhs, rhs);
12321271 };
12331272
1234- auto iterateDescA = [&](Value desc, int iterM, int iterN,
1235- int iterK) -> Value {
1236- // todo : Handle column major
1237- int byte = typeTensorA.getElementTypeBitWidth () / 8 ;
1238- int tileShapeA = typeTensorA.getDimSize (1 );
1239- int incrementVal =
1240- ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
1273+ // / Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1274+ // / Currently, it only handles row-major.
1275+ // /
1276+ // / It moves the pointer like below for [128][64] size:
1277+ // / +2 +4 +6
1278+ // / ↓ ↓ ↓
1279+ // / descA ---> +--+--+--+--+
1280+ // / |->|->|->|->|
1281+ // / | | | | |
1282+ // / | | | | |
1283+ // / | | | | |
1284+ // / descA+512---> +-----------+
1285+ // / | | | | |
1286+ // / | | | | |
1287+ // / | | | | |
1288+ // / | | | | |
1289+ // / +-----------+
1290+ // /
1291+ Value iterateDescriptorA (Value desc, int i, int j, int k) {
1292+ MemRefType matrixTypeA = op.getDescriptorA ().getType ().getTensor ();
1293+ Type elemA = matrixTypeA.getElementType ();
1294+ int byte = elemA.getIntOrFloatBitWidth () / 8 ;
1295+ int tileShapeA = matrixTypeA.getDimSize (1 );
1296+ int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
12411297 incrementVal = incrementVal >> exclude4LSB;
1242- LLVM_DEBUG (DBGS () << " \t\t [m: " << iterM << " n: " << iterN << " k: "
1243- << iterK << " ] [wgmma descriptors] Descriptor A + "
1298+ LLVM_DEBUG (DBGS () << " \t\t [m: " << i << " n: " << j << " k: " << k
1299+ << " ] [wgmma descriptors] Descriptor A + "
12441300 << incrementVal << " | \t " );
12451301 if (!incrementVal)
12461302 return desc;
12471303 return makeAdd (desc, makeI64Const (b, incrementVal));
1248- };
1304+ }
12491305
1250- auto iterateDescB = [&](Value desc, int iterM, int iterN,
1251- int iterK) -> Value {
1252- // todo : Handle row major
1253- int byte = typeTensorB.getElementTypeBitWidth () / 8 ;
1254- int incrementVal = typeTensorB.getDimSize (0 ) * wgmmaShapeK * iterK * byte;
1306+ // / Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1307+ // / Currently, it only handles column-major.
1308+ // /
1309+ // / It moves the pointer like below for [128][64] size:
1310+ // / descB ---> +--+--+--+--+--+--+--+--+
1311+ // / |↓ | | | | | | | |
1312+ // / |↓ | | | | | | | |
1313+ // / |↓ | | | | | | | |
1314+ // / |↓ | | | | | | | |
1315+ // / +--+--+--+--+--+--+--+--+
1316+ // /
1317+ Value iterateDescriptorB (Value desc, int i, int j, int k) {
1318+ MemRefType matrixTypeB = op.getDescriptorB ().getType ().getTensor ();
1319+ Type elemB = matrixTypeB.getElementType ();
1320+ int byte = elemB.getIntOrFloatBitWidth () / 8 ;
1321+ int incrementVal = matrixTypeB.getDimSize (0 ) * wgmmaK * k * byte;
12551322 incrementVal = incrementVal >> exclude4LSB;
12561323 LLVM_DEBUG (DBGSE () << " Descriptor B + " << incrementVal << " \n " );
12571324 if (!incrementVal)
12581325 return desc;
12591326 return makeAdd (desc, makeI64Const (b, incrementVal));
1260- };
1327+ }
1328+
1329+ // / This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1330+ // / descriptors and arranges them based on induction variables: i, j, and k.
1331+ Value generateWgmma (int i, int j, int k, Value matrixC, Value matrixD) {
1332+ LLVM_DEBUG (DBGS () << " \t wgmma."
1333+ << " m" << wgmmaM << " n" << wgmmaN << " k" << wgmmaK
1334+ << " (A[" << (iterationM * wgmmaM) << " :"
1335+ << (iterationM * wgmmaM) + wgmmaM << " ]["
1336+ << (iterationK * wgmmaK) << " :"
1337+ << (iterationK * wgmmaK + wgmmaK) << " ] * "
1338+ << " B[" << (iterationK * wgmmaK) << " :"
1339+ << (iterationK * wgmmaK + wgmmaK) << " ][" << 0 << " :"
1340+ << wgmmaN << " ])\n " );
1341+
1342+ Value descriptorA = iterateDescriptorA (adaptor.getDescriptorA (), i, j, k);
1343+ Value descriptorB = iterateDescriptorB (adaptor.getDescriptorB (), i, j, k);
1344+
1345+ Type elemA = op.getDescriptorA ().getType ().getTensor ().getElementType ();
1346+ NVVM::WGMMATypesAttr itypeA = generateWgmmaType (elemA);
1347+
1348+ Type elemB = op.getDescriptorB ().getType ().getTensor ().getElementType ();
1349+ NVVM::WGMMATypesAttr itypeB = generateWgmmaType (elemB);
1350+
1351+ NVVM::MMAShapeAttr shape = generateWgmmaShape ();
1352+ NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut ();
1353+ NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn ();
1354+ NVVM::MMALayoutAttr layoutA = generateWgmmaLayout (op.getTransposeA ());
1355+ NVVM::MMALayoutAttr layoutB = generateWgmmaLayout (op.getTransposeB ());
1356+
1357+ auto overflow = NVVM::MMAIntOverflowAttr::get (
1358+ op->getContext (), NVVM::MMAIntOverflow::wrapped);
1359+
1360+ Type resultStructType = typeConverter.convertType (matrixD.getType ());
1361+
1362+ return b.create <NVVM::WgmmaMmaAsyncOp>(
1363+ resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
1364+ itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1365+ }
12611366
1262- b.create <NVVM::WgmmaFenceAlignedOp>();
1263-
1264- SmallVector<Value> wgmmaResults;
1265- for (int iterM = 0 ; iterM < (sizeM / wgmmaShapeM); iterM++) {
1266- Value matrixC = adaptor.getMatrixC ()[iterM];
1267- Value matrixD = op.getMatrixD ()[iterM];
1268- Type structType = getTypeConverter ()->convertType (matrixD.getType ());
1269- LLVM_DEBUG (DBGS () << " D[" << (iterM * wgmmaShapeM) << " :"
1270- << (iterM * wgmmaShapeM) + wgmmaShapeM << " ][" << 0
1271- << " :" << wgmmaShapeN << " ] += \n " );
1272- for (int iterK = 0 ; iterK < (sizeK / wgmmaShapeK); iterK++) {
1273- Value descA = iterateDescA (descriptorA, iterM, 0 , iterK);
1274- Value descB = iterateDescB (descriptorB, iterM, 0 , iterK);
1275- LLVM_DEBUG (DBGS () << " \t wgmma."
1276- << " m" << wgmmaShapeM << " n" << wgmmaShapeN << " k"
1277- << wgmmaShapeK << " (A[" << (iterM * wgmmaShapeM)
1278- << " :" << (iterM * wgmmaShapeM) + wgmmaShapeM << " ]["
1279- << (iterK * wgmmaShapeK) << " :"
1280- << (iterK * wgmmaShapeK + wgmmaShapeK) << " ] * "
1281- << " B[" << (iterK * wgmmaShapeK) << " :"
1282- << (iterK * wgmmaShapeK + wgmmaShapeK) << " ][" << 0
1283- << " :" << wgmmaShapeN << " ])\n " );
1284- matrixC = generateNVVMWgmmaOp (b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
1285- structType, matrixC, descA, descB);
1367+ // / Generates multiple wgmma instructions to complete the given GEMM shape
1368+ SmallVector<Value> generateWgmmaGroup () {
1369+ SmallVector<Value> wgmmaResults;
1370+
1371+ // Perform GEMM
1372+ for (int i = 0 ; i < iterationM; ++i) {
1373+ Value matrixC = adaptor.getMatrixC ()[i];
1374+ Value matrixD = op.getMatrixD ()[i];
1375+ for (int j = 0 ; j < iterationN; ++j)
1376+ for (int k = 0 ; k < iterationK; ++k)
1377+ matrixC = generateWgmma (i, j, k, matrixC, matrixD);
1378+ wgmmaResults.push_back (matrixC);
12861379 }
1287- wgmmaResults.push_back (matrixC);
1380+
1381+ return wgmmaResults;
1382+ }
1383+
1384+ public:
1385+ WarpgroupGemm (nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1386+ OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1387+ : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
1388+ // Find the entire GEMM Shape
1389+ totalM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1390+ totalN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
1391+ totalK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1392+ LLVM_DEBUG (DBGS () << " ===--- GEMM D[" << totalM << " ][" << totalN
1393+ << " ] += A[" << totalM << " ][" << totalK << " ] * B["
1394+ << totalK << " ][" << totalN << " ] ---===\n " );
1395+
1396+ // Find the shape for one wgmma instruction
1397+ findWgmmaShape (
1398+ totalM, totalN,
1399+ op.getDescriptorA ().getType ().getTensor ().getElementType ());
1400+
1401+ // Iterations counts to complete the given shape with wgmma shape
1402+ iterationM = totalM / wgmmaM;
1403+ iterationN = totalN / wgmmaN;
1404+ iterationK = totalK / wgmmaK;
12881405 }
1289- b.create <NVVM::WgmmaGroupSyncAlignedOp>();
1290- b.create <NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup ());
12911406
1292- ValueRange myres (wgmmaResults);
1293- rewriter.replaceOp (op, myres);
1407+ // / Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1408+ // / includes generating a fence Op (WgmmaFenceAlignedOp) before the
1409+ // / instructions and group synchronization, as well as waiting
1410+ // / (WgmmaGroupSyncAlignedOp) for group synchronization
1411+ // / (WgmmaWaitGroupSyncOp) after the instructions.
1412+ SmallVector<Value> generateWarpgroupMma () {
1413+ b.create <NVVM::WgmmaFenceAlignedOp>();
1414+ SmallVector<Value> wgmmaResults = generateWgmmaGroup ();
1415+ b.create <NVVM::WgmmaGroupSyncAlignedOp>();
1416+ b.create <NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup ());
1417+ return wgmmaResults;
1418+ }
1419+ };
1420+
1421+ LogicalResult
1422+ matchAndRewrite (nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1423+ ConversionPatternRewriter &rewriter) const override {
1424+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1425+ // Step 1. Build a helper class
1426+ WarpgroupGemm warpgroupGemm (op, b, adaptor, *this ->getTypeConverter ());
1427+
1428+ // Step 2. Get the entire GEMM Shape
1429+ SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma ();
1430+
1431+ // Step 3. Replace fragmented result struct with the op results
1432+ rewriter.replaceOp (op, wgmmaResults);
12941433 return success ();
12951434 }
12961435};
0 commit comments