Skip to content

Commit 06ca5c8

Browse files
[mlir][Linalg] Apply fixes to TileReductionUsingForeachThreadOp
In the process, numerous insertion point issues were found and fixed. RAII on insertion points is now used more dilligently. Differential Revision: https://reviews.llvm.org/D139714
1 parent cf98e82 commit 06ca5c8

File tree

5 files changed

+183
-79
lines changed

5 files changed

+183
-79
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def TileReductionUsingForeachThreadOp :
796796
scf.foreach_thread.perform_concurrently {
797797
tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
798798
}
799-
} {thread_dim_mapping = []}
799+
} {mapping = []}
800800
%3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<?x5xf32>) outs(%arg1 : tensor<?xf32>) {
801801
^bb0(%in: f32, %out: f32):
802802
%4 = arith.addf %in, %out : f32
@@ -807,7 +807,8 @@ def TileReductionUsingForeachThreadOp :
807807

808808
let arguments = (ins PDL_Operation:$target,
809809
DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
810-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
810+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
811+
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
811812
let results = (outs PDL_Operation:$fill_op,
812813
PDL_Operation:$split_linalg_op,
813814
PDL_Operation:$combining_linalg_op);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
12221222
FailureOr<linalg::ForeachThreadReductionTilingResult> result =
12231223
linalg::tileReductionUsingForeachThread(
12241224
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
1225-
numThreads, tileSizes, /*mapping=*/std::nullopt);
1225+
numThreads, tileSizes, getMapping());
12261226

12271227
if (failed(result)) {
12281228
results.assign(3, nullptr);

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 136 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@
2525
#include "mlir/Dialect/Utils/IndexingUtils.h"
2626
#include "mlir/IR/AffineExpr.h"
2727
#include "mlir/IR/AffineMap.h"
28+
#include "mlir/IR/BuiltinOps.h"
29+
#include "mlir/IR/ValueRange.h"
2830
#include "mlir/Transforms/FoldUtils.h"
2931
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32+
#include "llvm/ADT/STLExtras.h"
3033
#include "llvm/Support/CommandLine.h"
3134
#include <utility>
3235

@@ -221,6 +224,9 @@ static void calculateTileOffsetsAndSizes(
221224
Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
222225
SmallVector<OpFoldResult> &tiledOffsets,
223226
SmallVector<OpFoldResult> &tiledSizes) {
227+
OpBuilder::InsertionGuard g(b);
228+
b.setInsertionPointToStart(foreachThreadOp.getBody(0));
229+
224230
ValueRange threadIds = foreachThreadOp.getThreadIndices();
225231
SmallVector<OpFoldResult> nonZeroNumThreads =
226232
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
@@ -300,6 +306,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
300306
Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
301307
Location loc = op->getLoc();
302308
OpBuilder::InsertionGuard g(b);
309+
303310
SmallVector<Range> loopRanges = op.getIterationDomain(b);
304311
if (loopRanges.empty())
305312
return op->emitOpError("expected non-empty loop ranges");
@@ -323,54 +330,64 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
323330

324331
Operation *tiledOp = nullptr;
325332

326-
// Create the ForeachThreadOp. We don't use the lambda body-builder
333+
// 1. Create the ForeachThreadOp. We don't use the lambda body-builder
327334
// version because we require the use of RewriterBase in the body, so we
328335
// manually move the insertion point to the body below.
329336
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
330337
loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
331338

332-
// Fill out the ForeachThreadOp body.
333-
b.setInsertionPointToStart(foreachThreadOp.getBody(0));
339+
// 2. Fill out the ForeachThreadOp body.
334340
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
335341
calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges,
336342
omitTileOffsetBoundsCheck, nominalTileSizes,
337343
tiledOffsets, tiledSizes);
338344

339-
// Clone the tileable op and update its destination operands to use the output
340-
// bbArgs of the ForeachThreadOp.
345+
// 3. Clone the tileable op and update its destination operands to use the
346+
// output bbArgs of the ForeachThreadOp.
341347
ArrayRef<BlockArgument> destBbArgs =
342348
foreachThreadOp.getOutputBlockArguments();
343-
Operation *clonedOp = b.clone(*op.getOperation());
344-
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
345-
if (destinationStyleOp) {
346-
for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
347-
auto *it = llvm::find(dest, outOperand->get());
348-
assert(it != dest.end() && "dest operand not found in dest");
349-
unsigned destNum = std::distance(dest.begin(), it);
350-
outOperand->set(destBbArgs[destNum]);
349+
{
350+
// 3.a. RAII guard, inserting within foreachThreadOp, before terminator.
351+
OpBuilder::InsertionGuard g(b);
352+
b.setInsertionPoint(foreachThreadOp.getTerminator());
353+
Operation *clonedOp = b.clone(*op.getOperation());
354+
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
355+
if (destinationStyleOp) {
356+
for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
357+
auto *it = llvm::find(dest, outOperand->get());
358+
assert(it != dest.end() && "dest operand not found in dest");
359+
unsigned destNum = std::distance(dest.begin(), it);
360+
outOperand->set(destBbArgs[destNum]);
361+
}
351362
}
352-
}
353363

354-
// Tile the cloned op and delete the clone.
355-
SmallVector<Operation *> tiledOps =
356-
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
357-
tiledSizes);
358-
b.eraseOp(clonedOp);
359-
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
360-
tiledOp = tiledOps.front();
364+
// 4. Tile the cloned op and delete the clone.
365+
SmallVector<Operation *> tiledOps =
366+
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
367+
tiledSizes);
368+
b.eraseOp(clonedOp);
369+
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
370+
tiledOp = tiledOps.front();
371+
}
361372

