1010//
1111//===----------------------------------------------------------------------===//
1212
13- #ifndef VECTOR_OPS
14- #define VECTOR_OPS
13+ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
14+ #define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
1515
16+ include "mlir/Dialect/Vector/IR/Vector.td"
17+ include "mlir/Dialect/Vector/IR/VectorAttributes.td"
18+ include "mlir/Dialect/Arith/IR/ArithBase.td"
19+ include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
1620include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
1721include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
1822include "mlir/IR/EnumAttr.td"
@@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2327include "mlir/Interfaces/VectorInterfaces.td"
2428include "mlir/Interfaces/ViewLikeInterface.td"
2529
26- def Vector_Dialect : Dialect {
27- let name = "vector";
28- let cppNamespace = "::mlir::vector";
29-
30- let useDefaultAttributePrinterParser = 1;
31- let hasConstantMaterializer = 1;
32- let dependentDialects = ["arith::ArithDialect"];
33- }
34-
35- // Base class for Vector dialect ops.
36- class Vector_Op<string mnemonic, list<Trait> traits = []> :
37- Op<Vector_Dialect, mnemonic, traits>;
38-
39- // The "kind" of combining function for contractions and reductions.
40- def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
41- def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
42- def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
43- def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
44- def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
45- def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
46- def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
47- def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
48- def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
49- def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
50- def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
51- def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
52- def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;
53-
54- def CombiningKind : I32BitEnumAttr<
55- "CombiningKind",
56- "Kind of combining function for contractions and reductions",
57- [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
58- COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
59- COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
60- COMBINING_KIND_OR, COMBINING_KIND_XOR,
61- COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
62- let cppNamespace = "::mlir::vector";
63- let genSpecializedAttr = 0;
64- }
65-
66- /// An attribute that specifies the combining function for `vector.contract`,
67- /// and `vector.reduction`.
68- def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
69- let assemblyFormat = "`<` $value `>`";
70- }
71-
72- def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
73- I32EnumAttrCase<"parallel", 0>,
74- I32EnumAttrCase<"reduction", 1>
75- ]> {
76- let genSpecializedAttr = 0;
77- let cppNamespace = "::mlir::vector";
78- }
79-
80- def Vector_IteratorTypeEnum
81- : EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
82- let assemblyFormat = "`<` $value `>`";
83- }
84-
85- def Vector_IteratorTypeArrayAttr
86- : TypedArrayAttrBase<Vector_IteratorTypeEnum,
87- "Iterator type should be an enum.">;
88-
8930// TODO: Add an attribute to specify a different algebra with operators other
9031// than the current set: {*, +}.
9132def Vector_ContractionOp :
@@ -274,12 +215,16 @@ def Vector_ReductionOp :
274215 Vector_Op<"reduction", [Pure,
275216 PredOpTrait<"source operand and result have same element type",
276217 TCresVTEtIsSameAsOpBase<0, 0>>,
218+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
277219 DeclareOpInterfaceMethods<MaskableOpInterface>,
278- DeclareOpInterfaceMethods<VectorUnrollOpInterface,
279- ["getShapeForUnroll"]> ]>,
220+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
221+ ]>,
280222 Arguments<(ins Vector_CombiningKindAttr:$kind,
281223 AnyVectorOfAnyRank:$vector,
282- Optional<AnyType>:$acc)>,
224+ Optional<AnyType>:$acc,
225+ DefaultValuedAttr<
226+ Arith_FastMathAttr,
227+ "::mlir::arith::FastMathFlags::none">:$fastmath)>,
283228 Results<(outs AnyType:$dest)> {
284229 let summary = "reduction operation";
285230 let description = [{
@@ -309,9 +254,13 @@ def Vector_ReductionOp :
309254 }];
310255 let builders = [
311256 // Builder that infers the type of `dest`.
312- OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>,
257+ OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc,
258+ CArg<"::mlir::arith::FastMathFlags",
259+ "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>,
313260 // Builder that infers the type of `dest` and has no accumulator.
314- OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)>
261+ OpBuilder<(ins "CombiningKind":$kind, "Value":$vector,
262+ CArg<"::mlir::arith::FastMathFlags",
263+ "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
315264 ];
316265
317266 // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
@@ -2469,22 +2418,6 @@ def Vector_TransposeOp :
24692418 let hasVerifier = 1;
24702419}
24712420
2472- def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
2473- "Punctuation for separating vectors or vector elements", [
2474- I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
2475- I32EnumAttrCase<"NewLine", 1, "newline">,
2476- I32EnumAttrCase<"Comma", 2, "comma">,
2477- I32EnumAttrCase<"Open", 3, "open">,
2478- I32EnumAttrCase<"Close", 4, "close">
2479- ]> {
2480- let cppNamespace = "::mlir::vector";
2481- let genSpecializedAttr = 0;
2482- }
2483-
2484- def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
2485- let assemblyFormat = "`<` $value `>`";
2486- }
2487-
24882421def Vector_PrintOp :
24892422 Vector_Op<"print", []>,
24902423 Arguments<(ins Optional<Type<Or<[
@@ -2939,4 +2872,4 @@ def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
29392872 }];
29402873}
29412874
2942- #endif // VECTOR_OPS
2875+ #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
0 commit comments