Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/TritonIntelGPU/blockptr_store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
// CHECK: %[[BASE_PTR:.*]] = llvm.extractvalue %[[BLOCK_PTR]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
%13 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf16, #dpas>>

// CHECK: %[[HEIGHT_i32:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
// CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32
// CHECK: %[[baseWidth:.*]] = llvm.mul %[[HEIGHT_i32]], %[[CST_2]] : i32
// CHECK: %[[baseHeight:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
// CHECK: %[[basePitch:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[CST_2]] : i32

// COM: The decomposed values of the tensor with DPAS layout.
// CHECK: %[[VAL_97:.*]] = llvm.extractvalue %[[VAL_71]][0]
// CHECK: %[[VAL_98:.*]] = llvm.extractvalue %[[VAL_71]][1]
Expand Down Expand Up @@ -164,11 +170,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
// CHECK: %[[VAL_159:.*]] = llvm.extractvalue %[[VAL_71]][62]
// CHECK: %[[VAL_160:.*]] = llvm.extractvalue %[[VAL_71]][63]

// CHECK: %[[HEIGHT_i32:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
// CHECK: %[[baseHeight:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
// CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32
// CHECK: %[[baseWidth:.*]] = llvm.mul %[[HEIGHT_i32]], %[[CST_2]] : i32
// CHECK: %[[basePitch:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[CST_2]] : i32
// CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[outerDimWarpId:.*]] = llvm.urem %[[SUB_GROUP_ID_M]], %[[VAL_166]] : i32
// CHECK: %[[VAL_168:.*]] = llvm.mlir.constant(1 : i32) : i32
Expand All @@ -181,7 +182,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
// CHECK: %[[warpId1Offset:.*]] = llvm.add %[[dimWarpId1]], %[[OFFSET_1]] : i32
// CHECK: %[[VAL_176:.*]] = llvm.mlir.constant(0 : i32) : i32


// COM: The shape of DPAS layout replica is [4, 2]
// COM: The replica order are [0, 1]
// COM: [2, 3]
Expand Down
66 changes: 38 additions & 28 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2456,11 +2456,6 @@ struct StoreOpToBlockIOConversion
auto b = TritonLLVMOpBuilder(loc, rewriter);
Type resultType = op.getValue().getType();
auto tensorType = cast<RankedTensorType>(resultType);

// Only lower StoreOp with dpas layout encoding.
if (!hasDpasEncoding(tensorType))
return failure();

auto dpasLayout = cast<DpasEncodingAttr>(tensorType.getEncoding());
LLVMTypeConverter *typeConverter = getTypeConverter();
MLIRContext *ctx = rewriter.getContext();
Expand All @@ -2471,14 +2466,21 @@ struct StoreOpToBlockIOConversion
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
size_t rank = tensorShape.size();
unsigned numElems = getTotalElemsPerThread(tensorType);

SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
// 2D block store supports 8 rows at most.
unsigned tileHeight = std::min(8u, elemsPerInstr[0]);
// 2D block store supports 64 bytes per row at most.
unsigned tileWidth = elemsPerInstr[1];
unsigned totalBytesPerRowPerMatrix = tileWidth * elemSizeInBits / 8;
if (totalBytesPerRowPerMatrix > 64)
return failure();

auto warpsPerCTA = dpasLayout.getWarpsPerCTA();
SmallVector<int64_t> numReps =
dpasLayout.getDPASRepetitions(tensorShape, 2);
SmallVector<unsigned> dpasWarpsOrder =
getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true);
unsigned threadsPerWarp =
product<unsigned>(getThreadsPerWarp(dpasLayout, tensorShape));

Value warpId = rewriter.create<arith::IndexCastOp>(
loc, i32_ty,
Expand All @@ -2487,25 +2489,33 @@ struct StoreOpToBlockIOConversion
SmallVector<Value> multiDimWarpId = mlir::LLVM::delinearize(
rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);

int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
Type store2DGenXType =
LLVM::getVectorType(IntegerType::get(ctx, elemSizeInBits),
elemsPerLane); // make it opaque type.

Value blockPtr = adaptor.getPtr();
auto [base, width, height, rowStride, colStride, offsetBaseX, offsetBaseY] =
getValuesFromBlockPointerStruct(blockPtr, rewriter);

auto vals = unpackLLElements(loc, adaptor.getValue(), rewriter);
assert(vals.size() == numElems);

width = b.trunc(i32_ty, width);
height = b.trunc(i32_ty, height);
rowStride = b.trunc(i32_ty, rowStride);
// encoded as bytes.
Value baseWidth = b.mul(width, elemSizeInBytes);
Value baseHeight = b.trunc(i32_ty, height);
// encoded as bytes.
Value basePitch = b.mul(rowStride, elemSizeInBytes);
Value pitch = b.mul(rowStride, elemSizeInBytes);
// 2D block store only supports vBlocks = 1.
unsigned vBlocks = 1;

// Get the LLVM values for store values
auto valElems = unpackLLElements(loc, adaptor.getValue(), rewriter);
assert(valElems.size() == numElems &&
"the number of store values does not match the number of elements");

unsigned threadsPerWarp =
TritonGPUDialect::getThreadsPerWarp(op->getParentOfType<ModuleOp>());

int64_t elemsPerLane = tileHeight * tileWidth / threadsPerWarp;
Type opaqueType = IntegerType::get(ctx, elemSizeInBits);
Type store2DGenXType =
LLVM::getVectorType(opaqueType,
elemsPerLane); // make it opaque type.

// A warp stride for the replicates.
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
Expand Down Expand Up @@ -2538,34 +2548,34 @@ struct StoreOpToBlockIOConversion
for (int m = 0; m < numRepOuter; ++m) {
for (int n = 0; n < numRepInner; ++n) {
for (int repM = 0; repM < repCluster[0]; ++repM) {
Value offsetY =
b.add(warpId0Offset,
b.i32_val(m * replicaStride[0] + repM * elemsPerInstr[0]));
Value offsetY = b.add(warpId0Offset, b.i32_val(m * replicaStride[0] +
repM * tileHeight));
for (int repN = 0; repN < repCluster[1]; ++repN) {
Value offsetX =
b.add(warpId1Offset, b.i32_val(n * replicaStride[1] +
repN * elemsPerInstr[1]));
b.add(warpId1Offset,
b.i32_val(n * replicaStride[1] + repN * tileWidth));

Value storeVal = rewriter.create<LLVM::UndefOp>(
loc, LLVM::getVectorType(typeConverter->convertType(eltTy),
elemsPerLane));
for (size_t i = 0; i < elemsPerLane; ++i) {
storeVal =
b.insert_element(storeVal, vals[valOffset], b.i32_val(i));
b.insert_element(storeVal, valElems[valOffset], b.i32_val(i));
++valOffset;
}

auto newOp = rewriter.create<TritonGEN::Matrix2DBlockStoreOp>(
loc,
/*ptr*/ base,
/*base_width*/ baseWidth,
/*base_height*/ height,
/*base_pitch*/ basePitch,
/*base_height*/ baseHeight,
/*base_pitch*/ pitch,
/*x*/ offsetX,
/*y*/ offsetY,
/*elem_size_in_bits*/ elemSizeInBits,
/*tile_width*/ elemsPerInstr[1],
/*tile_height*/ elemsPerInstr[0],
/*v_blocks*/ 1,
/*tile_width*/ tileWidth,
/*tile_height*/ tileHeight,
/*v_blocks*/ vBlocks,
/*stored_val*/ b.bitcast(storeVal, store2DGenXType));

if (failed(newOp.verify())) {
Expand Down