Skip to content

Commit 9fa379e

Browse files
authored
Merge pull request #47 from ROCm/navi4x_wmma
Navi4x wmma GEMM
2 parents 9de6359 + 9a9cb88 commit 9fa379e

22 files changed

+783
-76
lines changed

CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ else()
118118
add_definitions(-DPROFILER_ONLY)
119119
set(GPU_TARGETS "" CACHE STRING "" FORCE)
120120
if(GPU_TARGETS)
121-
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11")
121+
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
122122
endif()
123123
if(GPU_ARCH MATCHES "gfx90")
124124
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
@@ -128,8 +128,10 @@ else()
128128
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
129129
elseif(GPU_ARCH MATCHES "gfx11")
130130
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
131+
elseif(GPU_ARCH MATCHES "gfx12")
132+
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
131133
else()
132-
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11")
134+
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
133135
endif()
134136
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
135137
endif()

cmake/EnableCompilerWarnings.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ else()
6666
-Wunreachable-code
6767
-Wunused
6868
-Wno-reserved-identifier
69-
-Werror
69+
-Werror
7070
-Wno-option-ignored
7171
-Wsign-compare
7272
-Wno-extra-semi-stmt

example/01_gemm/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
2727

2828
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
2929
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
30-
if(GPU_TARGETS MATCHES "gfx11")
30+
31+
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
3132
add_custom_target(example_gemm_wmma)
3233
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
3334
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
@@ -74,4 +75,3 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
7475

7576
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
7677
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
77-

example/01_gemm/gemm_wmma_fp16.cpp

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,51 +19,48 @@ using AElementOp = PassThrough;
1919
using BElementOp = PassThrough;
2020
using CElementOp = PassThrough;
2121

22-
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
22+
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
2323

24-
// clang-format off
25-
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
26-
< ALayout,
27-
BLayout,
28-
CLayout,
29-
ADataType,
30-
BDataType,
31-
CDataType,
32-
AccDataType,
33-
CShuffleDataType,
34-
AElementOp,
35-
BElementOp,
36-
CElementOp,
37-
GemmDefault,
38-
1, // Prefetch stage
39-
128, // BlockSize
40-
64, // MPerBlock
41-
128, // NPerBlock
42-
64, // KPerBlock
43-
8, // K1
44-
16, // MPerWmma
45-
16, // NPerWmma
46-
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
47-
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
48-
S<4, 32, 1>,
49-
S<1, 0, 2>,
50-
S<1, 0, 2>,
51-
2,
52-
8,
53-
8,
54-
true,
55-
S<4, 32, 1>,
56-
S<1, 0, 2>,
57-
S<1, 0, 2>,
58-
2,
59-
8,
60-
8,
61-
true,
62-
1, // C shuffle (M Repeat) Per store
63-
1, // C shuffle (N Repeat) Per store
64-
S<1, 32, 1, 4>,
65-
8>;
66-
// clang-format on
24+
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout,
25+
BLayout,
26+
CLayout,
27+
ADataType,
28+
BDataType,
29+
CDataType,
30+
AccDataType,
31+
CShuffleDataType,
32+
AElementOp,
33+
BElementOp,
34+
CElementOp,
35+
GemmDefault,
36+
1,
37+
32,
38+
16,
39+
32,
40+
64,
41+
8,
42+
16,
43+
16,
44+
1,
45+
2,
46+
S<2, 16, 1>,
47+
S<1, 0, 2>,
48+
S<1, 0, 2>,
49+
2,
50+
8,
51+
8,
52+
true,
53+
S<2, 16, 1>,
54+
S<1, 0, 2>,
55+
S<1, 0, 2>,
56+
2,
57+
8,
58+
8,
59+
true,
60+
1,
61+
1,
62+
S<1, 16, 1, 2>,
63+
8>;
6764

6865
using ReferenceGemmInstance = ck::tensor_operation::host::
6966
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;

example/02_gemm_bilinear/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
1+
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
22
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
33
set(target 0)
44
foreach(gpu IN LISTS GPU_TARGETS)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
22

3-
if(GPU_TARGETS MATCHES "gfx11")
3+
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
44
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
55
endif()

example/64_fpAintB_gemm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if(GPU_TARGETS MATCHES "gfx11")
1+
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
22
add_custom_target(example_fpAintB_gemm_wmma)
33
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
44
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)

include/ck/ck.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@
5858
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
5959
#define __gfx11__
6060
#endif
61+
#if defined(__gfx1200__) || defined(__gfx1201__)
62+
#define __gfx12__
63+
#endif
6164

6265
// buffer resource
6366
#ifndef __HIP_DEVICE_COMPILE__ // for host code
@@ -67,7 +70,7 @@
6770
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
6871
#elif defined(__gfx103__)
6972
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
70-
#elif defined(__gfx11__)
73+
#elif defined(__gfx11__) || defined(__gfx12__)
7174
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
7275
#endif
7376

@@ -80,7 +83,7 @@
8083
#define CK_USE_AMD_V_FMAC_F32
8184
#define CK_USE_AMD_V_DOT2_F32_F16
8285
#define CK_USE_AMD_V_DOT4_I32_I8
83-
#elif defined(__gfx11__)
86+
#elif defined(__gfx11__) || defined(__gfx12__)
8487
#define CK_USE_AMD_V_FMAC_F32
8588
#define CK_USE_AMD_V_DOT2_F32_F16
8689
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
@@ -104,7 +107,7 @@
104107
// WMMA instruction
105108
#ifndef __HIP_DEVICE_COMPILE__ // for host code
106109
#define CK_USE_AMD_WMMA
107-
#elif defined(__gfx11__) // for GPU code
110+
#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
108111
#define CK_USE_AMD_WMMA
109112
#endif
110113

include/ck/host_utility/device_prop.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,6 @@ inline bool is_navi3_supported()
8585
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
8686
}
8787

88+
inline bool is_navi4_supported() { return ck::get_device_name() == "gfx1200"; }
89+
8890
} // namespace ck

0 commit comments

Comments
 (0)