Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
034987b
start adding navi21 GEMM
j4yan Apr 13, 2022
4f5817d
navi_gemm_km_kn_mn_fp32 compiles and passes one test.
j4yan Apr 14, 2022
0d46b40
rename variables and functions in gridwise_gemm_dlops_v1r3
j4yan Apr 14, 2022
27b1c45
add other 3 layouts; format instance
j4yan Apr 14, 2022
e10a262
adding more tuning parameters
j4yan Apr 15, 2022
8450124
add gemm_dlops_f16
j4yan Apr 19, 2022
6b2ef39
tmp
j4yan Apr 20, 2022
2957754
add dependence of DeviceGemm::IsSupportedArg() on arch
j4yan Apr 20, 2022
f70ad26
minor changes
j4yan Apr 20, 2022
6baedf3
minor changes
j4yan Apr 20, 2022
1bcb8cd
minor changes
j4yan Apr 20, 2022
62a792b
minor changes
j4yan Apr 20, 2022
999321d
minor changes
j4yan Apr 20, 2022
45e9862
minor changes
j4yan Apr 20, 2022
d3f3fac
minor changes
j4yan Apr 20, 2022
e5ea6c7
push gemm_dlops into profiler
j4yan Apr 21, 2022
c695dfa
minor changes
j4yan Apr 21, 2022
fc97e9d
if using xdl or dlops is moved into profiler_gemm_impl
j4yan Apr 21, 2022
cd2ce92
minor changes
j4yan Apr 21, 2022
bf8cea0
minor changes
j4yan Apr 22, 2022
2f70506
remove is_xdl from profile_gemm_impl
j4yan Apr 22, 2022
4ba880e
make IsSupportedArg dependent on arch for other device_gemm
j4yan Apr 22, 2022
5fd0997
minor changes
j4yan Apr 22, 2022
78ade2d
minor changes
j4yan Apr 22, 2022
1d58d7e
fix a bug in f_generate_tensor_value
j4yan Apr 22, 2022
f06ba36
add 64x64x64 for gemm_dlops_int8
j4yan Apr 22, 2022
0c3f0ba
add 64x64x64 for gemm_dlops_int8
j4yan Apr 22, 2022
578eec7
comment out 3 layouts in gemm_dlops_int8; add 32x32x32 for gemm_dlops…
j4yan Apr 25, 2022
aa0acfa
fix
Apr 30, 2022
2ca774b
start fixing tuning parameters
j4yan May 3, 2022
d9cd2e5
monir
j4yan May 5, 2022
f3bd93a
minor changes
j4yan May 5, 2022
9da908f
minor changes
j4yan May 5, 2022
1ea2ef5
minor changes
j4yan May 5, 2022
90438ea
Merge remote-tracking branch 'origin/develop' into navi21_gemm
May 8, 2022
3623f9c
fixing
May 11, 2022
e95e1bf
adding example
May 12, 2022
3a122cb
adding example
May 12, 2022
0eb6b99
adding example
May 12, 2022
217b836
add gemm fp32 example
May 12, 2022
55ff2c5
Merge remote-tracking branch 'origin/develop' into navi21_gemm_v2
May 12, 2022
9eb5eb2
Merge remote-tracking branch 'origin/develop' into navi21_gemm
May 17, 2022
162ac1d
clean up
May 17, 2022
f4f890a
use 128x128x16 as MNK tile in navi21 gemm example
shaojiewang May 18, 2022
9f602fa
bug fix
May 19, 2022
15c5b67
fix test
May 20, 2022
e79f340
Merge remote-tracking branch 'origin/develop' into navi21_gemm
May 20, 2022
39131c6
use new block c tile
May 20, 2022
a838cb9
clean
May 20, 2022
a295bbf
Merge remote-tracking branch 'origin/develop' into navi21_gemm
May 23, 2022
7c7904e
fix build
May 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions example/01_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
211 changes: 211 additions & 0 deletions example/01_gemm/gemm_dl_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>

#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;

using ALayout = Col;
using BLayout = Row;
using CLayout = Row;

using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::
// ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;

int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;

// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;

if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);

M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);

StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(1);
}

auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};

Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;

switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}

DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());

a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());

auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};

// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);

if(!gemm.IsSupportedArgument(argument))
{
std::cout << "wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
<< std::endl;

return 0;
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});

std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

float gb_per_sec = num_btype / 1.E6 / ave_time;

std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;

c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());

bool pass = true;

if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();

auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);

ref_invoker.Run(ref_argument);

pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
}

return pass ? 0 : 1;
}
Loading