@@ -67,6 +67,7 @@ class FloatType : public Type {
67
67
static FloatType getFloat8E4M3FNUZ (MLIRContext *ctx);
68
68
static FloatType getFloat8E4M3B11FNUZ (MLIRContext *ctx);
69
69
static FloatType getFloat8E3M4 (MLIRContext *ctx);
70
+ static FloatType getFloat6E3M2FN (MLIRContext *ctx);
70
71
71
72
// / Methods for support type inquiry through isa, cast, and dyn_cast.
72
73
static bool classof (Type type);
@@ -415,9 +416,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
415
416
inline bool FloatType::classof (Type type) {
416
417
return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
417
418
Float8E5M2FNUZType, Float8E4M3FNUZType,
418
- Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type ,
419
- Float16Type, FloatTF32Type, Float32Type, Float64Type ,
420
- Float80Type, Float128Type>(type);
419
+ Float8E4M3B11FNUZType, Float8E3M4Type, Float6E3M2FNType ,
420
+ BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
421
+ Float64Type, Float80Type, Float128Type>(type);
421
422
}
422
423
423
424
inline FloatType FloatType::getFloat8E5M2 (MLIRContext *ctx) {
@@ -448,6 +449,10 @@ inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
448
449
return Float8E3M4Type::get (ctx);
449
450
}
450
451
452
+ inline FloatType FloatType::getFloat6E3M2FN (MLIRContext *ctx) {
453
+ return Float6E3M2FNType::get (ctx);
454
+ }
455
+
451
456
inline FloatType FloatType::getBF16 (MLIRContext *ctx) {
452
457
return BFloat16Type::get (ctx);
453
458
}
0 commit comments