Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,50 +308,50 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
AttrNumAccess>>;

template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
Expand Down
184 changes: 51 additions & 133 deletions include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1527,15 +1527,15 @@ using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>;

template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base
struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = AType_;
using BDataType = BType_;
using CDataType = float;

using AVecType = ext_vector_t<ADataType, 32>;
using BVecType = ext_vector_t<BDataType, 32>;
using AVecType = ext_vector_t<ADataType, 32 / numeric_traits<ADataType>::PackedSize>;
using BVecType = ext_vector_t<BDataType, 32 / numeric_traits<BDataType>::PackedSize>;
using CVecType = ext_vector_t<CDataType, 4>;

static constexpr index_t kM = 16;
Expand All @@ -1556,163 +1556,81 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base
static constexpr index_t kCM1PerLane = 4;

// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
// opsel, scale_b)
#if defined(__gfx950__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
auto dtype2conf = [](auto dtype) {
if constexpr(std::is_same_v<decltype(dtype), fp8_t>)
return make_tuple(number<0>{}, int32x8_t{});
else if constexpr(std::is_same_v<decltype(dtype), bf8_t>)
return make_tuple(number<1>{}, int32x8_t{});
// else if e2m3 => make_tuple(number<2>{}, int32x6_t{})
// else if e3m2 => make_tuple(number<3>{}, int32x6_t{})
else if constexpr(std::is_same_v<decltype(dtype), pk_fp4_t>)
return make_tuple(number<4>{}, int32x4_t{});
else
static_assert(false, "Unsupported data type for mfma scale");
};
auto dtype2code = [&](auto dtype) { return dtype2conf(dtype)(number<0>{}); };
auto dtype2vec = [&](auto dtype) { return dtype2conf(dtype)(number<1>{}); };
auto arg256 = [&](auto x) {
if constexpr(sizeof(x) == 16)
return int32x8_t{x[0], x[1], x[2], x[3], 0, 0, 0, 0};
else if constexpr(sizeof(x) == 24)
return int32x8_t{x[0], x[1], x[2], x[3], x[4], x[5], 0, 0};
else if constexpr(sizeof(x) == 32)
return x;
else
static_assert(false, "Unexpected vector size for mfma scale");
};

auto arg_a = bit_cast<decltype(dtype2vec(ADataType{}))>(a_vec);
auto arg_b = bit_cast<decltype(dtype2vec(BDataType{}))>(b_vec);
constexpr int cbsz = decltype(dtype2code(ADataType{}))::value;
constexpr int blgp = decltype(dtype2code(BDataType{}))::value;
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg256(arg_a), arg256(arg_b), c_vec, cbsz, blgp, opselA, a_scale, opselB, b_scale);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_scale;
#endif
}

// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
template <index_t opselA, index_t opselB>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
{
#if defined(__gfx950__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
CVecType c_vec{0.f};
operator()<opselA, opselB>(c_vec, a_vec, a_scale, b_vec, b_scale);
return c_vec;
}
};

template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, fp8_t, Ctrl_>;

template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 =
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, bf8_t, Ctrl_>;

template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 =
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, fp8_t, Ctrl_>;

template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;

template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = pk_fp4_t;
using BDataType = pk_fp4_t;
using CDataType = float;

using AVecType = ext_vector_t<ADataType, 16>;
using BVecType = ext_vector_t<BDataType, 16>;
using CVecType = ext_vector_t<CDataType, 4>;

static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 128;

static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;

static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 32;

static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;

// c_vec += a_vec * b_vec
template <index_t opselA, index_t opselB, bool post_nop_ = false>
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
// opsel, scale_b)
#if defined(__gfx950__)
auto arg_a = bit_cast<int32x4_t>(a_vec);
auto arg_b = bit_cast<int32x4_t>(b_vec);
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
c_vec,
4,
4,
opselA,
a_scale,
opselB,
b_scale);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_scale;
#endif
operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0);
}

// c_vec = a_vec * b_vec
template <index_t opselA, index_t opselB>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx950__)
auto arg_a = bit_cast<int32x4_t>(a_vec);
auto arg_b = bit_cast<int32x4_t>(b_vec);
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
CVecType{0.f},
4,
4,
opselA,
a_scale,
opselB,
b_scale));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_scale;
return CVecType{0.f};
#endif
return operator()<0, 0>(a_vec, 0, b_vec, 0);
}
};

Expand Down
Loading