diff --git a/test/TritonIntelGPU/blockptr_store.mlir b/test/TritonIntelGPU/blockptr_store.mlir index f76497e85e..e714086a53 100644 --- a/test/TritonIntelGPU/blockptr_store.mlir +++ b/test/TritonIntelGPU/blockptr_store.mlir @@ -98,6 +98,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, // CHECK: %[[BASE_PTR:.*]] = llvm.extractvalue %[[BLOCK_PTR]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> %13 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + // CHECK: %[[HEIGHT_i32:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32 + // CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32 + // CHECK: %[[baseWidth:.*]] = llvm.mul %[[HEIGHT_i32]], %[[CST_2]] : i32 + // CHECK: %[[baseHeight:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32 + // CHECK: %[[basePitch:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[CST_2]] : i32 + // COM: The decomposed values of the tensor with DPAS layout. // CHECK: %[[VAL_97:.*]] = llvm.extractvalue %[[VAL_71]][0] // CHECK: %[[VAL_98:.*]] = llvm.extractvalue %[[VAL_71]][1] @@ -164,11 +170,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, // CHECK: %[[VAL_159:.*]] = llvm.extractvalue %[[VAL_71]][62] // CHECK: %[[VAL_160:.*]] = llvm.extractvalue %[[VAL_71]][63] - // CHECK: %[[HEIGHT_i32:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32 - // CHECK: %[[baseHeight:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32 - // CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32 - // CHECK: %[[baseWidth:.*]] = llvm.mul %[[HEIGHT_i32]], %[[CST_2]] : i32 - // CHECK: %[[basePitch:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[CST_2]] : i32 // CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[outerDimWarpId:.*]] = llvm.urem %[[SUB_GROUP_ID_M]], %[[VAL_166]] : i32 // CHECK: %[[VAL_168:.*]] = llvm.mlir.constant(1 : i32) : i32 @@ -181,7 +182,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, // CHECK: %[[warpId1Offset:.*]] = llvm.add %[[dimWarpId1]], %[[OFFSET_1]] : i32 // CHECK: %[[VAL_176:.*]] = llvm.mlir.constant(0 : i32) : i32 - // COM: The shape of DPAS layout replica is [4, 2] // COM: The replica order are [0, 1] // COM: [2, 3] diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 6e3cc4b6b5..7482c6de46 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -2456,11 +2456,6 @@ struct StoreOpToBlockIOConversion auto b = TritonLLVMOpBuilder(loc, rewriter); Type resultType = op.getValue().getType(); auto tensorType = cast(resultType); - - // Only lower StoreOp with dpas layout encoding. - if (!hasDpasEncoding(tensorType)) - return failure(); - auto dpasLayout = cast(tensorType.getEncoding()); LLVMTypeConverter *typeConverter = getTypeConverter(); MLIRContext *ctx = rewriter.getContext(); @@ -2471,14 +2466,21 @@ struct StoreOpToBlockIOConversion const ArrayRef tensorShape = tensorType.getShape(); size_t rank = tensorShape.size(); unsigned numElems = getTotalElemsPerThread(tensorType); + SmallVector elemsPerInstr = dpasLayout.getDPASInstShapeC(); + // 2D block store supports 8 rows at most. + unsigned tileHeight = std::min(8u, elemsPerInstr[0]); + // 2D block store supports 64 bytes per row at most. + unsigned tileWidth = elemsPerInstr[1]; + unsigned totalBytesPerRowPerMatrix = tileWidth * elemSizeInBits / 8; + if (totalBytesPerRowPerMatrix > 64) + return failure(); + auto warpsPerCTA = dpasLayout.getWarpsPerCTA(); SmallVector numReps = dpasLayout.getDPASRepetitions(tensorShape, 2); SmallVector dpasWarpsOrder = getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); - unsigned threadsPerWarp = - product(getThreadsPerWarp(dpasLayout, tensorShape)); Value warpId = rewriter.create( loc, i32_ty, @@ -2487,25 +2489,34 @@ struct StoreOpToBlockIOConversion SmallVector multiDimWarpId = mlir::LLVM::delinearize( rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); - int64_t elemsPerLane = product(elemsPerInstr) / threadsPerWarp; - Type store2DGenXType = - LLVM::getVectorType(IntegerType::get(ctx, elemSizeInBits), - elemsPerLane); // make it opaque type. - Value blockPtr = adaptor.getPtr(); auto [base, width, height, rowStride, colStride, offsetBaseX, offsetBaseY] = getValuesFromBlockPointerStruct(blockPtr, rewriter); - auto vals = unpackLLElements(loc, adaptor.getValue(), rewriter); - assert(vals.size() == numElems); - width = b.trunc(i32_ty, width); - height = b.trunc(i32_ty, height); rowStride = b.trunc(i32_ty, rowStride); // encoded as bytes. Value baseWidth = b.mul(width, elemSizeInBytes); + Value baseHeight = b.trunc(i32_ty, height); // encoded as bytes. - Value basePitch = b.mul(rowStride, elemSizeInBytes); + Value pitch = b.mul(rowStride, elemSizeInBytes); + // 2D block store only supports vBlocks = 1. + unsigned vBlocks = 1; + + // Get the LLVM values for store values + SmallVector valElems = + unpackLLElements(loc, adaptor.getValue(), rewriter); + assert(valElems.size() == numElems && + "the number of store values does not match the number of elements"); + + unsigned threadsPerWarp = + TritonGPUDialect::getThreadsPerWarp(op->getParentOfType()); + + int64_t elemsPerLane = tileHeight * tileWidth / threadsPerWarp; + Type opaqueType = IntegerType::get(ctx, elemSizeInBits); + Type store2DGenXType = + LLVM::getVectorType(opaqueType, + elemsPerLane); // make it opaque type. // A warp stride for the replicates. SmallVector repClusterShape = dpasLayout.getShapeC(); @@ -2538,19 +2549,19 @@ struct StoreOpToBlockIOConversion for (int m = 0; m < numRepOuter; ++m) { for (int n = 0; n < numRepInner; ++n) { for (int repM = 0; repM < repCluster[0]; ++repM) { - Value offsetY = - b.add(warpId0Offset, - b.i32_val(m * replicaStride[0] + repM * elemsPerInstr[0])); + Value offsetY = b.add(warpId0Offset, b.i32_val(m * replicaStride[0] + + repM * tileHeight)); for (int repN = 0; repN < repCluster[1]; ++repN) { Value offsetX = - b.add(warpId1Offset, b.i32_val(n * replicaStride[1] + - repN * elemsPerInstr[1])); + b.add(warpId1Offset, + b.i32_val(n * replicaStride[1] + repN * tileWidth)); + Value storeVal = rewriter.create( loc, LLVM::getVectorType(typeConverter->convertType(eltTy), elemsPerLane)); for (size_t i = 0; i < elemsPerLane; ++i) { storeVal = - b.insert_element(storeVal, vals[valOffset], b.i32_val(i)); + b.insert_element(storeVal, valElems[valOffset], b.i32_val(i)); ++valOffset; } @@ -2558,14 +2569,14 @@ struct StoreOpToBlockIOConversion loc, /*ptr*/ base, /*base_width*/ baseWidth, - /*base_height*/ height, - /*base_pitch*/ basePitch, + /*base_height*/ baseHeight, + /*base_pitch*/ pitch, /*x*/ offsetX, /*y*/ offsetY, /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ elemsPerInstr[1], - /*tile_height*/ elemsPerInstr[0], - /*v_blocks*/ 1, + /*tile_width*/ tileWidth, + /*tile_height*/ tileHeight, + /*v_blocks*/ vBlocks, /*stored_val*/ b.bitcast(storeVal, store2DGenXType)); if (failed(newOp.verify())) {