2525#include " mlir/Dialect/Utils/IndexingUtils.h"
2626#include " mlir/IR/AffineExpr.h"
2727#include " mlir/IR/AffineMap.h"
28+ #include " mlir/IR/BuiltinOps.h"
29+ #include " mlir/IR/ValueRange.h"
2830#include " mlir/Transforms/FoldUtils.h"
2931#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
32+ #include " llvm/ADT/STLExtras.h"
3033#include " llvm/Support/CommandLine.h"
3134#include < utility>
3235
@@ -221,6 +224,9 @@ static void calculateTileOffsetsAndSizes(
221224 Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
222225 SmallVector<OpFoldResult> &tiledOffsets,
223226 SmallVector<OpFoldResult> &tiledSizes) {
227+ OpBuilder::InsertionGuard g (b);
228+ b.setInsertionPointToStart (foreachThreadOp.getBody (0 ));
229+
224230 ValueRange threadIds = foreachThreadOp.getThreadIndices ();
225231 SmallVector<OpFoldResult> nonZeroNumThreads =
226232 llvm::to_vector (llvm::make_filter_range (numThreads, [](OpFoldResult ofr) {
@@ -300,6 +306,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
300306 Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
301307 Location loc = op->getLoc ();
302308 OpBuilder::InsertionGuard g (b);
309+
303310 SmallVector<Range> loopRanges = op.getIterationDomain (b);
304311 if (loopRanges.empty ())
305312 return op->emitOpError (" expected non-empty loop ranges" );
@@ -323,54 +330,64 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
323330
324331 Operation *tiledOp = nullptr ;
325332
326- // Create the ForeachThreadOp. We don't use the lambda body-builder
333+ // 1. Create the ForeachThreadOp. We don't use the lambda body-builder
327334 // version because we require the use of RewriterBase in the body, so we
328335 // manually move the insertion point to the body below.
329336 scf::ForeachThreadOp foreachThreadOp = b.create <scf::ForeachThreadOp>(
330337 loc, dest, ValueRange (materializedNonZeroNumThreads), mapping);
331338
332- // Fill out the ForeachThreadOp body.
333- b.setInsertionPointToStart (foreachThreadOp.getBody (0 ));
339+ // 2. Fill out the ForeachThreadOp body.
334340 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
335341 calculateTileOffsetsAndSizes (b, loc, foreachThreadOp, numThreads, loopRanges,
336342 omitTileOffsetBoundsCheck, nominalTileSizes,
337343 tiledOffsets, tiledSizes);
338344
339- // Clone the tileable op and update its destination operands to use the output
340- // bbArgs of the ForeachThreadOp.
345+ // 3. Clone the tileable op and update its destination operands to use the
346+ // output bbArgs of the ForeachThreadOp.
341347 ArrayRef<BlockArgument> destBbArgs =
342348 foreachThreadOp.getOutputBlockArguments ();
343- Operation *clonedOp = b.clone (*op.getOperation ());
344- auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
345- if (destinationStyleOp) {
346- for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands ()) {
347- auto *it = llvm::find (dest, outOperand->get ());
348- assert (it != dest.end () && " dest operand not found in dest" );
349- unsigned destNum = std::distance (dest.begin (), it);
350- outOperand->set (destBbArgs[destNum]);
349+ {
350+ // 3.a. RAII guard, inserting within foreachThreadOp, before terminator.
351+ OpBuilder::InsertionGuard g (b);
352+ b.setInsertionPoint (foreachThreadOp.getTerminator ());
353+ Operation *clonedOp = b.clone (*op.getOperation ());
354+ auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
355+ if (destinationStyleOp) {
356+ for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands ()) {
357+ auto *it = llvm::find (dest, outOperand->get ());
358+ assert (it != dest.end () && " dest operand not found in dest" );
359+ unsigned destNum = std::distance (dest.begin (), it);
360+ outOperand->set (destBbArgs[destNum]);
361+ }
351362 }
352- }
353363
354- // Tile the cloned op and delete the clone.
355- SmallVector<Operation *> tiledOps =
356- cast<TilingInterface>(clonedOp).getTiledImplementation (b, tiledOffsets,
357- tiledSizes);
358- b.eraseOp (clonedOp);
359- assert (tiledOps.size () == 1 && " expected a single produced tiled op" );
360- tiledOp = tiledOps.front ();
364+ // 4. Tile the cloned op and delete the clone.
365+ SmallVector<Operation *> tiledOps =
366+ cast<TilingInterface>(clonedOp).getTiledImplementation (b, tiledOffsets,
367+ tiledSizes);
368+ b.eraseOp (clonedOp);
369+ assert (tiledOps.size () == 1 && " expected a single produced tiled op" );
370+ tiledOp = tiledOps.front ();
371+ }
361372
373+ // 5. Parallel insert back into the result tensor.
362374 auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
363375 assert (tilingInterfaceOp && " Tiled op does not implement TilingInterface" );
364- OpBuilder::InsertPoint insertPt = b.saveInsertionPoint ();
365376 for (auto it : llvm::zip (llvm::seq (unsigned (0 ), unsigned (dest.size ())),
366377 tilingInterfaceOp->getResults (), destBbArgs)) {
367- b.setInsertionPoint (insertPt.getBlock (), insertPt.getPoint ());
378+ // 5.a. Partial subset information is inserted just before the terminator.
379+ OpBuilder::InsertionGuard g (b);
380+ b.setInsertionPoint (foreachThreadOp.getTerminator ());
381+
368382 SmallVector<OpFoldResult> resultOffsets, resultSizes;
369383 if (failed (op.getResultTilePosition (b, std::get<0 >(it), tiledOffsets,
370384 tiledSizes, resultOffsets,
371385 resultSizes)))
372386 return op->emitOpError (" output offsets couldn't be calculated" );
373387 SmallVector<OpFoldResult> strides (resultSizes.size (), b.getIndexAttr (1 ));
388+
389+ // 5.b. Parallel insertions are inserted at the end of the combining
390+ // terminator.
374391 b.setInsertionPointToEnd (foreachThreadOp.getTerminator ().getBody ());
375392 b.create <tensor::ParallelInsertSliceOp>(loc, std::get<1 >(it),
376393 std::get<2 >(it), resultOffsets,
@@ -415,6 +432,8 @@ template <typename LoopTy>
415432static FailureOr<TiledLinalgOp>
416433tileLinalgOpImpl (RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
417434 const LinalgTilingOptions &options) {
435+ OpBuilder::InsertionGuard g (b);
436+
418437 auto nLoops = op.getNumLoops ();
419438 // Initial tile sizes may be too big, only take the first nLoops.
420439 tileSizes = tileSizes.take_front (nLoops);
@@ -570,25 +589,44 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
570589 Optional<ArrayAttr> mapping) {
571590 Location loc = op.getLoc ();
572591 OpBuilder::InsertionGuard g (b);
592+
573593 // Ops implementing PartialReductionOpInterface are expected to implement
574594 // TilingInterface.
595+ // TODO: proper core mechanism to tie interfaces together.
575596 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation ());
597+
598+ // Ops implementing PartialReductionOpInterface are not necessarily expected
599+ // to implement TilingInterface.. This cast is unsafe atm.
600+ // TODO: proper core mechanism to tie interfaces together.
601+ // TODO: this function requires a pair of interfaces ..
602+ auto destinationStyleOp =
603+ dyn_cast<DestinationStyleOpInterface>(op.getOperation ());
604+ if (!destinationStyleOp)
605+ return b.notifyMatchFailure (op, " not a destination style op" );
606+
607+ // Actually this only work for Linalg ops atm.
608+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation ());
609+ if (!linalgOp)
610+ return b.notifyMatchFailure (op, " not a linalg op" );
611+
576612 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain (b);
577613 if (op->getNumResults () != 1 )
578614 return b.notifyMatchFailure (
579615 op, " don't support ops with multiple results for now" );
616+
580617 SmallVector<utils::IteratorType> iterators =
581618 tilingInterfaceOp.getLoopIteratorTypes ();
582619 SmallVector<unsigned > redDims;
583- cast<linalg::LinalgOp>(op. getOperation ()) .getReductionDims (redDims);
620+ linalgOp .getReductionDims (redDims);
584621 if (redDims.size () != 1 )
585622 return b.notifyMatchFailure (
586623 op, " only support ops with one reduction dimension." );
587624 if (!tileSizes.empty () && tileSizes.size () != numThreads.size ())
588625 return b.notifyMatchFailure (op, " if tile sizes are present it must have as "
589626 " many elements as number of threads" );
590627 int reductionDim = static_cast <int >(redDims.front ());
591- // 1. create the inital tensor value.
628+
629+ // 1. Create the inital tensor value.
592630 FailureOr<Operation *> identityTensor =
593631 op.generateInitialTensorForPartialReduction (b, loc, numThreads,
594632 reductionDim);
@@ -615,8 +653,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
615653 loc, identityTensor.value ()->getResults (),
616654 ValueRange (materializedNonZeroNumThreads), mapping);
617655
618- // 3. calculate the tile offsets and sizes.
619- b. setInsertionPointToStart ( foreachThreadOp. getBody ( 0 ));
656+ // 3. Calculate the tile offsets and sizes for the subsequent loop that will
657+ // be nested under ` foreachThreadOp`.
620658 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
621659 calculateTileOffsetsAndSizes (
622660 b, loc, foreachThreadOp, numThreads, iterationDomain,
@@ -625,54 +663,77 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
625663
626664 // 4. Clone the tileable op and update its destination operands to use the
627665 // output bbArgs of the ForeachThreadOp.
666+ ValueRange tilingResults;
628667 ArrayRef<BlockArgument> destBbArgs =
629668 foreachThreadOp.getOutputBlockArguments ();
630- Operation *clonedOp = b.clone (*op.getOperation ());
631- b.setInsertionPointToStart (foreachThreadOp.getBody (0 ));
632- auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
633- for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands ()) {
634- auto *it = llvm::find (dest, initOperand->get ());
635- assert (it != dest.end () && " dest operand not found in dest" );
636- unsigned destNum = std::distance (dest.begin (), it);
637- SmallVector<OpFoldResult> strides (numThreads.size (), b.getIndexAttr (1 ));
638- SmallVector<OpFoldResult> outOffsets (numThreads.size (), b.getIndexAttr (0 ));
639- SmallVector<OpFoldResult> sizes = tiledSizes;
640- sizes[reductionDim] = b.getIndexAttr (1 );
641- outOffsets[reductionDim] = foreachThreadOp.getThreadIndices ().front ();
642- // TODO: use SubsetExtractOpInterface once it is available.
643- Value patial = b.create <tensor::ExtractSliceOp>(
644- loc, initOperand->get ().getType ().cast <RankedTensorType>(),
645- destBbArgs[destNum], outOffsets, sizes, strides);
646- initOperand->set (patial);
647- }
648- b.setInsertionPoint (clonedOp);
669+ {
670+ // 4.a. RAII guard, inserting within foreachThreadOp, before terminator.
671+ OpBuilder::InsertionGuard g (b);
672+ b.setInsertionPoint (foreachThreadOp.getTerminator ());
673+
674+ SmallVector<Value> tiledDpsInitOperands;
675+ for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands ()) {
676+ auto *it = llvm::find (dest, initOperand->get ());
677+ assert (it != dest.end () && " dest operand not found in dest" );
678+ unsigned destNum = std::distance (dest.begin (), it);
679+ SmallVector<OpFoldResult> strides (numThreads.size (), b.getIndexAttr (1 ));
680+ SmallVector<OpFoldResult> outOffsets (numThreads.size (),
681+ b.getIndexAttr (0 ));
682+ SmallVector<OpFoldResult> sizes = tiledSizes;
683+ sizes[reductionDim] = b.getIndexAttr (1 );
684+ outOffsets[reductionDim] = foreachThreadOp.getThreadIndices ().front ();
685+ // TODO: use SubsetExtractOpInterface once it is available.
686+ tiledDpsInitOperands.push_back (b.create <tensor::ExtractSliceOp>(
687+ loc, initOperand->get ().getType ().cast <RankedTensorType>(),
688+ destBbArgs[destNum], outOffsets, sizes, strides));
689+ }
649690
650- // 5. Tile the cloned op and delete the clone.
651- if (tileSizes.empty ()) {
652- SmallVector<Operation *> tiledOps =
653- cast<TilingInterface>(clonedOp).getTiledImplementation (b, tiledOffsets,
654- tiledSizes);
655- assert (tiledOps.size () == 1 && " expected a single produced tiled op" );
656- tiledOp = tiledOps.front ();
657- } else {
658- LinalgTilingOptions options;
659- auto tiled = tileLinalgOpImpl<scf::ForOp>(b, cast<LinalgOp>(clonedOp),
660- tileSizes, options);
661- SmallVector<Value> ids = foreachThreadOp.getThreadIndices ();
662- mapLoopToProcessorIds (cast<scf::ForOp>(tiled->loops .back ()), ids,
663- materializedNonZeroNumThreads);
664- assert (tiled->loops .size () == 1 && " expected a single produced loop" );
665- tiledOp = tiled->loops .front ();
691+ // 4.b. Clone the op and update init operands.
692+ // We cannot use a BlockAndValueMapping here because it can replace
693+ // different OpOperands with the same value.
694+ Operation *clonedOp = b.clone (*op.getOperation ());
695+ b.updateRootInPlace (clonedOp, [&]() {
696+ for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal (
697+ cast<DestinationStyleOpInterface>(clonedOp).getDpsInitOperands (),
698+ tiledDpsInitOperands)) {
699+ initOperandPtr->set (tiledInitValue);
700+ }
701+ });
702+
703+ // 5. Tile the cloned op and delete the clone.
704+ if (tileSizes.empty ()) {
705+ SmallVector<Operation *> tiledOps =
706+ cast<TilingInterface>(clonedOp).getTiledImplementation (
707+ b, tiledOffsets, tiledSizes);
708+ assert (tiledOps.size () == 1 && " expected a single produced tiled op" );
709+ tiledOp = tiledOps.front ();
710+ tilingResults = tiledOp->getResults ();
711+ } else {
712+ LinalgTilingOptions options;
713+ FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
714+ b, cast<LinalgOp>(clonedOp), tileSizes, options);
715+ if (failed (maybeTiled))
716+ return b.notifyMatchFailure (op, " failed tileLinalgOpImpl" );
717+
718+ SmallVector<Value> ids = foreachThreadOp.getThreadIndices ();
719+ mapLoopToProcessorIds (cast<scf::ForOp>(maybeTiled->loops .back ()), ids,
720+ materializedNonZeroNumThreads);
721+ assert (maybeTiled->loops .size () == 1 &&
722+ " expected a single produced loop" );
723+ tiledOp = maybeTiled->op ;
724+ tilingResults = maybeTiled->loops .front ()->getResults ();
725+ }
726+
727+ b.eraseOp (clonedOp);
666728 }
667- b.eraseOp (clonedOp);
668729
669730 // 6. Insert the partial reductions back into a new tensor.
670- b. setInsertionPointAfter (tiledOp);
671- OpBuilder::InsertPoint insertPt = b. saveInsertionPoint ();
672- for ( auto [index, result, bbArg] :
673- llvm::zip (llvm::seq< unsigned >( 0 , dest. size ()), tiledOp-> getResults (),
674- destBbArgs)) {
675- b. setInsertionPoint (insertPt. getBlock (), insertPt. getPoint ());
731+ for ( auto [index, result, bbArg] : llvm::zip (
732+ llvm::seq< unsigned >( 0 , dest. size ()), tilingResults, destBbArgs)) {
733+ // 6.a. Partial subset information is inserted just before the terminator.
734+ OpBuilder::InsertionGuard g (b);
735+ b. setInsertionPoint (foreachThreadOp. getTerminator ());
736+
676737 SmallVector<OpFoldResult> resultOffsets, resultSizes;
677738 if (failed (tilingInterfaceOp.getResultTilePosition (
678739 b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
@@ -689,18 +750,23 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
689750 resultOffsetsRank.push_back (resultOffsets[offIdx++]);
690751 resultSizesRank.push_back (resultSizes[sizeIdx++]);
691752 }
692-
693753 SmallVector<OpFoldResult> strides (resultSizesRank.size (),
694754 b.getIndexAttr (1 ));
755+
756+ // 6.b. Parallel insertions are inserted at the end of the combining
757+ // terminator.
695758 b.setInsertionPointToEnd (foreachThreadOp.getTerminator ().getBody ());
696759 b.create <tensor::ParallelInsertSliceOp>(
697760 loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
698761 }
762+
699763 // 7. Merge the partial reductions.
700764 b.setInsertionPointAfter (foreachThreadOp);
701765 Operation *mergeOp =
702766 op.mergeReductions (b, loc, foreachThreadOp->getResults (), reductionDim);
703767 b.replaceOp (op, mergeOp->getResults ());
768+
769+ // 8. Return.
704770 ForeachThreadReductionTilingResult results;
705771 results.initialOp = identityTensor.value ();
706772 results.loops = foreachThreadOp;
0 commit comments