Skip to content

Commit 7e147c6

Browse files
authored
Merge pull request #51 from ROCm/lwpck-1010
Additional Navi4x enablement
2 parents 9fa379e + e7e224d commit 7e147c6

File tree

14 files changed

+46
-30
lines changed

14 files changed

+46
-30
lines changed

example/02_gemm_bilinear/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
66
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
77
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
88
endif()
9-
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
9+
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
1010
set(target 1)
1111
endif()
1212
endforeach()

example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
2-
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
2+
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103)
33

44
set(target 0)
55
foreach(gpu IN LISTS GPU_TARGETS)

include/ck/ck.hpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@
104104
#define CK_USE_AMD_MFMA_GFX940
105105
#endif
106106

107-
// WMMA instruction
108-
#ifndef __HIP_DEVICE_COMPILE__ // for host code
109-
#define CK_USE_AMD_WMMA
110-
#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
111-
#define CK_USE_AMD_WMMA
112-
#endif
113-
114107
// buffer load
115108
#define CK_USE_AMD_BUFFER_LOAD 1
116109

include/ck/host_utility/device_prop.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ inline bool is_navi3_supported()
8585
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
8686
}
8787

88-
inline bool is_navi4_supported() { return ck::get_device_name() == "gfx1200"; }
88+
inline bool is_navi4_supported()
89+
{
90+
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
91+
}
8992

9093
} // namespace ck

include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
488488
// sync point.
489489
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
490490
{
491+
#ifdef __gfx12__
492+
asm volatile("\
493+
s_barrier_signal -1 \n \
494+
s_barrier_wait -1 \
495+
" ::);
496+
#else
491497
asm volatile("s_barrier" ::);
498+
#endif
492499
__builtin_amdgcn_sched_barrier(0);
493500
}
494501
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ __global__ void
7070
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
7171
const Block2CTileMap block_2_ctile_map)
7272
{
73-
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
74-
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
73+
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
74+
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
75+
defined(__gfx12__))
7576

7677
const index_t num_blocks_per_batch =
7778
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
648649
static bool IsSupportedArgument(const Argument& arg)
649650
{
650651
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
651-
ck::is_navi2_supported() || ck::is_navi3_supported())
652+
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
652653
{
653654
bool pass = true;
654655
pass = pass && arg.K_ % K1 == 0;

include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
13941394
{
13951395
// check device
13961396
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
1397-
ck::is_navi3_supported()))
1397+
ck::is_navi3_supported() || ck::is_navi4_supported()))
13981398
{
13991399
return false;
14001400
}

include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ __global__ void
5050
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
5151
const Block2CTileMap block_2_ctile_map)
5252
{
53-
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
54-
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
53+
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
54+
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
55+
defined(__gfx12__))
5556

5657
constexpr index_t shared_block_size =
5758
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
552553
static bool IsSupportedArgument(const Argument& arg)
553554
{
554555
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
555-
ck::is_navi2_supported() || ck::is_navi3_supported())
556+
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
556557
{
557558
return GridwiseGemm::CheckValidity(
558559
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ __global__ void
9090
const Block2CTileMap block_2_ctile_map,
9191
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
9292
{
93-
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
94-
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
93+
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
94+
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
95+
defined(__gfx12__))
9596
// offset base pointer for each work-group
9697
const index_t num_blocks_per_batch =
9798
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
666667

667668
// check device
668669
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
669-
ck::is_navi2_supported() || ck::is_navi3_supported()))
670+
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported()))
670671
{
671672
return false;
672673
}

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ __global__ void
107107
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
108108
{
109109
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
110-
defined(__gfx11__))
110+
defined(__gfx11__) || defined(__gfx12__))
111111
// offset base pointer for each work-group
112112
const index_t num_blocks_per_batch =
113113
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
602602

603603
// check device
604604
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
605-
ck::is_navi3_supported()))
605+
ck::is_navi3_supported() || ck::is_navi4_supported()))
606606
{
607607
return false;
608608
}

0 commit comments

Comments
 (0)