373+
// 5. Parallel insert back into the result tensor.
362374
auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
363375
assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
364-
OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
365376
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
366377
tilingInterfaceOp->getResults(), destBbArgs)) {
367-
b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
378+
// 5.a. Partial subset information is inserted just before the terminator.
379+
OpBuilder::InsertionGuard g(b);
380+
b.setInsertionPoint(foreachThreadOp.getTerminator());
381+
368382
SmallVector<OpFoldResult> resultOffsets, resultSizes;
369383
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
370384
tiledSizes, resultOffsets,
371385
resultSizes)))
372386
return op->emitOpError("output offsets couldn't be calculated");
373387
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
388+
389+
// 5.b. Parallel insertions are inserted at the end of the combining
390+
// terminator.
374391
b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
375392
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
376393
std::get<2>(it), resultOffsets,
@@ -415,6 +432,8 @@ template <typename LoopTy>
415432
static FailureOr<TiledLinalgOp>
416433
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
417434
const LinalgTilingOptions &options) {
435+
OpBuilder::InsertionGuard g(b);
436+
418437
auto nLoops = op.getNumLoops();
419438
// Initial tile sizes may be too big, only take the first nLoops.
420439
tileSizes = tileSizes.take_front(nLoops);
@@ -570,25 +589,44 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
570589
Optional<ArrayAttr> mapping) {
571590
Location loc = op.getLoc();
572591
OpBuilder::InsertionGuard g(b);
592+
573593
// Ops implementing PartialReductionOpInterface are expected to implement
574594
// TilingInterface.
595+
// TODO: proper core mechanism to tie interfaces together.
575596
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
597+
598+
// Ops implementing PartialReductionOpInterface are not necessarily expected
599+
// to implement TilingInterface.. This cast is unsafe atm.
600+
// TODO: proper core mechanism to tie interfaces together.
601+
// TODO: this function requires a pair of interfaces ..
602+
auto destinationStyleOp =
603+
dyn_cast<DestinationStyleOpInterface>(op.getOperation());
604+
if (!destinationStyleOp)
605+
return b.notifyMatchFailure(op, "not a destination style op");
606+
607+
// Actually this only work for Linalg ops atm.
608+
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
609+
if (!linalgOp)
610+
return b.notifyMatchFailure(op, "not a linalg op");
611+
576612
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
577613
if (op->getNumResults() != 1)
578614
return b.notifyMatchFailure(
579615
op, "don't support ops with multiple results for now");
616+
580617
SmallVector<utils::IteratorType> iterators =
581618
tilingInterfaceOp.getLoopIteratorTypes();
582619
SmallVector<unsigned> redDims;
583-
cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
620+
linalgOp.getReductionDims(redDims);
584621
if (redDims.size() != 1)
585622
return b.notifyMatchFailure(
586623
op, "only support ops with one reduction dimension.");
587624
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
588625
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
589626
"many elements as number of threads");
590627
int reductionDim = static_cast<int>(redDims.front());
591-
// 1. create the inital tensor value.
628+
629+
// 1. Create the inital tensor value.
592630
FailureOr<Operation *> identityTensor =
593631
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
594632
reductionDim);
@@ -615,8 +653,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
615653
loc, identityTensor.value()->getResults(),
616654
ValueRange(materializedNonZeroNumThreads), mapping);
617655

