diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 696d3bac42..a0fe1fe2fa 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp new file mode 100644 index 0000000000..6e8e04f9e5 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -0,0 +1,211 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp new file mode 100644 index 0000000000..65c806bf07 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -0,0 +1,210 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp new file mode 100644 index 0000000000..a9590030c7 --- /dev/null +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 8461ebb76f..e595ca2333 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp new file mode 100644 index 0000000000..74b20acecd --- /dev/null +++ b/include/ck/host_utility/device_prop.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +namespace ck { + +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + + // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + static std::map device_name_map = { + {"Ellesmere", "gfx803"}, + {"Baffin", "gfx803"}, + {"RacerX", "gfx803"}, + {"Polaris10", "gfx803"}, + {"Polaris11", "gfx803"}, + {"Tonga", "gfx803"}, + {"Fiji", "gfx803"}, + {"gfx800", "gfx803"}, + {"gfx802", "gfx803"}, + {"gfx804", "gfx803"}, + {"Vega10", "gfx900"}, + {"gfx901", "gfx900"}, + {"10.3.0 Sienna_Cichlid 18", "gfx1030"}, + }; + + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + + auto match = device_name_map.find(name); + if(match != device_name_map.end()) + return match->second; + return name; +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 0a7b8486f4..f7fa867e16 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,10 +1,8 @@ -#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP -#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP - +#pragma once #include "common_header.hpp" #include "tensor_adaptor.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_contraction_dlops.hpp" +#include "threadwise_tensor_slice_transfer_v4r1.hpp" +#include "threadwise_contraction_dl.hpp" namespace ck { @@ -41,7 +39,7 @@ template ::type = false> -struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); public: - __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + __device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id())}, a_thread_copy_{ @@ -175,6 +173,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B "wrong!"); // TODO: remove this restriction + static_assert(BM0 == 2, "wrong"); static_assert(BM0 == 2 && BN0 == 2, "wrong"); } @@ -226,7 +225,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); constexpr auto threadwise_contraction = - ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< FloatA, FloatB, FloatC, @@ -407,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 93fe5da723..e8ec164364 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -75,14 +75,13 @@ struct BlockwiseTensorSliceTransfer_v5r1 } } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf); } } diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp new file mode 100644 index 0000000000..a6a059df77 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -0,0 +1,586 @@ +#pragma once + +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gemm_specialization.hpp" +#include "element_wise_operation.hpp" +#include "gridwise_gemm_dl_v1r3.hpp" +#include "device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDl + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused, but may be useful in future. + index_t M01_; + index_t N01_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmDl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmDl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index 819aa8f390..31f354358f 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -1,5 +1,4 @@ -#ifndef DEVICE_GEMM_XDL_HPP -#define DEVICE_GEMM_XDL_HPP +#pragma once #include #include @@ -12,6 +11,7 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" #include "gemm_specialization.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { @@ -408,6 +408,11 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, @@ -515,4 +520,3 @@ struct DeviceGemmXdl } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index e4a3a8e153..a74ee81679 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -9,6 +9,7 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "tensor_operation/gpu/device/gemm_specialization.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { @@ -558,6 +559,11 @@ struct DeviceGemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp index 97ca8e2f92..d9fc8f7a8a 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -12,6 +12,7 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r4.hpp" #include "gemm_specialization.hpp" +#include "device_prop.hpp" #ifndef CK_RUN_KERNEL_AND_TIME #define CK_RUN_KERNEL_AND_TIME 1 @@ -528,6 +529,11 @@ struct DeviceGemmXdlSplitK static bool IsSupportedArgument(const Argument& arg) { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp similarity index 57% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 1a66c8ff3f..3b5daf6ead 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,38 +1,38 @@ -#ifndef CK_GRIDWISE_GEMM_V1R3_HPP -#define CK_GRIDWISE_GEMM_V1R3_HPP +#pragma once #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r3.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_dl_v2r3.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_set.hpp" +#include "element_wise_operation.hpp" namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dlops_v1r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -43,10 +43,10 @@ __global__ void p_b_grid, p_c_grid, p_shared_block, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, integral_constant{}, integral_constant{}); } @@ -56,12 +56,12 @@ template -struct GridwiseGemmDlops_km_kn_mn_v1r3 + index_t CThreadTransferDstScalarPerVector> +struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 static constexpr auto I3 = Number<3>{}; // K1 should be Number<...> - static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2); __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); } __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); } __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { - const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); return grid_size; } __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) { - const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; return has_main_k_block_loop; } __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) { - const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; return has_double_tail_k_block_loop; } __host__ __device__ static constexpr auto - MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) { - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto M1 = Number{}; + const auto M1 = Number{}; const auto M0 = M / M1; - const auto a_k0_m0_m1_k1_grid_desc = - transform_tensor_descriptor(a_k0_m_k1_grid_desc, + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, make_tuple(make_pass_through_transform(K0), make_unmerge_transform(make_tuple(M0, M1)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return a_k0_m0_m1_k1_grid_desc; + return a_grid_desc_k0_m0_m1_k1; } __host__ __device__ static constexpr auto - MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) { - const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto N1 = Number{}; + const auto N1 = Number{}; const auto N0 = N / N1; - const auto b_k0_n0_n1_k1_grid_desc = - transform_tensor_descriptor(b_k0_n_k1_grid_desc, + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, make_tuple(make_pass_through_transform(K0), make_unmerge_transform(make_tuple(N0, N1)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return b_k0_n0_n1_k1_grid_desc; + return b_grid_desc_k0_n0_n1_k1; } __host__ __device__ static constexpr auto - MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; const auto M0 = M / M1; const auto N0 = N / N1; @@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 constexpr auto M10 = M1 / M11; constexpr auto N10 = N1 / N11; - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_unmerge_transform(make_tuple(N0, N10, N11))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - return c_m0_m10_m11_n0_n10_n11_grid_desc; + return c_grid_desc_m0_m10_m11_n0_n10_n11; } + // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto - MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); } - using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); - using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); - using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); template __device__ static void @@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, FloatAB* __restrict__ p_shared_block, - const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, integral_constant, integral_constant) { const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); // divide block work by [M, N] const auto c_m0_n0_block_cluster_idx = - cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); // HACK: this force index data into SGPR const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + // TODO: change this. I think it needs multi-dimensional alignment constexpr auto max_lds_align = K1; // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // A matrix in LDS memory, for blockwise GEMM constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, for blockwise GEMM constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); - static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == a_k0_m_k1_block_desc.GetElementSpaceSize() && - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == b_k0_n_k1_block_desc.GetElementSpaceSize() && "wrong!"); @@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(a_k0_m0_m1_k1_grid_desc), - decltype(a_k0_m0_m1_k1_block_desc), + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), ABlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths @@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder false, - true>(a_k0_m0_m1_k1_grid_desc, + true>(a_grid_desc_k0_m0_m1_k1, make_multi_index(0, im0, 0, 0), - a_k0_m0_m1_k1_block_desc, + a_block_desc_k0_m0_m1_k1, make_multi_index(0, 0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(b_k0_n0_n1_k1_grid_desc), - decltype(b_k0_n0_n1_k1_block_desc), + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), BBlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths @@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder false, - true>(b_k0_n0_n1_k1_grid_desc, + true>(b_grid_desc_k0_n0_n1_k1, make_multi_index(0, in0, 0, 0), - b_k0_n0_n1_k1_block_desc, + b_block_desc_k0_n0_n1_k1, make_multi_index(0, 0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlockM1] is in LDS - // b_mtx[KPerBlocl, NPerBlockN1] is in LDS - // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockSize, FloatAB, FloatAB, @@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); FloatAB* p_a_block_double = p_shared_block; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output auto c_thread_buf = make_static_buffer( - c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); - ThreadwiseTensorSliceSet_v1{} - .Run(c_m10_m11_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); + // Initialize C + c_thread_buf.Clear(); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); // LDS double buffer: preload data into LDS { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); } if constexpr(HasMainKBlockLoop) { - const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); index_t k_block_data_begin = 0; @@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 do { // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); // LDS double buffer: GEMM on current data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K0 - 2 * KPerBlock); + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); } // LDS double buffer: tail if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { - a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); - __syncthreads(); + block_sync_lds(); // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); - __syncthreads(); + block_sync_lds(); // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); } else // if has 1 iteration left { @@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); } // output: register to global memory { - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, @@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ThreadwiseTensorSliceTransfer_v1r3< FloatAcc, FloatC, - decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), - decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, Sequence<1, c_m10_m11_n10_n11_thread_tensor_lengths[I0], c_m10_m11_n10_n11_thread_tensor_lengths[I1], @@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, - true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, make_multi_index(im0, c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], in0, c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} - .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, make_tuple(I0, I0, I0, I0, I0, I0), c_thread_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_grid_buf, - CGridStepHacks{}); + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); } } }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index bfa93e5866..d60f8c4d07 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,4 +1,5 @@ #pragma once + #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dlops.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp similarity index 96% rename from include/ck/tensor_operation/gpu/thread/threadwise_contraction_dlops.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 8b75381026..6a532c79f9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dlops.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,6 +1,4 @@ -#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP -#define CK_THREADWISE_CONTRACTION_DLOPS_HPP - +#pragma once #include "common_header.hpp" #include "math.hpp" @@ -25,9 +23,9 @@ template ::type = false> -struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 +struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 { - __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() + __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -124,9 +122,9 @@ template ::type = false> -struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { - __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index 48338ddfa6..f0e9c7e761 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,4 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }; } // namespace ck -#endif diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 3071e45640..59fe17e867 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,6 +1,4 @@ -#ifndef CK_INNER_PRODUCT_HPP -#define CK_INNER_PRODUCT_HPP - +#pragma once #include "data_type.hpp" namespace ck { @@ -138,7 +136,7 @@ template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { -#if defined(CK_USE_DOT4_I32_I8) +#if defined(CK_USE_AMD_V_DOT4_I32_I8) #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ @@ -202,4 +200,3 @@ inner_product(const int8x16_t& a, const int8x16_t } } // namespace ck -#endif diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 1a59f3c81e..ef177e9697 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray { return base::operator()(i); } + + __host__ __device__ void Clear() + { + static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); + } }; // static buffer for vector @@ -146,9 +151,9 @@ struct StaticBufferTupleOfVector __host__ __device__ void Clear() { - const index_t numScalars = NumOfVector * ScalarPerVector; + constexpr index_t NumScalars = NumOfVector * ScalarPerVector; - static_for<0, Number{}, 1>{}([&](auto i) { SetAsType(i, S{0}); }); + static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); }); } }; diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 280ac83883..7cd6cc34c9 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -24,7 +24,7 @@ check_err(const std::vector& out, const std::vector& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-5, - double atol = 1e-8) + double atol = 3e-6) { if(out.size() != ref.size()) { @@ -173,8 +173,8 @@ check_err(const std::vector& out, { if(out[i] != ref[i]) { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] - << std::endl + std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) + << " != " << static_cast(ref[i]) << std::endl << msg << std::endl; return false; } diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 5abfb0c074..b20a4b57e5 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 556b06d7e1..da769a5626 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -1,4 +1,3 @@ -# device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; @@ -8,10 +7,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; @@ -33,11 +32,21 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp; ) add_library(device_gemm_instance OBJECT ${DEVICE_GEMM_INSTANCE_SOURCE}) target_compile_features(device_gemm_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) - -clang_tidy_check(device_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..db7f6af04b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..c4253bcc4c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d19d11f1f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..cd86e5ceae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..3fcc5fdfdc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..8cd32128b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..4c4bfc440d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..c6077341b1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..91b68d4bf2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..13b185fd93 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..ff4a89beb4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..e32158a292 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index 4530d95c72..2185b55aac 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index 4214c71efb..90966349b2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index 39bb7e1473..aa5a13001c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 2ddde9e630..82eec1164a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -45,11 +45,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 0525733103..ee0050d200 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 45e6174260..146449f3e8 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -42,14 +42,10 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( - std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( std::vector&); @@ -74,6 +70,21 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); + } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation @@ -125,7 +136,11 @@ void profile_gemm_impl(int do_verification, std::size_t num_thread = 1; switch(init_method) { - case 0: break; + // case 0: break; + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); @@ -174,6 +189,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); } @@ -192,6 +210,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); } @@ -210,6 +231,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); } @@ -228,6 +252,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } @@ -250,6 +277,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } @@ -268,6 +298,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); @@ -289,6 +322,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } @@ -307,6 +343,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } @@ -353,28 +392,40 @@ void profile_gemm_impl(int do_verification, is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); } } @@ -523,7 +574,8 @@ void profile_gemm_impl(int do_verification, } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" + << std::endl; } } diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 2a8078ca5f..3d20389963 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -13,6 +13,7 @@ int profile_gemm_bias_relu_add(int, char*[]); int profile_gemm_reduce(int, char*[]); int profile_batched_gemm(int, char*[]); int profile_grouped_gemm(int, char*[]); +int profile_conv_fwd(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); @@ -107,7 +108,7 @@ int main(int argc, char* argv[]) " conv1d_bwd_data: BackwardConvolution data 1 dim\n" " conv2d_bwd_data: BackwardConvolution data 2 dim\n" " conv3d_bwd_data: BackwardConvolution data 3 dim\n" - " reduce: REDUCE\n" + " reduce: Reduce\n" " conv2d_bwd_weight: Backward Weight Convolution 2d\n"); // clang-format on } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 382b1f9ed0..b05ec8d328 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -2,6 +2,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/ ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index 83b3c1e2e3..b8679e3715 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,15 +1,29 @@ -add_test_executable(test_gemm_fp32 gemm_fp32.cpp) -target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) +# GEMM XDL +add_test_executable(test_gemm_xdl_fp32 gemm_xdl_fp32.cpp) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_fp16 gemm_fp16.cpp) -target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_bf16 gemm_bf16.cpp) -target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) -target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_int8 gemm_int8.cpp) -target_link_libraries(test_gemm_int8 PRIVATE host_tensor) -target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_xdl_int8 gemm_xdl_int8.cpp) +target_link_libraries(test_gemm_xdl_int8 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_int8 PRIVATE device_gemm_instance) + +# GEMM DL +add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp) +target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp) +target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp) +target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor) +TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_instance) diff --git a/test/gemm/gemm_dl_fp16.cpp b/test/gemm/gemm_dl_fp16.cpp new file mode 100644 index 0000000000..6165355ec4 --- /dev/null +++ b/test/gemm/gemm_dl_fp16.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + + std::vector gemmPtrs; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_dl_fp32.cpp b/test/gemm/gemm_dl_fp32.cpp new file mode 100644 index 0000000000..cd0f816731 --- /dev/null +++ b/test/gemm/gemm_dl_fp32.cpp @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_dl_int8.cpp b/test/gemm/gemm_dl_int8.cpp new file mode 100644 index 0000000000..72b9f1440f --- /dev/null +++ b/test/gemm/gemm_dl_int8.cpp @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 17e954b7f2..258ed60b08 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -60,7 +60,7 @@ template -void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, +bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, const ck::gemm_util::GemmParams& params, const Tensor& A, const Tensor& B, @@ -73,9 +73,6 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); - a_m_k_device_buf.ToDevice(A.mData.data()); - b_k_n_device_buf.ToDevice(B.mData.data()); - auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto argument_ptr = gemmPtr->MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), @@ -91,15 +88,23 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, b_element_op, c_element_op); - if(!gemmPtr->IsSupportedArgument(argument_ptr.get())) + if(gemmPtr->IsSupportedArgument(argument_ptr.get())) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + a_m_k_device_buf.ToDevice(A.mData.data()); + b_k_n_device_buf.ToDevice(B.mData.data()); + invoker_ptr->Run(argument_ptr.get()); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; } + else + { + std::cout << "device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; - invoker_ptr->Run(argument_ptr.get()); - c_m_n_device_buf.FromDevice(C.mData.data()); + return false; + } } template ::value) + if(is_supported) { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + + return res; } - else if(std::is_same::value) + else { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + return true; } - else if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - - return res; } }; diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_xdl_bf16.cpp similarity index 100% rename from test/gemm/gemm_bf16.cpp rename to test/gemm/gemm_xdl_bf16.cpp diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_xdl_fp16.cpp similarity index 100% rename from test/gemm/gemm_fp16.cpp rename to test/gemm/gemm_xdl_fp16.cpp diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_xdl_fp32.cpp similarity index 100% rename from test/gemm/gemm_fp32.cpp rename to test/gemm/gemm_xdl_fp32.cpp diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_xdl_int8.cpp similarity index 82% rename from test/gemm/gemm_int8.cpp rename to test/gemm/gemm_xdl_int8.cpp index 870881dd76..fbb1b1ac98 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_xdl_int8.cpp @@ -31,14 +31,10 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( - std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation @@ -57,7 +53,7 @@ int main() bool res = true; ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -75,7 +71,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -93,7 +89,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -111,7 +107,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) {