2020#include " mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2121#include " mlir/Dialect/Utils/IndexingUtils.h"
2222#include " mlir/Dialect/Vector/IR/VectorOps.h"
23+ #include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2324#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2425#include " mlir/IR/BuiltinTypes.h"
2526#include " mlir/IR/Operation.h"
2627#include " mlir/IR/PatternMatch.h"
28+ #include " mlir/Pass/Pass.h"
2729#include " mlir/Support/LLVM.h"
2830#include " mlir/Transforms/DialectConversion.h"
31+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2932#include " mlir/Transforms/OneToNTypeConversion.h"
3033#include " llvm/ADT/STLExtras.h"
3134#include " llvm/ADT/SmallVector.h"
3235#include " llvm/ADT/StringExtras.h"
3336#include " llvm/Support/Debug.h"
37+ #include " llvm/Support/LogicalResult.h"
3438#include " llvm/Support/MathExtras.h"
3539
3640#include < functional>
@@ -46,14 +50,6 @@ namespace {
4650// Utility functions
4751// ===----------------------------------------------------------------------===//
4852
49- static int getComputeVectorSize (int64_t size) {
50- for (int i : {4 , 3 , 2 }) {
51- if (size % i == 0 )
52- return i;
53- }
54- return 1 ;
55- }
56-
5753static std::optional<SmallVector<int64_t >> getTargetShape (VectorType vecType) {
5854 LLVM_DEBUG (llvm::dbgs () << " Get target shape\n " );
5955 if (vecType.isScalable ()) {
@@ -62,8 +58,8 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
6258 return std::nullopt ;
6359 }
6460 SmallVector<int64_t > unrollShape = llvm::to_vector<4 >(vecType.getShape ());
65- std::optional<SmallVector<int64_t >> targetShape =
66- SmallVector< int64_t >( 1 , getComputeVectorSize (vecType.getShape ().back ()));
61+ std::optional<SmallVector<int64_t >> targetShape = SmallVector< int64_t >(
62+ 1 , mlir::spirv:: getComputeVectorSize (vecType.getShape ().back ()));
6763 if (!targetShape) {
6864 LLVM_DEBUG (llvm::dbgs () << " --no unrolling target shape defined\n " );
6965 return std::nullopt ;
@@ -1098,13 +1094,20 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
10981094 // the original operand of illegal type.
10991095 auto originalShape =
11001096 llvm::to_vector_of<int64_t , 4 >(origVecType.getShape ());
1101- SmallVector<int64_t > strides (targetShape->size (), 1 );
1097+ SmallVector<int64_t > strides (originalShape.size (), 1 );
1098+ SmallVector<int64_t > extractShape (originalShape.size (), 1 );
1099+ extractShape.back () = targetShape->back ();
11021100 SmallVector<Type> newTypes;
11031101 Value returnValue = returnOp.getOperand (origResultNo);
11041102 for (SmallVector<int64_t > offsets :
11051103 StaticTileOffsetRange (originalShape, *targetShape)) {
11061104 Value result = rewriter.create <vector::ExtractStridedSliceOp>(
1107- loc, returnValue, offsets, *targetShape, strides);
1105+ loc, returnValue, offsets, extractShape, strides);
1106+ if (originalShape.size () > 1 ) {
1107+ SmallVector<int64_t > extractIndices (originalShape.size () - 1 , 0 );
1108+ result =
1109+ rewriter.create <vector::ExtractOp>(loc, result, extractIndices);
1110+ }
11081111 newOperands.push_back (result);
11091112 newTypes.push_back (unrolledType);
11101113 }
@@ -1285,6 +1288,118 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
12851288 builder);
12861289}
12871290
1291+ // ===----------------------------------------------------------------------===//
1292+ // Public functions for vector unrolling
1293+ // ===----------------------------------------------------------------------===//
1294+
1295+ int mlir::spirv::getComputeVectorSize (int64_t size) {
1296+ for (int i : {4 , 3 , 2 }) {
1297+ if (size % i == 0 )
1298+ return i;
1299+ }
1300+ return 1 ;
1301+ }
1302+
1303+ SmallVector<int64_t >
1304+ mlir::spirv::getNativeVectorShapeImpl (vector::ReductionOp op) {
1305+ VectorType srcVectorType = op.getSourceVectorType ();
1306+ assert (srcVectorType.getRank () == 1 ); // Guaranteed by semantics
1307+ int64_t vectorSize =
1308+ mlir::spirv::getComputeVectorSize (srcVectorType.getDimSize (0 ));
1309+ return {vectorSize};
1310+ }
1311+
1312+ SmallVector<int64_t >
1313+ mlir::spirv::getNativeVectorShapeImpl (vector::TransposeOp op) {
1314+ VectorType vectorType = op.getResultVectorType ();
1315+ SmallVector<int64_t > nativeSize (vectorType.getRank (), 1 );
1316+ nativeSize.back () =
1317+ mlir::spirv::getComputeVectorSize (vectorType.getShape ().back ());
1318+ return nativeSize;
1319+ }
1320+
1321+ std::optional<SmallVector<int64_t >>
1322+ mlir::spirv::getNativeVectorShape (Operation *op) {
1323+ if (OpTrait::hasElementwiseMappableTraits (op) && op->getNumResults () == 1 ) {
1324+ if (auto vecType = dyn_cast<VectorType>(op->getResultTypes ()[0 ])) {
1325+ SmallVector<int64_t > nativeSize (vecType.getRank (), 1 );
1326+ nativeSize.back () =
1327+ mlir::spirv::getComputeVectorSize (vecType.getShape ().back ());
1328+ return nativeSize;
1329+ }
1330+ }
1331+
1332+ return TypeSwitch<Operation *, std::optional<SmallVector<int64_t >>>(op)
1333+ .Case <vector::ReductionOp, vector::TransposeOp>(
1334+ [](auto typedOp) { return getNativeVectorShapeImpl (typedOp); })
1335+ .Default ([](Operation *) { return std::nullopt ; });
1336+ }
1337+
1338+ LogicalResult mlir::spirv::unrollVectorsInSignatures (Operation *op) {
1339+ MLIRContext *context = op->getContext ();
1340+ RewritePatternSet patterns (context);
1341+ populateFuncOpVectorRewritePatterns (patterns);
1342+ populateReturnOpVectorRewritePatterns (patterns);
1343+ // We only want to apply signature conversion once to the existing func ops.
1344+ // Without specifying strictMode, the greedy pattern rewriter will keep
1345+ // looking for newly created func ops.
1346+ GreedyRewriteConfig config;
1347+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
1348+ return applyPatternsAndFoldGreedily (op, std::move (patterns), config);
1349+ }
1350+
1351+ LogicalResult mlir::spirv::unrollVectorsInFuncBodies (Operation *op) {
1352+ MLIRContext *context = op->getContext ();
1353+
1354+ // Unroll vectors in function bodies to native vector size.
1355+ {
1356+ RewritePatternSet patterns (context);
1357+ auto options = vector::UnrollVectorOptions ().setNativeShapeFn (
1358+ [](auto op) { return mlir::spirv::getNativeVectorShape (op); });
1359+ populateVectorUnrollPatterns (patterns, options);
1360+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
1361+ return failure ();
1362+ }
1363+
1364+ // Convert transpose ops into extract and insert pairs, in preparation of
1365+ // further transformations to canonicalize/cancel.
1366+ {
1367+ RewritePatternSet patterns (context);
1368+ auto options = vector::VectorTransformsOptions ().setVectorTransposeLowering (
1369+ vector::VectorTransposeLowering::EltWise);
1370+ vector::populateVectorTransposeLoweringPatterns (patterns, options);
1371+ vector::populateVectorShapeCastLoweringPatterns (patterns);
1372+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
1373+ return failure ();
1374+ }
1375+
1376+ // Run canonicalization to cast away leading size-1 dimensions.
1377+ {
1378+ RewritePatternSet patterns (context);
1379+
1380+ // We need to pull in casting way leading one dims.
1381+ vector::populateCastAwayVectorLeadingOneDimPatterns (patterns);
1382+ vector::ReductionOp::getCanonicalizationPatterns (patterns, context);
1383+ vector::TransposeOp::getCanonicalizationPatterns (patterns, context);
1384+
1385+ // Decompose different rank insert_strided_slice and n-D
1386+ // extract_slided_slice.
1387+ vector::populateVectorInsertExtractStridedSliceDecompositionPatterns (
1388+ patterns);
1389+ vector::InsertOp::getCanonicalizationPatterns (patterns, context);
1390+ vector::ExtractOp::getCanonicalizationPatterns (patterns, context);
1391+
1392+ // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1393+ // them up.
1394+ vector::BroadcastOp::getCanonicalizationPatterns (patterns, context);
1395+ vector::ShapeCastOp::getCanonicalizationPatterns (patterns, context);
1396+
1397+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
1398+ return failure ();
1399+ }
1400+ return success ();
1401+ }
1402+
12881403// ===----------------------------------------------------------------------===//
12891404// SPIR-V TypeConverter
12901405// ===----------------------------------------------------------------------===//
0 commit comments