618-
// 3. calculate the tile offsets and sizes.
619-
b.setInsertionPointToStart(foreachThreadOp.getBody(0));
656+
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
657+
// be nested under `foreachThreadOp`.
620658
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
621659
calculateTileOffsetsAndSizes(
622660
b, loc, foreachThreadOp, numThreads, iterationDomain,
@@ -625,54 +663,77 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
625663

626664
// 4. Clone the tileable op and update its destination operands to use the
627665
// output bbArgs of the ForeachThreadOp.
666+
ValueRange tilingResults;
628667
ArrayRef<BlockArgument> destBbArgs =
629668
foreachThreadOp.getOutputBlockArguments();
630-
Operation *clonedOp = b.clone(*op.getOperation());
631-
b.setInsertionPointToStart(foreachThreadOp.getBody(0));
632-
auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
633-
for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
634-
auto *it = llvm::find(dest, initOperand->get());
635-
assert(it != dest.end() && "dest operand not found in dest");
636-
unsigned destNum = std::distance(dest.begin(), it);
637-
SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
638-
SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
639-
SmallVector<OpFoldResult> sizes = tiledSizes;
640-
sizes[reductionDim] = b.getIndexAttr(1);
641-
outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
642-
// TODO: use SubsetExtractOpInterface once it is available.
643-
Value patial = b.create<tensor::ExtractSliceOp>(
644-
loc, initOperand->get().getType().cast<RankedTensorType>(),
645-
destBbArgs[destNum], outOffsets, sizes, strides);
646-
initOperand->set(patial);
647-
}
648-
b.setInsertionPoint(clonedOp);
669+
{
670+
// 4.a. RAII guard, inserting within foreachThreadOp, before terminator.
671+
OpBuilder::InsertionGuard g(b);
672+
b.setInsertionPoint(foreachThreadOp.getTerminator());
673+
674+
SmallVector<Value> tiledDpsInitOperands;
675+
for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
676+
auto *it = llvm::find(dest, initOperand->get());
677+
assert(it != dest.end() && "dest operand not found in dest");
678+
unsigned destNum = std::distance(dest.begin(), it);
679+
SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
680+
SmallVector<OpFoldResult> outOffsets(numThreads.size(),
681+
b.getIndexAttr(0));
682+
SmallVector<OpFoldResult> sizes = tiledSizes;
683+
sizes[reductionDim] = b.getIndexAttr(1);
684+
outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
685+
// TODO: use SubsetExtractOpInterface once it is available.
686+
tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
687+
loc, initOperand->get().getType().cast<RankedTensorType>(),
688+
destBbArgs[destNum], outOffsets, sizes, strides));
689+
}
649690

