Skip to content

Commit 78f637e

Browse files
authored
Merge pull request #58 from ROCm/navi4x_conv_fwd
Navi4x Conv and MHA enablement
2 parents 7e147c6 + 5cb59d3 commit 78f637e

15 files changed

+66
-78
lines changed

example/20_grouped_conv_bwd_weight/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
2-
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
2+
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1200)
33
set(target 0)
44
foreach(gpu IN LISTS GPU_TARGETS)
55
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)

example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
2-
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103)
2+
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200)
33

44
set(target 0)
55
foreach(gpu IN LISTS GPU_TARGETS)

example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if(GPU_TARGETS MATCHES "gfx11")
1+
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
22
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
33
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
44
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)

example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
2-
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
2+
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1200)
33
set(target 0)
44
foreach(gpu IN LISTS GPU_TARGETS)
55
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)

include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp

Lines changed: 34 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA
7070
static constexpr index_t A_KRow = 2;
7171
static constexpr index_t B_KRow = 2;
7272

73+
static constexpr index_t A_KRow_ = AEnableLds ? 1 : 2;
74+
static constexpr index_t B_KRow_ = BEnableLds ? 1 : 2;
75+
7376
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
7477
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
7578

@@ -191,9 +194,6 @@ struct BlockwiseGemmWMMA
191194
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
192195
NPerBlock % (NPerWMMA * NRepeat) == 0,
193196
"wrong!");
194-
195-
static_assert(AEnableLds == true, "only support EnableLds");
196-
static_assert(BEnableLds == true, "only support EnableLds");
197197
}
198198

199199
// transposed WMMA output C' = B' * A'
@@ -316,7 +316,7 @@ struct BlockwiseGemmWMMA
316316
// read A
317317
a_thread_copy_.Run(
318318
a_block_desc_k0_m0_m1_m2_k1,
319-
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0),
319+
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
320320
a_block_buf,
321321
a_thread_desc_,
322322
make_tuple(I0, m0, I0, I0, I0, I0),
@@ -326,7 +326,8 @@ struct BlockwiseGemmWMMA
326326
// read B
327327
b_thread_copy_.Run(
328328
b_block_desc_k0_n0_n1_n2_k1,
329-
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0),
329+
make_tuple(
330+
Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
330331
b_block_buf,
331332
b_thread_desc_,
332333
make_tuple(I0, n0, I0, I0, I0, I0),
@@ -372,15 +373,15 @@ struct BlockwiseGemmWMMA
372373
// read B
373374
b_thread_copy_.Run(
374375
b_block_desc_k0_n0_n1_n2_k1,
375-
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0),
376+
make_tuple(Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
376377
b_block_buf,
377378
b_thread_desc_,
378379
make_tuple(I0, n0, I0, I0, I0, I0),
379380
b_thread_buf);
380381
// read A
381382
a_thread_copy_.Run(
382383
a_block_desc_k0_m0_m1_m2_k1,
383-
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0),
384+
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
384385
a_block_buf,
385386
a_thread_desc_,
386387
make_tuple(I0, m0, I0, I0, I0, I0),
@@ -442,44 +443,30 @@ struct BlockwiseGemmWMMA
442443
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
443444
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
444445

445-
template <bool EnableLds>
446-
struct AThreadCopySelector;
447-
448-
template <>
449-
struct AThreadCopySelector<true>
450-
{
451-
using type =
452-
ThreadwiseTensorSliceTransfer_v4<FloatA,
453-
FloatA,
454-
decltype(a_block_desc_k0_m0_m1_m2_k1),
455-
decltype(a_thread_desc_),
456-
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
457-
Sequence<0, 1, 2, 3, 4, 5>,
458-
5,
459-
A_K1,
460-
A_K1>;
461-
};
462-
463-
template <bool EnableLds>
464-
struct BThreadCopySelector;
465-
466-
template <>
467-
struct BThreadCopySelector<true>
468-
{
469-
using type =
470-
ThreadwiseTensorSliceTransfer_v4<FloatB,
471-
FloatB,
472-
decltype(b_block_desc_k0_n0_n1_n2_k1),
473-
decltype(b_thread_desc_),
474-
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
475-
Sequence<0, 1, 2, 3, 4, 5>,
476-
5,
477-
B_K1,
478-
B_K1>;
479-
};
480-
481-
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
482-
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
446+
using AThreadCopyType =
447+
ThreadwiseTensorSliceTransfer_v4<FloatA,
448+
FloatA,
449+
decltype(a_block_desc_k0_m0_m1_m2_k1),
450+
decltype(a_thread_desc_),
451+
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
452+
Sequence<0, 1, 2, 3, 4, 5>,
453+
5,
454+
A_K1,
455+
A_K1>;
456+
457+
using BThreadCopyType =
458+
ThreadwiseTensorSliceTransfer_v4<FloatB,
459+
FloatB,
460+
decltype(b_block_desc_k0_n0_n1_n2_k1),
461+
decltype(b_thread_desc_),
462+
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
463+
Sequence<0, 1, 2, 3, 4, 5>,
464+
5,
465+
B_K1,
466+
B_K1>;
467+
468+
AThreadCopyType a_thread_copy_;
469+
BThreadCopyType b_thread_copy_;
483470
};
484471
#else
485472
template <index_t BlockSize,
@@ -537,9 +524,8 @@ struct BlockwiseGemmWMMA
537524
// permutation
538525
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
539526
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
540-
541-
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
542-
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
527+
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
528+
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
543529

544530
static constexpr auto wmma_gemm =
545531
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};

include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
829829

830830
static bool IsSupportedArgument(const Argument& arg)
831831
{
832-
if(ck::is_navi3_supported())
832+
if(ck::is_navi3_supported() || ck::is_navi4_supported())
833833
{
834834
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
835835
{

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ __global__ void
5656
bool input_permute,
5757
bool output_permute)
5858
{
59-
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
59+
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
6060

6161
// clang-format off
6262
// ***************************************************
@@ -159,6 +159,7 @@ __global__ void
159159
ignore = O;
160160
ignore = G0;
161161
ignore = G1;
162+
ignore = alpha;
162163
ignore = input_permute;
163164
ignore = output_permute;
164165
#endif // end of if (defined(__gfx11__))
@@ -187,7 +188,7 @@ __global__ void
187188
index_t head_size,
188189
float alpha)
189190
{
190-
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
191+
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
191192

192193
// clang-format off
193194
// ***************************************************
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
858859

859860
static bool IsSupportedArgument(const RawArg& arg)
860861
{
861-
if(ck::is_navi3_supported())
862+
if(ck::is_navi3_supported() || ck::is_navi4_supported())
862863
{
863864
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
864865
{

include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
9494
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
9595

9696
// If true, LDS is used unconditionally
97-
static constexpr auto AEnableLds_manu = true;
98-
static constexpr auto BEnableLds_manu = true;
97+
static constexpr auto AEnableLds_manu = false;
98+
static constexpr auto BEnableLds_manu = false;
9999

100100
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
101101
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
629629
static bool IsSupportedArgument(const Argument& arg)
630630
{
631631
// check device
632-
if(ck::is_navi3_supported())
632+
if(ck::is_navi3_supported() || ck::is_navi4_supported())
633633
{
634634
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
635635
{

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
702702
static bool IsSupportedArgument(const Argument& arg)
703703
{
704704
// check device
705-
if(ck::is_navi3_supported())
705+
if(ck::is_navi3_supported() || ck::is_navi4_supported())
706706
{
707707
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
708708
{

0 commit comments

Comments
 (0)