650-
// 5. Tile the cloned op and delete the clone.
651-
if (tileSizes.empty()) {
652-
SmallVector<Operation *> tiledOps =
653-
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
654-
tiledSizes);
655-
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
656-
tiledOp = tiledOps.front();
657-
} else {
658-
LinalgTilingOptions options;
659-
auto tiled = tileLinalgOpImpl<scf::ForOp>(b, cast<LinalgOp>(clonedOp),
660-
tileSizes, options);
661-
SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
662-
mapLoopToProcessorIds(cast<scf::ForOp>(tiled->loops.back()), ids,
663-
materializedNonZeroNumThreads);
664-
assert(tiled->loops.size() == 1 && "expected a single produced loop");
665-
tiledOp = tiled->loops.front();
691+
// 4.b. Clone the op and update init operands.
692+
// We cannot use a BlockAndValueMapping here because it can replace
693+
// different OpOperands with the same value.
694+
Operation *clonedOp = b.clone(*op.getOperation());
695+
b.updateRootInPlace(clonedOp, [&]() {
696+
for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
697+
cast<DestinationStyleOpInterface>(clonedOp).getDpsInitOperands(),
698+
tiledDpsInitOperands)) {
699+
initOperandPtr->set(tiledInitValue);
700+
}
701+
});
702+
703+
// 5. Tile the cloned op and delete the clone.
704+
if (tileSizes.empty()) {
705+
SmallVector<Operation *> tiledOps =
706+
cast<TilingInterface>(clonedOp).getTiledImplementation(
707+
b, tiledOffsets, tiledSizes);
708+
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
709+
tiledOp = tiledOps.front();
710+
tilingResults = tiledOp->getResults();
711+
} else {
712+
LinalgTilingOptions options;
713+
FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
714+
b, cast<LinalgOp>(clonedOp), tileSizes, options);
715+
if (failed(maybeTiled))
716+
return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
717+
718+
SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
719+
mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
720+
materializedNonZeroNumThreads);
721+
assert(maybeTiled->loops.size() == 1 &&
722+
"expected a single produced loop");
723+
tiledOp = maybeTiled->op;
724+
tilingResults = maybeTiled->loops.front()->getResults();
725+
}
726+
727+
b.eraseOp(clonedOp);
666728
}
667-
b.eraseOp(clonedOp);
668729

669730
// 6. Insert the partial reductions back into a new tensor.
670-
b.setInsertionPointAfter(tiledOp);
671-
OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
672-
for (auto [index, result, bbArg] :
673-
llvm::zip(llvm::seq<unsigned>(0, dest.size()), tiledOp->getResults(),
674-
destBbArgs)) {
675-
b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
731+
for (auto [index, result, bbArg] : llvm::zip(
732+
llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
733+
// 6.a. Partial subset information is inserted just before the terminator.
734+
OpBuilder::InsertionGuard g(b);
735+
b.setInsertionPoint(foreachThreadOp.getTerminator());
736+
676737
SmallVector<OpFoldResult> resultOffsets, resultSizes;
677738
if (failed(tilingInterfaceOp.getResultTilePosition(
678739
b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
@@ -689,18 +750,23 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
689750
resultOffsetsRank.push_back(resultOffsets[offIdx++]);
690751
resultSizesRank.push_back(resultSizes[sizeIdx++]);
691752
}
692-
693753
SmallVector<OpFoldResult> strides(resultSizesRank.size(),
694754
b.getIndexAttr(1));
755+
756+
// 6.b. Parallel insertions are inserted at the end of the combining
757+
// terminator.
695758
b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
696759
b.create<tensor::ParallelInsertSliceOp>(
697760
loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
698761
}
762+
699763
// 7. Merge the partial reductions.
700764
b.setInsertionPointAfter(foreachThreadOp);
701765
Operation *mergeOp =
702766
op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
703767
b.replaceOp(op, mergeOp->getResults());
768+
769+
// 8. Return.
704770
ForeachThreadReductionTilingResult results;
705771
results.initialOp = identityTensor.value();
706772
results.loops = foreachThreadOp;

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -874,19 +874,19 @@ void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
874874
DiagnosedSilenceableFailure
875875
transform::PrintOp::apply(transform::TransformResults &results,
876876
transform::TransformState &state) {
877-
llvm::errs() << "[[[ IR printer: ";
877+
llvm::outs() << "[[[ IR printer: ";
878878
if (getName().has_value())
879-
llvm::errs() << *getName() << " ";
879+
llvm::outs() << *getName() << " ";
880880

881881
if (!getTarget()) {
882-
llvm::errs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
882+
llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
883883
return DiagnosedSilenceableFailure::success();
884884
}
885885

886-
llvm::errs() << "]]]\n";
886+
llvm::outs() << "]]]\n";
887887
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
888888
for (Operation *target : targets)
889-
llvm::errs() << *target << "\n";
889+
llvm::outs() << *target << "\n";
890890

891891
return DiagnosedSilenceableFailure::success();
892892
}

0 commit comments

Comments
 (0)