From 034987b172dc3dd077ae2306911d846e73e7fe95 Mon Sep 17 00:00:00 2001 From: j4yan Date: Wed, 13 Apr 2022 01:16:36 -0500 Subject: [PATCH 01/46] start adding navi21 GEMM --- .../blockwise_tensor_slice_transfer_v5r1.hpp | 11 + .../gpu/device/device_gemm_dlops.hpp | 579 ++++++++++++++++++ .../gpu/grid/gridwise_gemm_dlops_v1r3.hpp | 42 +- .../gpu/gemm/CMakeLists.txt | 3 + ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 46 ++ test/gemm_dlops/CMakeLists.txt | 15 + test/gemm_dlops/gemm_dlops_fp32.cpp | 74 +++ 7 files changed, 744 insertions(+), 26 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp create mode 100644 test/gemm_dlops/CMakeLists.txt create mode 100644 test/gemm_dlops/gemm_dlops_fp32.cpp diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index acd99132cc..2c3b4438c2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -86,6 +86,17 @@ struct BlockwiseTensorSliceTransfer_v5r1 } } + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf); + } + } + template __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp new file mode 100644 index 0000000000..4483d53889 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp @@ -0,0 +1,579 @@ +#pragma once + +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gemm_specialization.hpp" +#include "element_wise_operation.hpp" +#include "gridwise_gemm_dlops_v1r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename CDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDlops : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlops_km_kn_mn_v1r3; + + using AK0M0M1K1GridDesc = + decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{})); + using BK0N0N1K1GridDesc = decltype(GridwiseGemm::MakeBKN0N1GridDescriptor(BGridDesc_K0_N_K1{})); + using CM0M10M11N0N10N11GridDesc = + decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01} + // a_element_op_{a_element_op}, + // b_element_op_{b_element_op}, + // c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc; + BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc; + CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc; + + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + + index_t M01_; + index_t N01_; + + // AElementwiseOperation a_element_op_; + // BElementwiseOperation b_element_op_; + // CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.cblockid_to_m0_n0_block_cluster_adaptor_); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.cblockid_to_m0_n0_block_cluster_adaptor_); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.cblockid_to_m0_n0_block_cluster_adaptor_); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.cblockid_to_m0_n0_block_cluster_adaptor_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck + diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp index 1a66c8ff3f..932674be87 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp @@ -83,12 +83,7 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemmDlops_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; @@ -437,8 +432,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // LDS double buffer: preload data into LDS { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); @@ -456,17 +451,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 { // even iteration a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); + a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); + b_block_slice_copy_step); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, @@ -480,17 +473,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // odd iteration a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); + a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); + b_block_slice_copy_step); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -508,15 +499,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); + a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); + b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step); __syncthreads(); // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( @@ -583,8 +574,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 make_tuple(I0, I0, I0, I0, I0, I0), c_thread_buf, c_m0_m10_m11_n0_n10_n11_grid_desc, - c_grid_buf, - CGridStepHacks{}); + c_grid_buf); } } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 5f057adcc5..6bf5d3064d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -33,6 +33,9 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; + + device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp; + ) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..33c69f78c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/gemm_dlops/CMakeLists.txt b/test/gemm_dlops/CMakeLists.txt new file mode 100644 index 0000000000..96c3e9e5bf --- /dev/null +++ b/test/gemm_dlops/CMakeLists.txt @@ -0,0 +1,15 @@ +add_test_executable(test_gemm_dlops_fp32 gemm_fp32.cpp) +target_link_libraries(test_gemm_dlops_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance) + +# add_test_executable(test_gemm_dlops_fp16 gemm_fp16.cpp) +# target_link_libraries(test_gemm_dlops_fp16 PRIVATE host_tensor) +# target_link_libraries(test_gemm_dlops_fp16 PRIVATE device_gemm_dlops_instance) +# +# add_test_executable(test_gemm_dlops_bf16 gemm_bf16.cpp) +# target_link_libraries(test_gemm_dlops_bf16 PRIVATE host_tensor) +# target_link_libraries(test_gemm_dlops_bf16 PRIVATE device_gemm_dlops_instance) +# +# add_test_executable(test_gemm_dlops_int8 gemm_int8.cpp) +# target_link_libraries(test_gemm_dlops_int8 PRIVATE host_tensor) +# target_link_libraries(test_gemm_dlops_int8 PRIVATE device_gemm_dlops_instance) diff --git a/test/gemm_dlops/gemm_dlops_fp32.cpp b/test/gemm_dlops/gemm_dlops_fp32.cpp new file mode 100644 index 0000000000..1e8d721ab8 --- /dev/null +++ b/test/gemm_dlops/gemm_dlops_fp32.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_dlops_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(std::vector&); +// void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(std::vector&); +// void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(std::vector&); +// void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} From 4f5817da278bac9afa990bfa1d0942ccdcef283c Mon Sep 17 00:00:00 2001 From: j4yan Date: Wed, 13 Apr 2022 23:26:25 -0500 Subject: [PATCH 02/46] navi_gemm_km_kn_mn_fp32 compiles and passes one test. --- .../gpu/block/blockwise_gemm_dlops_v2r3.hpp | 2 +- .../gpu/device/device_gemm_dlops.hpp | 155 +++++++++--------- .../gpu/grid/gridwise_gemm_dlops_v1r3.hpp | 17 +- .../threadwise_tensor_slice_transfer_v5r1.hpp | 4 +- .../gpu/gemm/CMakeLists.txt | 9 + ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 48 +++++- test/CMakeLists.txt | 1 + test/gemm_dlops/CMakeLists.txt | 2 +- test/gemm_dlops/gemm_dlops_fp32.cpp | 3 +- 9 files changed, 141 insertions(+), 100 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp index 0a7b8486f4..fa52d1749e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp @@ -3,7 +3,7 @@ #include "common_header.hpp" #include "tensor_adaptor.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_contraction_dlops.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp index 4483d53889..0bb3daf5bb 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp @@ -33,7 +33,8 @@ template < index_t BlockSize, index_t MPerBlock, index_t NPerBlock, - index_t KPerBlock, + index_t K0PerBlock, + index_t K1, index_t M1PerThread, index_t N1PerThread, index_t KPerThread, @@ -56,17 +57,13 @@ template < typename CThreadTransferSrcDstAccessOrder, index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferDstScalarPerVector, - typename AGridStepHacks, - typename BGridStepHacks, - typename CGridStepHacks, - typename AGridMoveSliceWindowStepHacks, - typename BGridMoveSliceWindowStepHacks, enable_if_t< is_same_v && is_same_v && is_same_v, bool> = false> -struct DeviceGemmDlops : public DeviceGemm +struct DeviceGemmDlops + : public DeviceGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -201,12 +198,12 @@ struct DeviceGemmDlops : public DeviceGemm; - - using AK0M0M1K1GridDesc = + CThreadTransferDstScalarPerVector>; + + using AGridDesc_K0_M0_M1_K1 = decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{})); - using BK0N0N1K1GridDesc = decltype(GridwiseGemm::MakeBKN0N1GridDescriptor(BGridDesc_K0_N_K1{})); - using CM0M10M11N0N10N11GridDesc = + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBK0N0N1K1GridDescriptor(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(CGridDesc_M_N{})); // Argument struct Argument : public BaseArgument @@ -261,10 +256,9 @@ struct DeviceGemmDlops : public DeviceGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, true, true>; @@ -369,7 +370,7 @@ struct DeviceGemmDlops : public DeviceGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, true, false>; @@ -395,7 +396,7 @@ struct DeviceGemmDlops : public DeviceGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, false, true>; @@ -421,7 +422,7 @@ struct DeviceGemmDlops : public DeviceGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, false, false>; @@ -447,7 +448,7 @@ struct DeviceGemmDlops : public DeviceGemm"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp index 932674be87..5167018d74 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp @@ -7,8 +7,9 @@ #include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_set.hpp" +#include "element_wise_operation.hpp" namespace ck { @@ -327,7 +328,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ABlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(a_k0_m0_m1_k1_grid_desc), + remove_reference_t, decltype(a_k0_m0_m1_k1_block_desc), ABlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, @@ -351,7 +352,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 BBlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(b_k0_n0_n1_k1_grid_desc), + remove_reference_t, decltype(b_k0_n0_n1_k1_block_desc), BBlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, @@ -498,10 +499,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // LDS double buffer: tail if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { - a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step); + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step); __syncthreads(); @@ -552,6 +551,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 FloatC, decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + ck::tensor_operation::element_wise::PassThrough, Sequence<1, c_m10_m11_n10_n11_thread_tensor_lengths[I0], c_m10_m11_n10_n11_thread_tensor_lengths[I1], @@ -569,7 +569,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], in0, c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0), c_thread_buf, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index 48338ddfa6..f0e9c7e761 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,4 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }; } // namespace ck -#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 6bf5d3064d..d17f01da42 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -45,3 +45,12 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) clang_tidy_check(device_gemm_instance) + + +add_library(device_gemm_dlops_instance SHARED device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp) + +target_compile_features(device_gemm_dlops_instance PUBLIC) +set_target_properties(device_gemm_dlops_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_gemm_dlops_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_gemm_dlops_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index 33c69f78c6..c7cff6e518 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -23,16 +23,52 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = - std::tuple< - // clang-format off +using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + DeviceGemmDlops, + S<8, 2>, + S<4, 1, 1, 2>, + S<2, 1, 128, 1>, + S<1, 2, 0, 3>, + S<1, 2, 0, 3>, + S<4, 1, 1, 2>, + S<1, 2, 0, 3>, + S<1, 1, 1, 2>, + S<4, 1, 1, 2>, + S<2, 1, 128, 1>, + S<1, 2, 0, 3>, + S<1, 2, 0, 3>, + S<4, 1, 1, 2>, + S<1, 2, 0, 3>, + S<1, 1, 1, 2>, + S<0, 1, 2, 3, 4, 5>, + 5, + 4>>; void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances( std::vector>& instances) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ae9949b8ce..e4f75df092 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -37,6 +37,7 @@ add_subdirectory(space_filling_curve) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) +add_subdirectory(gemm_dlops) add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) diff --git a/test/gemm_dlops/CMakeLists.txt b/test/gemm_dlops/CMakeLists.txt index 96c3e9e5bf..4d137ff5d5 100644 --- a/test/gemm_dlops/CMakeLists.txt +++ b/test/gemm_dlops/CMakeLists.txt @@ -1,4 +1,4 @@ -add_test_executable(test_gemm_dlops_fp32 gemm_fp32.cpp) +add_test_executable(test_gemm_dlops_fp32 gemm_dlops_fp32.cpp) target_link_libraries(test_gemm_dlops_fp32 PRIVATE host_tensor) target_link_libraries(test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance) diff --git a/test/gemm_dlops/gemm_dlops_fp32.cpp b/test/gemm_dlops/gemm_dlops_fp32.cpp index 1e8d721ab8..f8bc1086db 100644 --- a/test/gemm_dlops/gemm_dlops_fp32.cpp +++ b/test/gemm_dlops/gemm_dlops_fp32.cpp @@ -6,7 +6,7 @@ #include #include -#include "gemm_util.hpp" +#include "../gemm/gemm_util.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -15,7 +15,6 @@ #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_dlops_c_shuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" From 0d46b40c30668d4c5d921a26b379d227b3cedd3f Mon Sep 17 00:00:00 2001 From: j4yan Date: Thu, 14 Apr 2022 11:43:09 -0500 Subject: [PATCH 03/46] rename variables and functions in gridwise_gemm_dlops_v1r3 --- .../gpu/device/device_gemm_dlops.hpp | 44 +-- .../gpu/grid/gridwise_gemm_dlops_v1r3.hpp | 299 +++++++++--------- 2 files changed, 172 insertions(+), 171 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp index 0bb3daf5bb..517ac2d26e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp @@ -228,13 +228,13 @@ struct DeviceGemmDlops CThreadTransferDstScalarPerVector>; using AGridDesc_K0_M0_M1_K1 = - decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{})); + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using BGridDesc_K0_N0_N1_K1 = - decltype(GridwiseGemm::MakeBK0N0N1K1GridDescriptor(BGridDesc_K0_N_K1{})); + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = - decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{})); + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); using DefaultBlock2CTileMap = - decltype(GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(CGridDesc_M_N{})); + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); // Argument struct Argument : public BaseArgument @@ -261,10 +261,10 @@ struct DeviceGemmDlops c_grid_desc_m0_m10_m11_n0_n10_n11_{}, block_2_ctile_map_{}, M01_{M01}, - N01_{N01} - // a_element_op_{a_element_op}, - // b_element_op_{b_element_op}, - // c_element_op_{c_element_op} + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} { a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); @@ -274,14 +274,14 @@ struct DeviceGemmDlops a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) { a_grid_desc_k0_m0_m1_k1_ = - GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_grid_desc_k0_m_k1_); + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); b_grid_desc_k0_n0_n1_k1_ = - GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_grid_desc_k0_n_k1_); + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); c_grid_desc_m0_m10_m11_n0_n10_n11_ = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_grid_desc_m_n_); + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); block_2_ctile_map_ = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_grid_desc_m_n_); + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); } } @@ -300,12 +300,14 @@ struct DeviceGemmDlops DefaultBlock2CTileMap block_2_ctile_map_; + // TODO: unused, but may be useful in future. index_t M01_; index_t N01_; - // AElementwiseOperation a_element_op_; - // BElementwiseOperation b_element_op_; - // CElementwiseOperation c_element_op_; + // TODO: unused since gridwise_gemm_dlops_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; }; // Invoker @@ -317,14 +319,14 @@ struct DeviceGemmDlops { { std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" - << arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I2) << "}" << std::endl; + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" - << arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I2) << "}" << std::endl; + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp index 5167018d74..f3668a0c28 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp @@ -16,10 +16,10 @@ namespace ck { template __global__ void @@ -30,10 +30,10 @@ __global__ void const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -44,10 +44,10 @@ __global__ void p_b_grid, p_c_grid, p_shared_block, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, integral_constant{}, integral_constant{}); } @@ -57,12 +57,12 @@ template {}; // K1 should be Number<...> - static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2); __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -102,112 +102,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); } __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CMNGridDesc& c_grid_desc_m_n) { - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); } __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { - const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); return grid_size; } __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) { - const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; return has_main_k_block_loop; } __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) { - const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; return has_double_tail_k_block_loop; } __host__ __device__ static constexpr auto - MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) { - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto M1 = Number{}; + const auto M1 = Number{}; const auto M0 = M / M1; - const auto a_k0_m0_m1_k1_grid_desc = - transform_tensor_descriptor(a_k0_m_k1_grid_desc, + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, make_tuple(make_pass_through_transform(K0), make_unmerge_transform(make_tuple(M0, M1)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return a_k0_m0_m1_k1_grid_desc; + return a_grid_desc_k0_m0_m1_k1; } __host__ __device__ static constexpr auto - MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) { - const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto N1 = Number{}; + const auto N1 = Number{}; const auto N0 = N / N1; - const auto b_k0_n0_n1_k1_grid_desc = - transform_tensor_descriptor(b_k0_n_k1_grid_desc, + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, make_tuple(make_pass_through_transform(K0), make_unmerge_transform(make_tuple(N0, N1)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return b_k0_n0_n1_k1_grid_desc; + return b_grid_desc_k0_n0_n1_k1; } __host__ __device__ static constexpr auto - MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CMNGridDesc& c_grid_desc_m_n) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; const auto M0 = M / M1; const auto N0 = N / N1; @@ -222,41 +222,41 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 constexpr auto M10 = M1 / M11; constexpr auto N10 = N1 / N11; - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_unmerge_transform(make_tuple(N0, N10, N11))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - return c_m0_m10_m11_n0_n10_n11_grid_desc; + return c_grid_desc_m0_m10_m11_n0_n10_n11; } __host__ __device__ static constexpr auto - MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + MakeDefaultBlock2CTileMap(const CMNGridDesc& c_grid_desc_m_n) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; const auto M0 = M / M1; const auto N0 = N / N1; - const auto cblockid_to_m0_n0_block_cluster_adaptor = + const auto block_2_ctile_map = make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0>{})); - return cblockid_to_m0_n0_block_cluster_adaptor; + return block_2_ctile_map; } - using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); - using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); - using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CMNGridDesc{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CMNGridDesc{})); template __device__ static void @@ -264,24 +264,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, FloatAB* __restrict__ p_shared_block, - const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, integral_constant, integral_constant) { const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); // divide block work by [M, N] const auto c_m0_n0_block_cluster_idx = - cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); // HACK: this force index data into SGPR const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); @@ -293,28 +292,28 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // A matrix in LDS memory, for blockwise GEMM constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, for blockwise GEMM constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); - static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == a_k0_m_k1_block_desc.GetElementSpaceSize() && - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == b_k0_n_k1_block_desc.GetElementSpaceSize() && "wrong!"); @@ -322,14 +321,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - remove_reference_t, - decltype(a_k0_m0_m1_k1_block_desc), + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), ABlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths @@ -337,23 +336,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder false, - true>(a_k0_m0_m1_k1_grid_desc, + true>(a_grid_desc_k0_m0_m1_k1, make_multi_index(0, im0, 0, 0), - a_k0_m0_m1_k1_block_desc, + a_block_desc_k0_m0_m1_k1, make_multi_index(0, 0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - remove_reference_t, - decltype(b_k0_n0_n1_k1_block_desc), + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), BBlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths @@ -361,16 +360,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder false, - true>(b_k0_n0_n1_k1_grid_desc, + true>(b_grid_desc_k0_n0_n1_k1, make_multi_index(0, in0, 0, 0), - b_k0_n0_n1_k1_block_desc, + b_block_desc_k0_n0_n1_k1, make_multi_index(0, 0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlockM1] is in LDS - // b_mtx[KPerBlocl, NPerBlockN1] is in LDS - // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< @@ -391,58 +390,58 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); FloatAB* p_a_block_double = p_shared_block; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output auto c_thread_buf = make_static_buffer( - c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); ThreadwiseTensorSliceSet_v1{} - .Run(c_m10_m11_n10_n11_thread_desc, + .Run(c_thread_desc_m10_m11_n10_n11, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); // LDS double buffer: preload data into LDS { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); } if constexpr(HasMainKBlockLoop) { - const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); index_t k_block_data_begin = 0; @@ -451,76 +450,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 do { // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on current data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K0 - 2 * KPerBlock); + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); } // LDS double buffer: tail if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); __syncthreads(); // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); __syncthreads(); // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); } else // if has 1 iteration left { @@ -528,12 +527,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); } // output: register to global memory { - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, @@ -549,8 +548,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ThreadwiseTensorSliceTransfer_v1r3< FloatAcc, FloatC, - decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), - decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), ck::tensor_operation::element_wise::PassThrough, Sequence<1, c_m10_m11_n10_n11_thread_tensor_lengths[I0], @@ -563,7 +562,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, - true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, make_multi_index(im0, c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], @@ -571,10 +570,10 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), ck::tensor_operation::element_wise::PassThrough{}} - .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, make_tuple(I0, I0, I0, I0, I0, I0), c_thread_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_buf); } } From 27b1c45bdf0ea8f95c4d846757cc158d359854ad Mon Sep 17 00:00:00 2001 From: j4yan Date: Thu, 14 Apr 2022 14:56:42 -0500 Subject: [PATCH 04/46] add other 3 layouts; format instance --- .../gpu/gemm/CMakeLists.txt | 16 +++-- ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 50 ++------------- ...mm_dlops_f32_f32_f32_km_nk_mn_instance.cpp | 48 ++++++++++++++ ...mm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp | 47 ++++++++++++++ ...mm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp | 48 ++++++++++++++ test/gemm_dlops/gemm_dlops_fp32.cpp | 64 +++++++++++++++++-- 6 files changed, 220 insertions(+), 53 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index d17f01da42..b3910b319b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -1,5 +1,5 @@ # device_gemm_instance -set(DEVICE_GEMM_INSTANCE_SOURCE +set(DEVICE_GEMM_XDL_INSTANCE_SOURCE device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; @@ -33,12 +33,9 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; - - device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp; - ) -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +add_library(device_gemm_instance SHARED ${DEVICE_GEMM_XDL_INSTANCE_SOURCE}) target_compile_features(device_gemm_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -47,7 +44,14 @@ install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) clang_tidy_check(device_gemm_instance) -add_library(device_gemm_dlops_instance SHARED device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp) +set(DEVICE_GEMM_DLOPS_INSTANCE_SOURCE + device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_dlops_instance SHARED ${DEVICE_GEMM_DLOPS_INSTANCE_SOURCE}) target_compile_features(device_gemm_dlops_instance PUBLIC) set_target_properties(device_gemm_dlops_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index c7cff6e518..fe2dda3c25 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -25,50 +25,14 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on - DeviceGemmDlops, - S<8, 2>, - S<4, 1, 1, 2>, - S<2, 1, 128, 1>, - S<1, 2, 0, 3>, - S<1, 2, 0, 3>, - S<4, 1, 1, 2>, - S<1, 2, 0, 3>, - S<1, 1, 1, 2>, - S<4, 1, 1, 2>, - S<2, 1, 128, 1>, - S<1, 2, 0, 3>, - S<1, 2, 0, 3>, - S<4, 1, 1, 2>, - S<1, 2, 0, 3>, - S<1, 1, 1, 2>, - S<0, 1, 2, 3, 4, 5>, - 5, - 4>>; + >; void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..0b996630eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,48 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dlops_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..2f3f631d89 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,47 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..e27c15cd95 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,48 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + diff --git a/test/gemm_dlops/gemm_dlops_fp32.cpp b/test/gemm_dlops/gemm_dlops_fp32.cpp index f8bc1086db..3cd2775bb6 100644 --- a/test/gemm_dlops/gemm_dlops_fp32.cpp +++ b/test/gemm_dlops/gemm_dlops_fp32.cpp @@ -14,7 +14,7 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" +#include "device_gemm_dlops.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -30,10 +30,11 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { + void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(std::vector&); -// void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(std::vector&); -// void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(std::vector&); -// void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -68,6 +69,61 @@ int main() PassThrough>{}(gemmPtr); } + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } From e10a2624478b2a99f7e3f3bc8182d1e4260bc65e Mon Sep 17 00:00:00 2001 From: j4yan Date: Fri, 15 Apr 2022 11:17:51 -0500 Subject: [PATCH 05/46] adding more tuning parameters add tuning parameters for other 3 layouts --- .../gpu/block/blockwise_gemm_dlops_v2r3.hpp | 1 + ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 38 ++++++++++++++++++- ...mm_dlops_f32_f32_f32_km_nk_mn_instance.cpp | 25 +++++++++++- ...mm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp | 25 +++++++++++- ...mm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp | 24 +++++++++++- 5 files changed, 106 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp index fa52d1749e..15e7fd9028 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp @@ -175,6 +175,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B "wrong!"); // TODO: remove this restriction + static_assert(BM0 == 2, "wrong"); static_assert(BM0 == 2 && BN0 == 2, "wrong"); } diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index fe2dda3c25..e46afa5830 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -24,13 +24,47 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + /* + * K1 = 1 + */ + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + + // repeat the above configurartion, but changing K1 to 4, NOT working for fp32 + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + + + + // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp index 0b996630eb..9c4318274d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp @@ -29,8 +29,29 @@ using device_gemm_dlops_f32_f32_f32_km_nk_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + /* + * K1 = 1 + */ + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp index 2f3f631d89..0cf50e66dc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp @@ -29,8 +29,29 @@ using device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + /* + * K1 = 1 + */ + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp index e27c15cd95..68fd1ac0c1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp @@ -29,7 +29,29 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + /* + * K1 = 1 + */ + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; From 84501243a10c6cd53ccf21604e5aa115e90e9280 Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 00:21:30 -0500 Subject: [PATCH 06/46] add gemm_dlops_f16 --- .../gpu/device/device_gemm_dlops.hpp | 2 +- include/ck/utility/inner_product.hpp | 36 +++++ .../gpu/gemm/CMakeLists.txt | 8 ++ ...mm_dlops_f16_f16_f16_km_kn_mn_instance.cpp | 69 +++++++++ ...mm_dlops_f16_f16_f16_km_nk_mn_instance.cpp | 70 ++++++++++ ...mm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp | 69 +++++++++ ...mm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp | 70 ++++++++++ ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 3 +- ...mm_dlops_f32_f32_f32_km_nk_mn_instance.cpp | 2 +- ...mm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp | 2 +- ...mm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp | 3 +- ...dlops_int8_int8_int8_km_kn_mn_instance.cpp | 79 +++++++++++ ...dlops_int8_int8_int8_km_nk_mn_instance.cpp | 80 +++++++++++ ...dlops_int8_int8_int8_mk_kn_mn_instance.cpp | 79 +++++++++++ ...dlops_int8_int8_int8_mk_nk_mn_instance.cpp | 80 +++++++++++ test/gemm_dlops/CMakeLists.txt | 18 +-- test/gemm_dlops/gemm_dlops_fp16.cpp | 130 +++++++++++++++++ test/gemm_dlops/gemm_dlops_int8.cpp | 131 ++++++++++++++++++ 18 files changed, 915 insertions(+), 16 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp create mode 100644 test/gemm_dlops/gemm_dlops_fp16.cpp create mode 100644 test/gemm_dlops/gemm_dlops_int8.cpp diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp index 517ac2d26e..b3ca4db678 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp @@ -21,8 +21,8 @@ namespace device { template < typename ADataType, typename BDataType, - typename AccDataType, typename CDataType, + typename AccDataType, typename ALayout, typename BLayout, typename CLayout, diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 3071e45640..03ec2fdc47 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -70,6 +70,12 @@ inner_product(const float4_t& a, const float4_t& b, f c); } +template <> +__device__ void inner_product(const half_t& a, const half_t& b, float& c) +{ + c += a * b; +} + template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { @@ -134,6 +140,36 @@ __device__ void inner_product(const half8_t& a, const h c); } +template <> +__device__ void inner_product(const int8_t& a, const int8_t& b, int32_t& c) +{ + c += a * b; +} + +template <> +__device__ void +inner_product(const int8x2_t& a, const int8x2_t& b, int32_t& c) +{ +// #if defined(CK_USE_DOT2_I32_I8) +// #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM +// asm volatile("\n \ +// v_dot2_i32_i8 %0, %1, %2, %0\n \ +// " +// : "=v"(c) +// : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); +// #else +// c = __builtin_amdgcn_sdot2(bit_cast(a), bit_cast(b), c, false); +// #endif +// #else + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 2, 1>{}([&](auto i) { + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); + }); +// #endif +} template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index b3910b319b..5cbdc5d421 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -49,6 +49,14 @@ set(DEVICE_GEMM_DLOPS_INSTANCE_SOURCE device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp; + device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp; + device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp; + device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp; ) add_library(device_gemm_dlops_instance SHARED ${DEVICE_GEMM_DLOPS_INSTANCE_SOURCE}) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..9d01e92597 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,69 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dlops_f16_f16_f16_km_kn_mn_instances = std::tuple< + + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..24cf237efe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,70 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dlops_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..1fdfa22033 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,69 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..5ca4bd7e35 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,70 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index e46afa5830..75b3d75670 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -24,11 +24,10 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< - // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | /* * K1 = 1 diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp index 9c4318274d..325424ab81 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp @@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | /* * K1 = 1 diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp index 0cf50e66dc..f33a384a61 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp @@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | /* * K1 = 1 diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp index 68fd1ac0c1..cd3069d493 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp @@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | /* * K1 = 1 @@ -52,7 +52,6 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple< DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..3ce8e6afdc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp @@ -0,0 +1,79 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = std::tuple< + + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 2 + */ + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..eb43247ec4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp @@ -0,0 +1,80 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dlops_int8_int8_int8_km_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // + /* + * K1 = 2 + */ + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 4 + */ + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + + + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..fb7fcc557a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp @@ -0,0 +1,79 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // + /* + * K1 = 2 + */ + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 4 + */ + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + + diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..e8b7d89373 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp @@ -0,0 +1,80 @@ +#include +#include "config.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + /* + * K1 = 1 + */ + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // + /* + * K1 = 2 + */ + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + + /* + * K1 = 4 + */ + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + + + + diff --git a/test/gemm_dlops/CMakeLists.txt b/test/gemm_dlops/CMakeLists.txt index 4d137ff5d5..4d1e8d53bf 100644 --- a/test/gemm_dlops/CMakeLists.txt +++ b/test/gemm_dlops/CMakeLists.txt @@ -2,14 +2,14 @@ add_test_executable(test_gemm_dlops_fp32 gemm_dlops_fp32.cpp) target_link_libraries(test_gemm_dlops_fp32 PRIVATE host_tensor) target_link_libraries(test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance) -# add_test_executable(test_gemm_dlops_fp16 gemm_fp16.cpp) -# target_link_libraries(test_gemm_dlops_fp16 PRIVATE host_tensor) -# target_link_libraries(test_gemm_dlops_fp16 PRIVATE device_gemm_dlops_instance) -# -# add_test_executable(test_gemm_dlops_bf16 gemm_bf16.cpp) +add_test_executable(test_gemm_dlops_fp16 gemm_dlops_fp16.cpp) +target_link_libraries(test_gemm_dlops_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_dlops_fp16 PRIVATE device_gemm_dlops_instance) + +# add_test_executable(test_gemm_dlops_bf16 gemm_dlops_bf16.cpp) # target_link_libraries(test_gemm_dlops_bf16 PRIVATE host_tensor) # target_link_libraries(test_gemm_dlops_bf16 PRIVATE device_gemm_dlops_instance) -# -# add_test_executable(test_gemm_dlops_int8 gemm_int8.cpp) -# target_link_libraries(test_gemm_dlops_int8 PRIVATE host_tensor) -# target_link_libraries(test_gemm_dlops_int8 PRIVATE device_gemm_dlops_instance) + +add_test_executable(test_gemm_dlops_int8 gemm_dlops_int8.cpp) +target_link_libraries(test_gemm_dlops_int8 PRIVATE host_tensor) +target_link_libraries(test_gemm_dlops_int8 PRIVATE device_gemm_dlops_instance) diff --git a/test/gemm_dlops/gemm_dlops_fp16.cpp b/test/gemm_dlops/gemm_dlops_fp16.cpp new file mode 100644 index 0000000000..387af56419 --- /dev/null +++ b/test/gemm_dlops/gemm_dlops_fp16.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} + diff --git a/test/gemm_dlops/gemm_dlops_int8.cpp b/test/gemm_dlops/gemm_dlops_int8.cpp new file mode 100644 index 0000000000..e9591ce5b2 --- /dev/null +++ b/test/gemm_dlops/gemm_dlops_int8.cpp @@ -0,0 +1,131 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dlops.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} + + From 6b2ef39000a4451b0fe4b3e99d29675e4f0d2ce5 Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 19:34:36 -0500 Subject: [PATCH 07/46] tmp --- .../device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp index 3ce8e6afdc..ed602e5c28 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp @@ -52,7 +52,7 @@ using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = std::tuple< DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, /* - * K1 = 2 + * K1 = 4 */ DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, @@ -76,4 +76,3 @@ void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances( } // namespace tensor_operation } // namespace ck - From 29577543a207e3e81aa90f9e50b17902c88c17b9 Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 19:51:15 -0500 Subject: [PATCH 08/46] add dependence of DeviceGemm::IsSupportedArg() on arch --- .../gpu/device/device_gemm_dlops.hpp | 1 + .../gpu/device/device_gemm_xdl.hpp | 17 +++++++++++++++++ include/ck/utility/device_prop.hpp | 16 ++++++++++++++++ 3 files changed, 34 insertions(+) create mode 100644 include/ck/utility/device_prop.hpp diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp index b3ca4db678..6723a1e06f 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp @@ -13,6 +13,7 @@ #include "gemm_specialization.hpp" #include "element_wise_operation.hpp" #include "gridwise_gemm_dlops_v1r3.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index 0d0e463bb0..14f7441856 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -404,11 +404,28 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { +#ifdef __gfx1030__ return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, arg.M01_, arg.N01_); +#else + return false; +#endif + + // if (ck:get_device_name() == "gfx1030") + // { + // return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + // arg.b_grid_desc_k0_n_k1_, + // arg.c_grid_desc_m_n_, + // arg.M01_, + // arg.N01_); + // } + // else + // { + // return false; + // } } // polymorphic diff --git a/include/ck/utility/device_prop.hpp b/include/ck/utility/device_prop.hpp new file mode 100644 index 0000000000..65362995b2 --- /dev/null +++ b/include/ck/utility/device_prop.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace ck { + +std::string get_device_name() +{ + hipDeviceProp_t props{}; + hipGetDeviceProperties(&props, device); + const std::string name(props.gcnArchName); + + return name; +} + +} // namespace ck From f70ad268373b8130cee35bbe39c65f71e4566275 Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 20:59:57 -0500 Subject: [PATCH 09/46] minor changes --- ...dlops_int8_int8_int8_km_kn_mn_instance.cpp | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp index ed602e5c28..4779270cf5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp @@ -21,7 +21,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = std::tuple< - // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| @@ -30,26 +29,26 @@ using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = std::tuple< /* * K1 = 1 */ - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // /* * K1 = 2 */ - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, /* * K1 = 4 From 6baedf3fc7da0f04137db75e04b93f20c519afd4 Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 21:04:57 -0500 Subject: [PATCH 10/46] minor changes --- include/ck/utility/device_prop.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/ck/utility/device_prop.hpp b/include/ck/utility/device_prop.hpp index 65362995b2..665ffb9b76 100644 --- a/include/ck/utility/device_prop.hpp +++ b/include/ck/utility/device_prop.hpp @@ -7,6 +7,8 @@ namespace ck { std::string get_device_name() { hipDeviceProp_t props{}; + int device; + hipGetDevice(&device); hipGetDeviceProperties(&props, device); const std::string name(props.gcnArchName); From 1bcb8cd537c756775b32062aaf9986759a3387ba Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 21:06:20 -0500 Subject: [PATCH 11/46] minor changes --- include/ck/utility/device_prop.hpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/ck/utility/device_prop.hpp b/include/ck/utility/device_prop.hpp index 665ffb9b76..b7de386f0c 100644 --- a/include/ck/utility/device_prop.hpp +++ b/include/ck/utility/device_prop.hpp @@ -9,6 +9,11 @@ std::string get_device_name() hipDeviceProp_t props{}; int device; hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + hipGetDeviceProperties(&props, device); const std::string name(props.gcnArchName); From 62a792b9977f822c151f2754eacf7d95a43eac9f Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 21:06:56 -0500 Subject: [PATCH 12/46] minor changes --- include/ck/utility/device_prop.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/utility/device_prop.hpp b/include/ck/utility/device_prop.hpp index b7de386f0c..1037d9668a 100644 --- a/include/ck/utility/device_prop.hpp +++ b/include/ck/utility/device_prop.hpp @@ -8,7 +8,7 @@ std::string get_device_name() { hipDeviceProp_t props{}; int device; - hipGetDevice(&device); + auto status = hipGetDevice(&device); if(status != hipSuccess) { return std::string(); From 999321dc0051e95d1eb7a0e314d3f11b558777ea Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 21:07:33 -0500 Subject: [PATCH 13/46] minor changes --- include/ck/utility/device_prop.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/ck/utility/device_prop.hpp b/include/ck/utility/device_prop.hpp index 1037d9668a..48577c846a 100644 --- a/include/ck/utility/device_prop.hpp +++ b/include/ck/utility/device_prop.hpp @@ -15,6 +15,10 @@ std::string get_device_name() } hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } const std::string name(props.gcnArchName); return name; From 45e9862d84c0a7bae0ff245684d4fdb4a14e39fa Mon Sep 17 00:00:00 2001 From: j4yan Date: Tue, 19 Apr 2022 21:08:07 -0500 Subject: [PATCH 14/46] minor changes --- include/ck/utility/device_prop.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/utility/device_prop.hpp b/include/ck/utility/device_prop.hpp index 48577c846a..a880172246 100644 --- a/include/ck/utility/device_prop.hpp +++ b/include/ck/utility/device_prop.hpp @@ -14,7 +14,7 @@ std::string get_device_name() return std::string(); } - hipGetDeviceProperties(&props, device); + status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) { return std::string(); From d3f3face9b5daa9ec8d621ec64c85f25988f6c82 Mon Sep 17 00:00:00 2001 From: j4yan Date: Wed, 20 Apr 2022 17:08:34 -0500 Subject: [PATCH 15/46] minor changes --- .../ck/tensor_operation/gpu/device/device_gemm_dlops.hpp | 1 + .../ck/tensor_operation/gpu/device/device_gemm_xdl.hpp | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp index 6723a1e06f..305d905772 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp @@ -472,6 +472,7 @@ struct DeviceGemmDlops static bool IsSupportedArgument(const Argument& arg) { + std::cout << ck::get_device_name() << std::endl; return GridwiseGemm::CheckValidity( arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); } diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index 14f7441856..f9ae911964 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -1,5 +1,4 @@ -#ifndef DEVICE_GEMM_XDL_HPP -#define DEVICE_GEMM_XDL_HPP +#pragma once #include #include @@ -12,6 +11,7 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" #include "gemm_specialization.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { @@ -414,7 +414,9 @@ struct DeviceGemmXdl return false; #endif - // if (ck:get_device_name() == "gfx1030") + std::cout << ck::get_device_name() << std::endl; + + // if (ck::get_device_name() == "gfx1030") // { // return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, // arg.b_grid_desc_k0_n_k1_, @@ -529,4 +531,3 @@ struct DeviceGemmXdl } // namespace device } // namespace tensor_operation } // namespace ck -#endif From e5ea6c7021559177fa25faa7bcf8a7f451cff404 Mon Sep 17 00:00:00 2001 From: j4yan Date: Thu, 21 Apr 2022 10:54:37 -0500 Subject: [PATCH 16/46] push gemm_dlops into profiler --- .../conv2d_bwd_weight_xdl.cpp | 9 +++- .../blockwise_tensor_slice_transfer_v5r1.hpp | 3 +- .../gpu/device/device_gemm_dlops.hpp | 25 +++++---- .../gpu/device/device_gemm_xdl.hpp | 10 ---- .../gpu/element/element_wise_operation.hpp | 10 ++-- include/ck/utility/device_prop.hpp | 2 +- include/ck/utility/dynamic_buffer.hpp | 2 +- include/ck/utility/inner_product.hpp | 20 ++++---- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 2 +- .../conv_add_fwd_driver_offline_nchwc.cpp | 6 +-- .../conv_bwd_driver_offline.cpp | 6 +-- .../conv_fwd_driver_offline.cpp | 6 +-- .../conv_fwd_driver_offline_nchwc.cpp | 32 ++++++------ .../conv_maxpool_fwd_driver_offline_nchwc.cpp | 24 ++++----- .../conv_wrw_driver_offline.cpp | 8 +-- ...mm_dlops_f16_f16_f16_km_kn_mn_instance.cpp | 1 - ...mm_dlops_f16_f16_f16_km_nk_mn_instance.cpp | 3 -- ...mm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp | 2 - ...mm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp | 3 -- ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 3 -- ...mm_dlops_f32_f32_f32_km_nk_mn_instance.cpp | 2 - ...mm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp | 1 - ...mm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp | 2 - ...dlops_int8_int8_int8_km_kn_mn_instance.cpp | 13 ++--- ...dlops_int8_int8_int8_km_nk_mn_instance.cpp | 16 +++--- ...dlops_int8_int8_int8_mk_kn_mn_instance.cpp | 15 +++--- ...dlops_int8_int8_int8_mk_nk_mn_instance.cpp | 16 +++--- profiler/include/profile_gemm_impl.hpp | 51 +++++++++++++++++++ script/clang-format-overwrite.sh | 2 +- test/gemm_dlops/gemm_dlops_fp16.cpp | 2 - test/gemm_dlops/gemm_dlops_fp32.cpp | 1 - test/gemm_dlops/gemm_dlops_int8.cpp | 3 -- 32 files changed, 163 insertions(+), 138 deletions(-) diff --git a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp index 7b74b40d32..bf78cc87e0 100644 --- a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp +++ b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp @@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device:: 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -using ReferenceConvBwdWeightInstance = ck::tensor_operation::host:: - ReferenceConvBwdWeight; +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; int main(int argc, char* argv[]) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 2c3b4438c2..0b737153b0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -87,8 +87,7 @@ struct BlockwiseTensorSliceTransfer_v5r1 } template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp index 305d905772..339acc4de0 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp @@ -263,9 +263,9 @@ struct DeviceGemmDlops block_2_ctile_map_{}, M01_{M01}, N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} { a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); @@ -281,8 +281,7 @@ struct DeviceGemmDlops c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); } } @@ -340,7 +339,8 @@ struct DeviceGemmDlops "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); @@ -472,9 +472,15 @@ struct DeviceGemmDlops static bool IsSupportedArgument(const Argument& arg) { - std::cout << ck::get_device_name() << std::endl; - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + if(ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + else + { + return false; + } } // polymorphic @@ -577,4 +583,3 @@ struct DeviceGemmDlops } // namespace device } // namespace tensor_operation } // namespace ck - diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index f9ae911964..71d476cb93 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -404,16 +404,6 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { -#ifdef __gfx1030__ - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); -#else - return false; -#endif - std::cout << ck::get_device_name() << std::endl; // if (ck::get_device_name() == "gfx1030") diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 5b3606e859..7fd0b7a36f 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -156,8 +156,9 @@ struct RequantReluRequant float gemm_requant = scaleGemm_ * static_cast(x); float relu = gemm_requant > 0 ? gemm_requant : 0; float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 : relu_requant); + y = static_cast(relu_requant > 127 ? 127 + : relu_requant < -128 ? -128 + : relu_requant); } // for reference_gemm @@ -166,8 +167,9 @@ struct RequantReluRequant float gemm_requant = scaleGemm_ * x; float relu = gemm_requant > 0 ? gemm_requant : 0; float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 : relu_requant); + y = static_cast(relu_requant > 127 ? 127 + : relu_requant < -128 ? -128 + : relu_requant); } float scaleGemm_; diff --git a/include/ck/utility/device_prop.hpp b/include/ck/utility/device_prop.hpp index a880172246..5f13d6cb22 100644 --- a/include/ck/utility/device_prop.hpp +++ b/include/ck/utility/device_prop.hpp @@ -4,7 +4,7 @@ namespace ck { -std::string get_device_name() +inline std::string get_device_name() { hipDeviceProp_t props{}; int device; diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index c00982dfff..1f52855293 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -151,7 +151,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_STORE bool constexpr use_amd_buffer_addressing = true; #else - bool constexpr use_amd_buffer_addressing = false; + bool constexpr use_amd_buffer_addressing = false; #endif #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 03ec2fdc47..d84879ff8f 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -150,17 +150,17 @@ template <> __device__ void inner_product(const int8x2_t& a, const int8x2_t& b, int32_t& c) { -// #if defined(CK_USE_DOT2_I32_I8) -// #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM -// asm volatile("\n \ + // #if defined(CK_USE_DOT2_I32_I8) + // #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM + // asm volatile("\n \ // v_dot2_i32_i8 %0, %1, %2, %0\n \ // " -// : "=v"(c) -// : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); -// #else -// c = __builtin_amdgcn_sdot2(bit_cast(a), bit_cast(b), c, false); -// #endif -// #else + // : "=v"(c) + // : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); + // #else + // c = __builtin_amdgcn_sdot2(bit_cast(a), bit_cast(b), c, false); + // #endif + // #else const vector_type a_vector{a}; const vector_type b_vector{b}; @@ -168,7 +168,7 @@ inner_product(const int8x2_t& a, const int8x2_t& b, c += type_convert(a_vector.AsType()[i]) * type_convert(b_vector.AsType()[i]); }); -// #endif + // #endif } template <> __device__ void diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 18e712fb47..de3489f924 100644 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -424,7 +424,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; #else - const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0]; const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0); const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1); diff --git a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp index a7541f03de..82d92fa64d 100644 --- a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp @@ -248,9 +248,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; diff --git a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp index c4dcb7c085..c130cd609c 100644 --- a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp @@ -263,9 +263,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp index ab8beec87b..94c5fd9ca9 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp @@ -257,9 +257,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 0 using in_data_t = bhalf_t; using acc_data_t = float; diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp index 6fb8b4c2aa..ff7e5c7a15 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp @@ -165,15 +165,15 @@ int main(int argc, char* argv[]) constexpr auto K0 = Number<1>{}; constexpr auto K1 = Number<4>{}; #elif 1 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; #elif 0 constexpr auto N = Number<1>{}; constexpr auto Hi = Number<1080>{}; @@ -212,10 +212,10 @@ int main(int argc, char* argv[]) constexpr auto conv_dilation_w = I1; #if 1 - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; #else constexpr auto in_left_pad_h = I0; constexpr auto in_left_pad_w = I0; @@ -235,9 +235,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; diff --git a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp index fb7e8e975b..388656e747 100644 --- a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp @@ -181,15 +181,15 @@ int main(int argc, char* argv[]) constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; #if 1 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; #elif 0 constexpr auto N = Number<1>{}; constexpr auto Hi = Number<1080>{}; @@ -247,9 +247,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; diff --git a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp index 1ac974202c..23b1039fec 100644 --- a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp @@ -229,10 +229,10 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using out_data_t = half_t; - using acc_data_t = float; - using wei_data_t = float; + using in_data_t = half_t; + using out_data_t = half_t; + using acc_data_t = float; + using wei_data_t = float; #elif 1 using in_data_t = int8_t; using out_data_t = int8_t; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp index 9d01e92597..c060010a1b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp @@ -66,4 +66,3 @@ void add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp index 24cf237efe..0963b73f3d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp @@ -65,6 +65,3 @@ void add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck - - - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp index 1fdfa22033..5d36ac4182 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp @@ -65,5 +65,3 @@ void add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck - - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp index 5ca4bd7e35..3da69e7014 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp @@ -65,6 +65,3 @@ void add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck - - - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index 75b3d75670..365db66cc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -61,9 +61,6 @@ using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - - - // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp index 325424ab81..919d9d0d13 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp @@ -65,5 +65,3 @@ void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck - - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp index f33a384a61..30ec69692c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp @@ -65,4 +65,3 @@ void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp index cd3069d493..9a6a9ac5ea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp @@ -65,5 +65,3 @@ void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck - - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp index 4779270cf5..2920fc4ba2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp @@ -20,8 +20,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = std::tuple< - // clang-format off +using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = + std::tuple< + // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | @@ -61,17 +62,17 @@ using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = std::tuple< DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances( std::vector>& instances) { - add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_km_kn_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_dlops_int8_int8_int8_km_kn_mn_instances{}); } } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp index eb43247ec4..a794e06484 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp @@ -20,8 +20,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_dlops_int8_int8_int8_km_nk_mn_instances = std::tuple< - // clang-format off +using device_gemm_dlops_int8_int8_int8_km_nk_mn_instances = + std::tuple< + // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | @@ -61,20 +62,17 @@ using device_gemm_dlops_int8_int8_int8_km_nk_mn_instances = std::tuple< DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances( std::vector>& instances) { - add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_km_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_dlops_int8_int8_int8_km_nk_mn_instances{}); } } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck - - - - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp index fb7fcc557a..36d35b2ce8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp @@ -20,8 +20,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances = std::tuple< - // clang-format off +using device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances = + std::tuple< + // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | @@ -61,19 +62,17 @@ using device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances = std::tuple< DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances( std::vector>& instances) { - add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances{}); } } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck - - - diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp index e8b7d89373..e219aacdfb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp @@ -20,8 +20,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances = std::tuple< - // clang-format off +using device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances = + std::tuple< + // clang-format off // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | @@ -61,20 +62,17 @@ using device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances = std::tuple< DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances( std::vector>& instances) { - add_device_operation_instances(instances, device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances{}); } } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck - - - - diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index f266188844..d143828c54 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -74,6 +74,21 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(std::vector&); + } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation @@ -174,6 +189,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); } @@ -192,6 +210,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); } @@ -210,6 +231,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); } @@ -228,6 +252,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } @@ -250,6 +277,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } @@ -268,6 +298,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); @@ -289,6 +322,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } @@ -307,6 +343,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } @@ -354,6 +393,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && @@ -361,6 +403,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && @@ -368,6 +413,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && @@ -375,6 +423,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); } } diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index fab19f1b8e..009d2da48c 100644 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format -i -style=file {}' diff --git a/test/gemm_dlops/gemm_dlops_fp16.cpp b/test/gemm_dlops/gemm_dlops_fp16.cpp index 387af56419..e6e7a4b1b7 100644 --- a/test/gemm_dlops/gemm_dlops_fp16.cpp +++ b/test/gemm_dlops/gemm_dlops_fp16.cpp @@ -123,8 +123,6 @@ int main() PassThrough>{}(gemmPtr); } - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } - diff --git a/test/gemm_dlops/gemm_dlops_fp32.cpp b/test/gemm_dlops/gemm_dlops_fp32.cpp index 3cd2775bb6..aaece4c39d 100644 --- a/test/gemm_dlops/gemm_dlops_fp32.cpp +++ b/test/gemm_dlops/gemm_dlops_fp32.cpp @@ -123,7 +123,6 @@ int main() PassThrough>{}(gemmPtr); } - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } diff --git a/test/gemm_dlops/gemm_dlops_int8.cpp b/test/gemm_dlops/gemm_dlops_int8.cpp index e9591ce5b2..7103468f3e 100644 --- a/test/gemm_dlops/gemm_dlops_int8.cpp +++ b/test/gemm_dlops/gemm_dlops_int8.cpp @@ -123,9 +123,6 @@ int main() PassThrough>{}(gemmPtr); } - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } - - From c695dfa1a8784182ee5e513fd84e18ead7f41a8e Mon Sep 17 00:00:00 2001 From: j4yan Date: Thu, 21 Apr 2022 15:33:49 -0500 Subject: [PATCH 17/46] minor changes --- .../gpu/device/device_gemm_xdl.hpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index 71d476cb93..b4b70161dc 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -404,20 +404,18 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { - std::cout << ck::get_device_name() << std::endl; - - // if (ck::get_device_name() == "gfx1030") - // { - // return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - // arg.b_grid_desc_k0_n_k1_, - // arg.c_grid_desc_m_n_, - // arg.M01_, - // arg.N01_); - // } - // else - // { - // return false; - // } + if (ck::get_device_name() == "gfx1030") + { + return false; + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } } // polymorphic From fc97e9db4ae7bba5492c190297b2174e0df8962f Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Thu, 21 Apr 2022 23:16:12 +0000 Subject: [PATCH 18/46] if using xdl or dlops is moved into profiler_gemm_impl --- example/CMakeLists.txt | 1 + .../{utility => host_utility}/device_prop.hpp | 0 profiler/CMakeLists.txt | 60 ++-- profiler/include/profile_gemm_impl.hpp | 258 ++++++++++++------ profiler/src/profiler.cpp | 156 +++++------ test/CMakeLists.txt | 1 + 6 files changed, 278 insertions(+), 198 deletions(-) rename include/ck/{utility => host_utility}/device_prop.hpp (100%) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5f04125305..5292c5248d 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform diff --git a/include/ck/utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp similarity index 100% rename from include/ck/utility/device_prop.hpp rename to include/ck/host_utility/device_prop.hpp diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index a2cf6eeb62..48850e2679 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform @@ -24,38 +25,39 @@ include_directories(BEFORE set(PROFILER_SOURCE src/profiler.cpp src/profile_gemm.cpp - src/profile_gemm_bias_2d.cpp - src/profile_gemm_bias_relu.cpp - src/profile_gemm_bias_relu_add.cpp - src/profile_gemm_reduce.cpp - src/profile_batched_gemm.cpp - src/profile_conv_fwd.cpp - src/profile_conv_fwd_bias_relu.cpp - src/profile_conv_fwd_bias_relu_add.cpp - src/profile_conv_fwd_bias_relu_atomic_add.cpp - src/profile_convnd_bwd_data.cpp - src/profile_reduce.cpp - src/profile_grouped_gemm.cpp - src/profile_conv_bwd_weight.cpp - src/profile_batched_gemm_reduce.cpp + # src/profile_gemm_bias_2d.cpp + # src/profile_gemm_bias_relu.cpp + # src/profile_gemm_bias_relu_add.cpp + # src/profile_gemm_reduce.cpp + # src/profile_batched_gemm.cpp + # src/profile_conv_fwd.cpp + # src/profile_conv_fwd_bias_relu.cpp + # src/profile_conv_fwd_bias_relu_add.cpp + # src/profile_conv_fwd_bias_relu_atomic_add.cpp + # src/profile_convnd_bwd_data.cpp + # src/profile_reduce.cpp + # src/profile_grouped_gemm.cpp + # src/profile_conv_bwd_weight.cpp + # src/profile_batched_gemm_reduce.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) -target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) +# target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) -target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) -target_link_libraries(ckProfiler PRIVATE device_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) -target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_dlops_instance) +# target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) +# target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) +# target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +# target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) +# target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) +# target_link_libraries(ckProfiler PRIVATE device_reduce_instance) +# target_link_libraries(ckProfiler PRIVATE device_reduce_instance) +# target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) +# target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) +# target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index d143828c54..66deb475c0 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -12,6 +12,7 @@ #include "element_wise_operation.hpp" #include "device_gemm.hpp" #include "reference_gemm.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { @@ -115,6 +116,8 @@ void profile_gemm_impl(int do_verification, int StrideC, int KBatch) { + const bool is_xdl = cd::get_device_name() != "gfx1030"; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) @@ -179,84 +182,108 @@ void profile_gemm_impl(int do_verification, is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } } } @@ -267,87 +294,115 @@ void profile_gemm_impl(int do_verification, is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + gemm_ptrs); + } } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } } + else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(KBatch > 1) + if(is_xdl) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } } else { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } } } @@ -391,41 +446,62 @@ void profile_gemm_impl(int do_verification, is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + if(is_xdl) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + } } else if constexpr(is_same::value && is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + if(is_xdl) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + } } else if constexpr(is_same::value && is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + if(is_xdl) + { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + } } else if constexpr(is_same::value && is_same::value && is_same::value) { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + if(is_xdl) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + } } } diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 3cd454e351..6bf59341e8 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -5,20 +5,20 @@ #include int profile_gemm(int, char*[]); -int profile_gemm_bias_2d(int, char*[]); -int profile_gemm_bias_relu(int, char*[]); -int profile_gemm_bias_relu_add(int, char*[]); -int profile_gemm_reduce(int, char*[]); -int profile_batched_gemm(int, char*[]); -int profile_grouped_gemm(int, char*[]); -int profile_conv_fwd(int, char*[]); -int profile_conv_fwd_bias_relu(int, char*[]); -int profile_conv_fwd_bias_relu_add(int, char*[]); -int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); -int profile_convnd_bwd_data(int, char*[], int); -int profile_reduce(int, char*[]); -int profile_conv_bwd_weight(int, char*[]); -int profile_batched_gemm_reduce(int, char*[]); +// int profile_gemm_bias_2d(int, char*[]); +// int profile_gemm_bias_relu(int, char*[]); +// int profile_gemm_bias_relu_add(int, char*[]); +// int profile_gemm_reduce(int, char*[]); +// int profile_batched_gemm(int, char*[]); +// int profile_grouped_gemm(int, char*[]); +// int profile_conv_fwd(int, char*[]); +// int profile_conv_fwd_bias_relu(int, char*[]); +// int profile_conv_fwd_bias_relu_add(int, char*[]); +// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +// int profile_convnd_bwd_data(int, char*[], int); +// int profile_reduce(int, char*[]); +// int profile_conv_bwd_weight(int, char*[]); +// int profile_batched_gemm_reduce(int, char*[]); int main(int argc, char* argv[]) { @@ -26,70 +26,70 @@ int main(int argc, char* argv[]) { return profile_gemm(argc, argv); } - else if(strcmp(argv[1], "gemm_bias_2d") == 0) - { - return profile_gemm_bias_2d(argc, argv); - } - else if(strcmp(argv[1], "gemm_bias_relu") == 0) - { - return profile_gemm_bias_relu(argc, argv); - } - else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) - { - return profile_gemm_bias_relu_add(argc, argv); - } - else if(strcmp(argv[1], "gemm_reduce") == 0) - { - return profile_gemm_reduce(argc, argv); - } - else if(strcmp(argv[1], "batched_gemm") == 0) - { - return profile_batched_gemm(argc, argv); - } - else if(strcmp(argv[1], "batched_gemm_reduce") == 0) - { - return profile_batched_gemm_reduce(argc, argv); - } - else if(strcmp(argv[1], "grouped_gemm") == 0) - { - profile_grouped_gemm(argc, argv); - } - else if(strcmp(argv[1], "conv_fwd") == 0) - { - return profile_conv_fwd(argc, argv); - } - else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) - { - return profile_conv_fwd_bias_relu(argc, argv); - } - else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) - { - return profile_conv_fwd_bias_relu_add(argc, argv); - } - else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) - { - return profile_conv_fwd_bias_relu_atomic_add(argc, argv); - } - else if(strcmp(argv[1], "conv1d_bwd_data") == 0) - { - return profile_convnd_bwd_data(argc, argv, 1); - } - else if(strcmp(argv[1], "conv2d_bwd_data") == 0) - { - return profile_convnd_bwd_data(argc, argv, 2); - } - else if(strcmp(argv[1], "conv3d_bwd_data") == 0) - { - return profile_convnd_bwd_data(argc, argv, 3); - } - else if(strcmp(argv[1], "reduce") == 0) - { - return profile_reduce(argc, argv); - } - else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) - { - return profile_conv_bwd_weight(argc, argv); - } + // else if(strcmp(argv[1], "gemm_bias_2d") == 0) + // { + // return profile_gemm_bias_2d(argc, argv); + // } + // else if(strcmp(argv[1], "gemm_bias_relu") == 0) + // { + // return profile_gemm_bias_relu(argc, argv); + // } + // else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) + // { + // return profile_gemm_bias_relu_add(argc, argv); + // } + // else if(strcmp(argv[1], "gemm_reduce") == 0) + // { + // return profile_gemm_reduce(argc, argv); + // } + // else if(strcmp(argv[1], "batched_gemm") == 0) + // { + // return profile_batched_gemm(argc, argv); + // } + // else if(strcmp(argv[1], "batched_gemm_reduce") == 0) + // { + // return profile_batched_gemm_reduce(argc, argv); + // } + // else if(strcmp(argv[1], "grouped_gemm") == 0) + // { + // profile_grouped_gemm(argc, argv); + // } + // else if(strcmp(argv[1], "conv_fwd") == 0) + // { + // return profile_conv_fwd(argc, argv); + // } + // else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) + // { + // return profile_conv_fwd_bias_relu(argc, argv); + // } + // else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) + // { + // return profile_conv_fwd_bias_relu_add(argc, argv); + // } + // else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) + // { + // return profile_conv_fwd_bias_relu_atomic_add(argc, argv); + // } + // else if(strcmp(argv[1], "conv1d_bwd_data") == 0) + // { + // return profile_convnd_bwd_data(argc, argv, 1); + // } + // else if(strcmp(argv[1], "conv2d_bwd_data") == 0) + // { + // return profile_convnd_bwd_data(argc, argv, 2); + // } + // else if(strcmp(argv[1], "conv3d_bwd_data") == 0) + // { + // return profile_convnd_bwd_data(argc, argv, 3); + // } + // else if(strcmp(argv[1], "reduce") == 0) + // { + // return profile_reduce(argc, argv); + // } + // else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) + // { + // return profile_conv_bwd_weight(argc, argv); + // } else { // clang-format off diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e4f75df092..e023374f6c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform From cd2ce92f06a47fdca3ac6caf81eea909526e204b Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Thu, 21 Apr 2022 23:22:06 +0000 Subject: [PATCH 19/46] minor changes --- library/src/tensor_operation_instance/gpu/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 7b361b48bd..827c56c461 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform From bf8cea0769da210d559628dc39fbb7d352d7ce23 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 01:17:43 +0000 Subject: [PATCH 20/46] minor changes --- profiler/include/profile_gemm_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 66deb475c0..0556756798 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -116,7 +116,7 @@ void profile_gemm_impl(int do_verification, int StrideC, int KBatch) { - const bool is_xdl = cd::get_device_name() != "gfx1030"; + const bool is_xdl = ck::get_device_name() != "gfx1030"; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { From 2f705068921cb504edf74e3081451b5c01b1db36 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 05:35:41 +0000 Subject: [PATCH 21/46] remove is_xdl from profile_gemm_impl --- ...dlops_int8_int8_int8_km_kn_mn_instance.cpp | 12 +- ...dlops_int8_int8_int8_km_nk_mn_instance.cpp | 12 +- ...dlops_int8_int8_int8_mk_kn_mn_instance.cpp | 12 +- ...dlops_int8_int8_int8_mk_nk_mn_instance.cpp | 12 +- profiler/include/profile_gemm_impl.hpp | 258 ++++++------------ 5 files changed, 115 insertions(+), 191 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp index 2920fc4ba2..8d055682d9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp @@ -54,12 +54,12 @@ using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = /* * K1 = 4 */ - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp index a794e06484..a036afaf59 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp @@ -54,12 +54,12 @@ using device_gemm_dlops_int8_int8_int8_km_nk_mn_instances = /* * K1 = 4 */ - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp index 36d35b2ce8..2ef12d33b1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp @@ -54,12 +54,12 @@ using device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances = /* * K1 = 4 */ - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp index e219aacdfb..8a2a114536 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp @@ -54,12 +54,12 @@ using device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances = /* * K1 = 4 */ - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 0556756798..d143828c54 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -12,7 +12,6 @@ #include "element_wise_operation.hpp" #include "device_gemm.hpp" #include "reference_gemm.hpp" -#include "device_prop.hpp" namespace ck { namespace tensor_operation { @@ -116,8 +115,6 @@ void profile_gemm_impl(int do_verification, int StrideC, int KBatch) { - const bool is_xdl = ck::get_device_name() != "gfx1030"; - auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) @@ -182,108 +179,84 @@ void profile_gemm_impl(int do_verification, is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); } else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); } else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); } else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } } } @@ -294,115 +267,87 @@ void profile_gemm_impl(int do_verification, is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); } else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } - else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) + if(KBatch > 1) { - - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } else { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } } } @@ -446,62 +391,41 @@ void profile_gemm_impl(int do_verification, is_same::value && is_same::value) { - if(is_xdl) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) - { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { - if(is_xdl) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); - } + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); } } From 4ba880eeef86be9e3f438b51cd9ea6bd4a62d8fc Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 05:43:58 +0000 Subject: [PATCH 22/46] make IsSupportedArg dependent on arch for other device_gemm --- .../gpu/device/device_gemm_xdl.hpp | 2 +- .../gpu/device/device_gemm_xdl_c_shuffle.hpp | 17 ++++++++++++----- .../gpu/device/device_gemm_xdl_cshuffle.hpp | 11 +++++++++-- .../gpu/device/device_gemm_xdl_splitk.hpp | 18 +++++++++++++----- script/clang-format-overwrite.sh | 2 +- 5 files changed, 36 insertions(+), 14 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index b4b70161dc..bdf1f43b37 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -404,7 +404,7 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { - if (ck::get_device_name() == "gfx1030") + if(ck::get_device_name() == "gfx1030") { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index 155eb5225c..f033fbfcb6 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -369,11 +369,18 @@ struct DeviceGemmXdl_C_Shuffle static bool IsSupportedArgument(const Argument& arg) { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + if(ck::get_device_name() == "gfx1030") + { + return false; + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index 324b33ffb2..ca5b2c5e09 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -591,8 +591,15 @@ struct DeviceGemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); + if(ck::get_device_name() == "gfx1030") + { + return false; + } + else + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); + } } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp index db6c884739..fd4ca0da4e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -545,11 +545,19 @@ struct DeviceGemmXdlSplitK static bool IsSupportedArgument(const Argument& arg) { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + if(ck::get_device_name() == "gfx1030") + { + return false; + } + else + { + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } } // polymorphic diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 009d2da48c..fab19f1b8e 100644 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' From 5fd0997e0d783f9ea0fed3b4a8c5f7ce25d6cb54 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 05:46:30 +0000 Subject: [PATCH 23/46] minor changes --- .../ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp | 1 + .../ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp | 1 + .../ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp | 1 + 3 files changed, 3 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index f033fbfcb6..c4c62e89f6 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -10,6 +10,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v3r1.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index ca5b2c5e09..30af71bbc0 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -9,6 +9,7 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "tensor_operation/gpu/device/gemm_specialization.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp index fd4ca0da4e..4cf019ed35 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -12,6 +12,7 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r4.hpp" #include "gemm_specialization.hpp" +#include "device_prop.hpp" #ifndef CK_RUN_KERNEL_AND_TIME #define CK_RUN_KERNEL_AND_TIME 1 From 78ade2d74aaf247df89199ba68fb69b29b98d089 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 15:10:30 +0000 Subject: [PATCH 24/46] minor changes --- profiler/include/profile_gemm_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index d143828c54..f0ab2f43ca 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -573,7 +573,7 @@ void profile_gemm_impl(int do_verification, } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" << std::endl; } } From 1d58d7ea554784dba6c804407f11639fcbbfccbf Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 18:20:09 +0000 Subject: [PATCH 25/46] fix a bug in f_generate_tensor_value --- .../include/ck/library/utility/check_err.hpp | 2 +- profiler/include/profile_gemm_impl.hpp | 6 ++++- test/gemm/gemm_util.hpp | 25 +++++++++++++------ 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 280ac83883..5ca1605f29 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -24,7 +24,7 @@ check_err(const std::vector& out, const std::vector& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-5, - double atol = 1e-8) + double atol = 3e-6) { if(out.size() != ref.size()) { diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index f0ab2f43ca..e0bedc6573 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -140,7 +140,11 @@ void profile_gemm_impl(int do_verification, std::size_t num_thread = 1; switch(init_method) { - case 0: break; + // case 0: break; + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 08c8edfb94..631e2875f2 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -139,7 +139,7 @@ struct TestGemm Tensor c_m_n_device_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - auto f_generate_tensor_value = [](auto desc, auto type) { + auto f_generate_tensor_value = [](auto& desc, auto type) { using dataType = decltype(type); if(std::is_same::value) @@ -166,12 +166,19 @@ struct TestGemm // Arrange ck::gemm_util::GemmParams params; - params.M = 1024; - params.N = 1024; - params.K = 1024; - params.StrideA = 1024; - params.StrideB = 1024; - params.StrideC = 1024; + // params.M = 1024; + // params.N = 1024; + // params.K = 1024; + // params.StrideA = 1024; + // params.StrideB = 1024; + // params.StrideC = 1024; + + params.M = 256; + params.N = 256; + params.K = 256; + params.StrideA = 256; + params.StrideB = 256; + params.StrideC = 256; auto host_tensors = PrepareGemmTensor(params); @@ -216,6 +223,10 @@ struct TestGemm std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } + LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " a_host: \n", a.mData, ", ") << std::endl; + LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " b_host: \n", b.mData, ", ") << std::endl; + LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " c_host: \n", c_host.mData, ", ") << std::endl; + LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " c_device: \n", c_device.mData, ", ") << std::endl; return res; } }; From f06ba361917fabe6c36ce30560525496723de977 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 18:25:17 +0000 Subject: [PATCH 26/46] add 64x64x64 for gemm_dlops_int8 --- .../device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp index 2ef12d33b1..65a7a1ab79 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp @@ -61,7 +61,8 @@ using device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances = // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; From 0c3f0babaacc697fc04c908bdbfe3e782a173bcf Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Fri, 22 Apr 2022 18:32:54 +0000 Subject: [PATCH 27/46] add 64x64x64 for gemm_dlops_int8 --- .../device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp | 3 ++- .../device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp | 3 ++- .../device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp index 8d055682d9..f7c2f63d56 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp @@ -61,7 +61,8 @@ using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp index a036afaf59..4ed8191ae3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp @@ -61,7 +61,8 @@ using device_gemm_dlops_int8_int8_int8_km_nk_mn_instances = // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp index 8a2a114536..55b2e78462 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp @@ -61,7 +61,8 @@ using device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances = // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; From 578eec736827b3dd4316eea616312dd6557e72f1 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Mon, 25 Apr 2022 22:42:34 +0000 Subject: [PATCH 28/46] comment out 3 layouts in gemm_dlops_int8; add 32x32x32 for gemm_dlops_int8; init A values to 1 --- library/include/ck/library/utility/check_err.hpp | 4 ++-- ...ice_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp | 7 ++++--- test/gemm/gemm_util.hpp | 10 ++++------ test/gemm_dlops/gemm_dlops_int8.cpp | 2 ++ 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 5ca1605f29..7cd6cc34c9 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -173,8 +173,8 @@ check_err(const std::vector& out, { if(out[i] != ref[i]) { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] - << std::endl + std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) + << " != " << static_cast(ref[i]) << std::endl << msg << std::endl; return false; } diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp index f7c2f63d56..1dafade20f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp @@ -38,6 +38,7 @@ using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 8, 1, 2, 2, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 2> // /* * K1 = 2 @@ -60,9 +61,9 @@ using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 631e2875f2..79e437d4da 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -152,8 +152,10 @@ struct TestGemm } }; - f_generate_tensor_value(a_m_k, ADataType{}); - f_generate_tensor_value(b_k_n, BDataType{}); + // f_generate_tensor_value(a_m_k, ADataType{}); + // f_generate_tensor_value(b_k_n, BDataType{}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); } @@ -223,10 +225,6 @@ struct TestGemm std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } - LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " a_host: \n", a.mData, ", ") << std::endl; - LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " b_host: \n", b.mData, ", ") << std::endl; - LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " c_host: \n", c_host.mData, ", ") << std::endl; - LogRangeAsType(std::cout << gemmPtr->GetTypeString() + " c_device: \n", c_device.mData, ", ") << std::endl; return res; } }; diff --git a/test/gemm_dlops/gemm_dlops_int8.cpp b/test/gemm_dlops/gemm_dlops_int8.cpp index 7103468f3e..bff29111c3 100644 --- a/test/gemm_dlops/gemm_dlops_int8.cpp +++ b/test/gemm_dlops/gemm_dlops_int8.cpp @@ -69,6 +69,7 @@ int main() PassThrough>{}(gemmPtr); } +#if 0 gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemmPtrs); @@ -123,6 +124,7 @@ int main() PassThrough>{}(gemmPtr); } +#endif std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } From aa0acfa2a87e68c8a136ed6a0e621d3a2a21bb4c Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 30 Apr 2022 03:38:29 +0000 Subject: [PATCH 29/46] fix --- .../gpu/element/element_wise_operation.hpp | 10 ++--- include/ck/utility/dynamic_buffer.hpp | 2 +- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 2 +- .../conv_add_fwd_driver_offline_nchwc.cpp | 6 +-- .../conv_bwd_driver_offline.cpp | 6 +-- .../conv_fwd_driver_offline.cpp | 6 +-- .../conv_fwd_driver_offline_nchwc.cpp | 32 +++++++-------- .../conv_maxpool_fwd_driver_offline_nchwc.cpp | 24 +++++------ .../conv_wrw_driver_offline.cpp | 8 ++-- ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 40 ++++--------------- ...mm_dlops_f32_f32_f32_km_nk_mn_instance.cpp | 31 ++++---------- ...mm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp | 31 ++++---------- ...mm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp | 31 ++++---------- profiler/include/profile_gemm_impl.hpp | 5 ++- test/gemm/gemm_util.hpp | 17 +------- 15 files changed, 83 insertions(+), 168 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 7fd0b7a36f..5b3606e859 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -156,9 +156,8 @@ struct RequantReluRequant float gemm_requant = scaleGemm_ * static_cast(x); float relu = gemm_requant > 0 ? gemm_requant : 0; float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 - : relu_requant); + y = static_cast(relu_requant > 127 ? 127 + : relu_requant < -128 ? -128 : relu_requant); } // for reference_gemm @@ -167,9 +166,8 @@ struct RequantReluRequant float gemm_requant = scaleGemm_ * x; float relu = gemm_requant > 0 ? gemm_requant : 0; float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 - : relu_requant); + y = static_cast(relu_requant > 127 ? 127 + : relu_requant < -128 ? -128 : relu_requant); } float scaleGemm_; diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 1f52855293..c00982dfff 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -151,7 +151,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_STORE bool constexpr use_amd_buffer_addressing = true; #else - bool constexpr use_amd_buffer_addressing = false; + bool constexpr use_amd_buffer_addressing = false; #endif #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE diff --git a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index de3489f924..18e712fb47 100644 --- a/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/library/include/ck/library/obselete_driver_offline/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -424,7 +424,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; #else - const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0]; const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0); const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1); diff --git a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp index 82d92fa64d..a7541f03de 100644 --- a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp @@ -248,9 +248,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; diff --git a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp index c130cd609c..c4dcb7c085 100644 --- a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp @@ -263,9 +263,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp index 94c5fd9ca9..ab8beec87b 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp @@ -257,9 +257,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 0 using in_data_t = bhalf_t; using acc_data_t = float; diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp index ff7e5c7a15..6fb8b4c2aa 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp @@ -165,15 +165,15 @@ int main(int argc, char* argv[]) constexpr auto K0 = Number<1>{}; constexpr auto K1 = Number<4>{}; #elif 1 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; #elif 0 constexpr auto N = Number<1>{}; constexpr auto Hi = Number<1080>{}; @@ -212,10 +212,10 @@ int main(int argc, char* argv[]) constexpr auto conv_dilation_w = I1; #if 1 - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; #else constexpr auto in_left_pad_h = I0; constexpr auto in_left_pad_w = I0; @@ -235,9 +235,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; diff --git a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp index 388656e747..fb7e8e975b 100644 --- a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp @@ -181,15 +181,15 @@ int main(int argc, char* argv[]) constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu; #if 1 - constexpr auto N = Number<1>{}; - constexpr auto Hi = Number<1080>{}; - constexpr auto Wi = Number<1920>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - constexpr auto C0 = Number<2>{}; - constexpr auto C1 = Number<8>{}; - constexpr auto K0 = Number<2>{}; - constexpr auto K1 = Number<8>{}; + constexpr auto N = Number<1>{}; + constexpr auto Hi = Number<1080>{}; + constexpr auto Wi = Number<1920>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + constexpr auto C0 = Number<2>{}; + constexpr auto C1 = Number<8>{}; + constexpr auto K0 = Number<2>{}; + constexpr auto K1 = Number<8>{}; #elif 0 constexpr auto N = Number<1>{}; constexpr auto Hi = Number<1080>{}; @@ -247,9 +247,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; diff --git a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp index 23b1039fec..1ac974202c 100644 --- a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp @@ -229,10 +229,10 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using out_data_t = half_t; - using acc_data_t = float; - using wei_data_t = float; + using in_data_t = half_t; + using out_data_t = half_t; + using acc_data_t = float; + using wei_data_t = float; #elif 1 using in_data_t = int8_t; using out_data_t = int8_t; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index 365db66cc4..4f73d679dd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -29,38 +29,14 @@ using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - - // repeat the above configurartion, but changing K1 to 4, NOT working for fp32 - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp index 919d9d0d13..07441f1d17 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp @@ -29,29 +29,14 @@ using device_gemm_dlops_f32_f32_f32_km_nk_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> +// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp index 30ec69692c..5776c5426f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp @@ -29,29 +29,14 @@ using device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> +// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp index 9a6a9ac5ea..e95177a69d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp @@ -29,29 +29,14 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index e0bedc6573..569d28c657 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -141,7 +141,7 @@ void profile_gemm_impl(int do_verification, switch(init_method) { // case 0: break; - case 0: + case 0: a_m_k.GenerateTensorValue(GeneratorTensor_1{}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; @@ -577,7 +577,8 @@ void profile_gemm_impl(int do_verification, } else { - std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" << std::endl; + std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" + << std::endl; } } diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 79e437d4da..673ddec0a8 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -139,22 +139,7 @@ struct TestGemm Tensor c_m_n_device_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - auto f_generate_tensor_value = [](auto& desc, auto type) { - using dataType = decltype(type); - - if(std::is_same::value) - { - desc.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - } - else - { - desc.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - }; - - // f_generate_tensor_value(a_m_k, ADataType{}); - // f_generate_tensor_value(b_k_n, BDataType{}); - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); From 2ca774b8fc04f1202910e8e1a5aeb1e30fd83107 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Tue, 3 May 2022 20:28:38 +0000 Subject: [PATCH 30/46] start fixing tuning parameters --- ..._gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index 4f73d679dd..429a6ad6cf 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -29,14 +29,14 @@ using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<1, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // clang-format on >; From d9cd2e56ac5c7bbb08e6ce3af8315fe7df0f8afa Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Thu, 5 May 2022 15:26:27 +0000 Subject: [PATCH 31/46] monir --- .../gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index 429a6ad6cf..a45f11a3a0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -30,7 +30,7 @@ using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<1, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 2, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, From f3bd93a23b75283de41020cac6c6e79912385ea9 Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Thu, 5 May 2022 15:32:56 +0000 Subject: [PATCH 32/46] minor changes --- .../gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index a45f11a3a0..36884b9ca5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -30,7 +30,7 @@ using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 2, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, From 9da908f05b3c97b06e362a2510bd170c442e4b7b Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Thu, 5 May 2022 15:38:12 +0000 Subject: [PATCH 33/46] minor changes --- .../gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index 36884b9ca5..e3f3478932 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -29,8 +29,8 @@ using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, From 1ea2ef56f281fa57f0f3da54b52fbcadad23b17e Mon Sep 17 00:00:00 2001 From: Jianfeng yan Date: Thu, 5 May 2022 15:44:26 +0000 Subject: [PATCH 34/46] minor changes --- .../gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp index e3f3478932..f9e902aed8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp @@ -29,11 +29,10 @@ using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 2, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, From 3623f9c315f3ca971533d713c37b7b71940c7f03 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 11 May 2022 15:13:14 +0000 Subject: [PATCH 35/46] fixing --- ...ps_v2r3.hpp => blockwise_gemm_dl_v2r3.hpp} | 13 +- .../blockwise_tensor_slice_transfer_v5r1.hpp | 11 -- ...vice_gemm_dlops.hpp => device_gemm_dl.hpp} | 156 +++++++++--------- .../gpu/device/device_gemm_xdl.hpp | 16 +- ...ops_v1r3.hpp => gridwise_gemm_dl_v1r3.hpp} | 26 ++- ...lops.hpp => threadwise_contraction_dl.hpp} | 13 +- .../gpu/gemm/CMakeLists.txt | 63 ++++--- ..._gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp | 53 ++++++ ..._gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp | 53 ++++++ ..._gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp | 53 ++++++ ..._gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp | 53 ++++++ ..._gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp | 52 ++++++ ..._gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp | 52 ++++++ ..._gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp | 53 ++++++ ..._gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp | 53 ++++++ ...ice_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp | 51 ++++++ ...ice_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp | 51 ++++++ ...ice_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp | 51 ++++++ ...ice_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp | 51 ++++++ ...mm_dlops_f16_f16_f16_km_kn_mn_instance.cpp | 68 -------- ...mm_dlops_f16_f16_f16_km_nk_mn_instance.cpp | 67 -------- ...mm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp | 67 -------- ...mm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp | 67 -------- ...mm_dlops_f32_f32_f32_km_kn_mn_instance.cpp | 51 ------ ...mm_dlops_f32_f32_f32_km_nk_mn_instance.cpp | 52 ------ ...mm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp | 52 ------ ...mm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp | 52 ------ ...dlops_int8_int8_int8_km_kn_mn_instance.cpp | 80 --------- ...dlops_int8_int8_int8_km_nk_mn_instance.cpp | 79 --------- ...dlops_int8_int8_int8_mk_kn_mn_instance.cpp | 79 --------- ...dlops_int8_int8_int8_mk_nk_mn_instance.cpp | 79 --------- ..._c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp} | 6 +- ..._c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp} | 6 +- ..._c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp} | 6 +- ..._c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp} | 6 +- profiler/CMakeLists.txt | 3 +- test/CMakeLists.txt | 2 +- test/gemm/CMakeLists.txt | 8 +- test/gemm_dl/CMakeLists.txt | 11 ++ .../gemm_dl_fp16.cpp} | 18 +- .../gemm_dl_fp32.cpp} | 0 .../gemm_dl_int8.cpp} | 0 test/gemm_dlops/CMakeLists.txt | 15 -- 43 files changed, 800 insertions(+), 998 deletions(-) rename include/ck/tensor_operation/gpu/block/{blockwise_gemm_dlops_v2r3.hpp => blockwise_gemm_dl_v2r3.hpp} (97%) rename include/ck/tensor_operation/gpu/device/{device_gemm_dlops.hpp => device_gemm_dl.hpp} (79%) rename include/ck/tensor_operation/gpu/grid/{gridwise_gemm_dlops_v1r3.hpp => gridwise_gemm_dl_v1r3.hpp} (97%) rename include/ck/tensor_operation/gpu/thread/{threadwise_contraction_dlops.hpp => threadwise_contraction_dl.hpp} (96%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp rename library/src/tensor_operation_instance/gpu/gemm/{device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp => device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp} (97%) rename library/src/tensor_operation_instance/gpu/gemm/{device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp => device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp} (97%) rename library/src/tensor_operation_instance/gpu/gemm/{device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp => device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp} (97%) rename library/src/tensor_operation_instance/gpu/gemm/{device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp => device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp} (97%) create mode 100644 test/gemm_dl/CMakeLists.txt rename test/{gemm_dlops/gemm_dlops_fp16.cpp => gemm_dl/gemm_dl_fp16.cpp} (85%) rename test/{gemm_dlops/gemm_dlops_fp32.cpp => gemm_dl/gemm_dl_fp32.cpp} (100%) rename test/{gemm_dlops/gemm_dlops_int8.cpp => gemm_dl/gemm_dl_int8.cpp} (100%) delete mode 100644 test/gemm_dlops/CMakeLists.txt diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 15e7fd9028..f7fa867e16 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,10 +1,8 @@ -#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP -#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP - +#pragma once #include "common_header.hpp" #include "tensor_adaptor.hpp" #include "threadwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_contraction_dlops.hpp" +#include "threadwise_contraction_dl.hpp" namespace ck { @@ -41,7 +39,7 @@ template ::type = false> -struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); public: - __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + __device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id())}, a_thread_copy_{ @@ -227,7 +225,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); constexpr auto threadwise_contraction = - ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< FloatA, FloatB, FloatC, @@ -408,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 0b737153b0..3a0bfa2b76 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -75,17 +75,6 @@ struct BlockwiseTensorSliceTransfer_v5r1 } } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); - } - } - template __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp similarity index 79% rename from include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp rename to include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp index 339acc4de0..984f5b7560 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -12,7 +12,7 @@ #include "tensor_descriptor_helper.hpp" #include "gemm_specialization.hpp" #include "element_wise_operation.hpp" -#include "gridwise_gemm_dlops_v1r3.hpp" +#include "gridwise_gemm_dl_v1r3.hpp" #include "device_prop.hpp" namespace ck { @@ -63,7 +63,7 @@ template < is_same_v && is_same_v, bool> = false> -struct DeviceGemmDlops +struct DeviceGemmDl : public DeviceGemm { static constexpr auto I0 = Number<0>{}; @@ -194,39 +194,39 @@ struct DeviceGemmDlops // GridwiseGemm using GridwiseGemm = - GridwiseGemmDlops_km_kn_mn_v1r3; + GridwiseGemmDl_km_kn_mn_v1r3; using AGridDesc_K0_M0_M1_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); @@ -267,9 +267,9 @@ struct DeviceGemmDlops b_element_op_{b_element_op}, c_element_op_{c_element_op} { - a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceGemmDlops::MakeCGridDescriptor_M_N(M, N, StrideC); + a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC); if(GridwiseGemm::CheckValidity( a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) @@ -304,7 +304,7 @@ struct DeviceGemmDlops index_t M01_; index_t N01_; - // TODO: unused since gridwise_gemm_dlops_v1r3 does NOT support prologue for the time being. + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -313,7 +313,7 @@ struct DeviceGemmDlops // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmDlops::Argument; + using Argument = DeviceGemmDl::Argument; float Run(const Argument& arg, int nrepeat = 1) { @@ -336,7 +336,7 @@ struct DeviceGemmDlops arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) { throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"); } const index_t grid_size = GridwiseGemm::CalculateGridSize( @@ -352,15 +352,15 @@ struct DeviceGemmDlops if(has_main_k_block_loop && has_double_tail_k_block_loop) { const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; ave_time = launch_and_time_kernel(kernel, nrepeat, @@ -378,15 +378,15 @@ struct DeviceGemmDlops else if(has_main_k_block_loop && !has_double_tail_k_block_loop) { const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; ave_time = launch_and_time_kernel(kernel, nrepeat, @@ -404,15 +404,15 @@ struct DeviceGemmDlops else if(!has_main_k_block_loop && has_double_tail_k_block_loop) { const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; ave_time = launch_and_time_kernel(kernel, nrepeat, @@ -430,15 +430,15 @@ struct DeviceGemmDlops else { const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; ave_time = launch_and_time_kernel(kernel, nrepeat, @@ -563,7 +563,7 @@ struct DeviceGemmDlops auto str = std::stringstream(); // clang-format off - str << "DeviceGemmDlops" + str << "DeviceGemmDl" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index bdf1f43b37..eee6beb5b9 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -404,18 +404,16 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1030") + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) { return false; } - else - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index f3668a0c28..07a556d996 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,11 +1,9 @@ -#ifndef CK_GRIDWISE_GEMM_V1R3_HPP -#define CK_GRIDWISE_GEMM_V1R3_HPP - +#pragma once #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r3.hpp" +#include "blockwise_gemm_dl_v2r3.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_set.hpp" @@ -26,14 +24,13 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dlops_v1r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -85,7 +82,7 @@ template -struct GridwiseGemmDlops_km_kn_mn_v1r3 +struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -372,7 +369,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockSize, FloatAB, FloatAB, @@ -580,4 +577,3 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dlops.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp similarity index 96% rename from include/ck/tensor_operation/gpu/thread/threadwise_contraction_dlops.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 8b75381026..6a532c79f9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dlops.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,6 +1,4 @@ -#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP -#define CK_THREADWISE_CONTRACTION_DLOPS_HPP - +#pragma once #include "common_header.hpp" #include "math.hpp" @@ -25,9 +23,9 @@ template ::type = false> -struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 +struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 { - __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() + __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -124,9 +122,9 @@ template ::type = false> -struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { - __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ }; } // namespace ck -#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 5cbdc5d421..2ae89c8ff1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -1,4 +1,3 @@ -# device_gemm_instance set(DEVICE_GEMM_XDL_INSTANCE_SOURCE device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; @@ -8,10 +7,10 @@ set(DEVICE_GEMM_XDL_INSTANCE_SOURCE device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; @@ -35,34 +34,30 @@ set(DEVICE_GEMM_XDL_INSTANCE_SOURCE device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; ) -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_XDL_INSTANCE_SOURCE}) - -target_compile_features(device_gemm_instance PUBLIC) -set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) - -clang_tidy_check(device_gemm_instance) - - -set(DEVICE_GEMM_DLOPS_INSTANCE_SOURCE - device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp; - device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp; - device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp; - device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp; - device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp; - device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp; - device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp; - device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp; - device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp; - device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp; - device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp; - device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp; +add_library(device_gemm_xdl_instance SHARED ${DEVICE_GEMM_XDL_INSTANCE_SOURCE}) +target_compile_features(device_gemm_xdl_instance PUBLIC) +set_target_properties(device_gemm_xdl_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_gemm_xdl_instance LIBRARY DESTINATION lib) +clang_tidy_check(device_gemm_xdl_instance) + + +set(DEVICE_GEMM_DL_INSTANCE_SOURCE +# device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; +# device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; +# device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp; +# device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp; +# device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp; +# device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp; +# device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp; +# device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp; +# device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp; +# device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp; ) -add_library(device_gemm_dlops_instance SHARED ${DEVICE_GEMM_DLOPS_INSTANCE_SOURCE}) - -target_compile_features(device_gemm_dlops_instance PUBLIC) -set_target_properties(device_gemm_dlops_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_dlops_instance LIBRARY DESTINATION lib) - -clang_tidy_check(device_gemm_dlops_instance) +add_library(device_gemm_dl_instance SHARED ${DEVICE_GEMM_DL_INSTANCE_SOURCE}) +target_compile_features(device_gemm_dl_instance PUBLIC) +set_target_properties(device_gemm_dl_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_gemm_dl_instance LIBRARY DESTINATION lib) +clang_tidy_check(device_gemm_dl_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..2b9a1c140a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< + + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..a82dfa90b2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..37d51e9410 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..d51191a3d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..1dd3829676 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + // #######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 2, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> +// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..3514650a22 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> +// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..26e0f98a54 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> +// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, +// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..c21f0736d1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..0fa75a0f43 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = + std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..ab63994987 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = + std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..5d98de4650 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..afc378a93d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp deleted file mode 100644 index c060010a1b..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dlops_f16_f16_f16_km_kn_mn_instances = std::tuple< - - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_km_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp deleted file mode 100644 index 0963b73f3d..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_dlops_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_km_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 5d36ac4182..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index 3da69e7014..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 2 - */ - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp deleted file mode 100644 index f9e902aed8..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 2, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // clang-format on - >; - -void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_km_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp deleted file mode 100644 index 07441f1d17..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_dlops_f32_f32_f32_km_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> -// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // clang-format on - >; - -void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_km_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp deleted file mode 100644 index 5776c5426f..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> -// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDlops< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // clang-format on - >; - -void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp deleted file mode 100644 index e95177a69d..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp deleted file mode 100644 index 1dafade20f..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dlops_int8_int8_int8_km_kn_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 8, 1, 2, 2, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 2> - // - /* - * K1 = 2 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 4 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, - device_gemm_dlops_int8_int8_int8_km_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp deleted file mode 100644 index 4ed8191ae3..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_dlops_int8_int8_int8_km_nk_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // - /* - * K1 = 2 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 4 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, - device_gemm_dlops_int8_int8_int8_km_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp deleted file mode 100644 index 65a7a1ab79..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // - /* - * K1 = 2 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 4 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, - device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp deleted file mode 100644 index 55b2e78462..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_dlops.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - /* - * K1 = 1 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // - /* - * K1 = 2 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - - /* - * K1 = 4 - */ - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, - device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances{}); -} - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index 4530d95c72..2185b55aac 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index 4214c71efb..90966349b2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index 39bb7e1473..aa5a13001c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 2ddde9e630..82eec1164a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -45,11 +45,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 18f1befaf4..8ecccc7e98 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -46,7 +46,8 @@ add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE conv_fwd_util) target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_dl_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_xdl_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_dlops_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e49c1a1318..1c95efa0d0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -53,7 +53,7 @@ add_subdirectory(space_filling_curve) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) -add_subdirectory(gemm_dlops) +add_subdirectory(gemm_dl) add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index 83b3c1e2e3..ea581ee781 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,15 +1,15 @@ add_test_executable(test_gemm_fp32 gemm_fp32.cpp) target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) +target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_xdl_instance) add_test_executable(test_gemm_fp16 gemm_fp16.cpp) target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) +target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_xdl_instance) add_test_executable(test_gemm_bf16 gemm_bf16.cpp) target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) -target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) +target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_xdl_instance) add_test_executable(test_gemm_int8 gemm_int8.cpp) target_link_libraries(test_gemm_int8 PRIVATE host_tensor) -target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) +target_link_libraries(test_gemm_int8 PRIVATE device_gemm_xdl_instance) diff --git a/test/gemm_dl/CMakeLists.txt b/test/gemm_dl/CMakeLists.txt new file mode 100644 index 0000000000..6486474771 --- /dev/null +++ b/test/gemm_dl/CMakeLists.txt @@ -0,0 +1,11 @@ +add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp) +target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_dl_instance) + +add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp) +target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_dl_instance) + +add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp) +target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor) +TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_dl_instance) diff --git a/test/gemm_dlops/gemm_dlops_fp16.cpp b/test/gemm_dl/gemm_dl_fp16.cpp similarity index 85% rename from test/gemm_dlops/gemm_dlops_fp16.cpp rename to test/gemm_dl/gemm_dl_fp16.cpp index e6e7a4b1b7..1e78cf0d64 100644 --- a/test/gemm_dlops/gemm_dlops_fp16.cpp +++ b/test/gemm_dl/gemm_dl_fp16.cpp @@ -14,7 +14,7 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_gemm_dlops.hpp" +#include "device_gemm_dl.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -31,10 +31,10 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(std::vector&); +//void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +//void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +//void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -53,7 +53,7 @@ int main() bool res = true; std::vector gemmPtrs; ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -71,7 +71,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -89,7 +89,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -107,7 +107,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { diff --git a/test/gemm_dlops/gemm_dlops_fp32.cpp b/test/gemm_dl/gemm_dl_fp32.cpp similarity index 100% rename from test/gemm_dlops/gemm_dlops_fp32.cpp rename to test/gemm_dl/gemm_dl_fp32.cpp diff --git a/test/gemm_dlops/gemm_dlops_int8.cpp b/test/gemm_dl/gemm_dl_int8.cpp similarity index 100% rename from test/gemm_dlops/gemm_dlops_int8.cpp rename to test/gemm_dl/gemm_dl_int8.cpp diff --git a/test/gemm_dlops/CMakeLists.txt b/test/gemm_dlops/CMakeLists.txt deleted file mode 100644 index 4d1e8d53bf..0000000000 --- a/test/gemm_dlops/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_test_executable(test_gemm_dlops_fp32 gemm_dlops_fp32.cpp) -target_link_libraries(test_gemm_dlops_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance) - -add_test_executable(test_gemm_dlops_fp16 gemm_dlops_fp16.cpp) -target_link_libraries(test_gemm_dlops_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_dlops_fp16 PRIVATE device_gemm_dlops_instance) - -# add_test_executable(test_gemm_dlops_bf16 gemm_dlops_bf16.cpp) -# target_link_libraries(test_gemm_dlops_bf16 PRIVATE host_tensor) -# target_link_libraries(test_gemm_dlops_bf16 PRIVATE device_gemm_dlops_instance) - -add_test_executable(test_gemm_dlops_int8 gemm_dlops_int8.cpp) -target_link_libraries(test_gemm_dlops_int8 PRIVATE host_tensor) -target_link_libraries(test_gemm_dlops_int8 PRIVATE device_gemm_dlops_instance) From e95e1bf13af599d3ec4353fdc0046630bd9c4234 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 12 May 2022 04:48:22 +0000 Subject: [PATCH 36/46] adding example --- example/01_gemm/CMakeLists.txt | 1 + include/ck/config.hpp | 2 +- .../ck/tensor_operation/gpu/device/device_gemm_dl.hpp | 9 +++++---- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 9 ++------- include/ck/utility/static_buffer.hpp | 9 +++++++-- test/gemm_dl/gemm_dl_fp16.cpp | 10 +++++++--- 6 files changed, 23 insertions(+), 17 deletions(-) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 696d3bac42..f5c0c233b5 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) diff --git a/include/ck/config.hpp b/include/ck/config.hpp index eedeb7e136..919af1e6dd 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -15,7 +15,7 @@ #ifdef CK_USE_LAUNCH_BOUNDS #define CK_MAX_THREAD_PER_BLOCK 256 -#define CK_MIN_BLOCK_PER_CU 2 +#define CK_MIN_BLOCK_PER_CU 1 #endif // check GPU target diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp index 984f5b7560..473b006a81 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -472,15 +472,16 @@ struct DeviceGemmDl static bool IsSupportedArgument(const Argument& arg) { +#if 0 if(ck::get_device_name() == "gfx1030") +#else + if(true) +#endif { return GridwiseGemm::CheckValidity( arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); } - else - { - return false; - } + else { return false; } } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 07a556d996..fb8ac230ff 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -404,13 +404,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 auto c_thread_buf = make_static_buffer( c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); - ThreadwiseTensorSliceSet_v1{} - .Run(c_thread_desc_m10_m11_n10_n11, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); + // Initialize C + c_thread_buf.Clear(); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index f36328fa5f..4e0e965efc 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray { return base::operator()(i); } + + __host__ __device__ void Clear() + { + static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); + } }; // static buffer for vector @@ -146,9 +151,9 @@ struct StaticBufferTupleOfVector __host__ __device__ void Clear() { - const index_t numScalars = NumOfVector * ScalarPerVector; + constexpr index_t NumScalars = NumOfVector * ScalarPerVector; - static_for<0, Number{}, 1>{}([&](auto i) { SetAsType(i, S{0}); }); + static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); }); } }; diff --git a/test/gemm_dl/gemm_dl_fp16.cpp b/test/gemm_dl/gemm_dl_fp16.cpp index 1e78cf0d64..d5369a5846 100644 --- a/test/gemm_dl/gemm_dl_fp16.cpp +++ b/test/gemm_dl/gemm_dl_fp16.cpp @@ -31,10 +31,10 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -//void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); -//void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); +// void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +// void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -//void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +// void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -51,7 +51,10 @@ int main() using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; bool res = true; + std::vector gemmPtrs; + +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); @@ -104,6 +107,7 @@ int main() PassThrough, PassThrough>{}(gemmPtr); } +#endif gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: From 3a122cb5fa2812d19e1663c91f05f27f81784f52 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 12 May 2022 04:48:44 +0000 Subject: [PATCH 37/46] adding example --- example/01_gemm/gemm_dl_fp16.cpp | 216 +++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 example/01_gemm/gemm_dl_fp16.cpp diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp new file mode 100644 index 0000000000..a32fc8c9a4 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -0,0 +1,216 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + +#if 0 + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device : ", c_m_n_device_result.mData, ",") + << std::endl; +#endif + } + + return 0; +} From 0eb6b99b5be566c5a9182886252aa010772ac0d4 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 12 May 2022 05:36:36 +0000 Subject: [PATCH 38/46] adding example --- example/01_gemm/CMakeLists.txt | 1 + example/01_gemm/gemm_dl_fp16.cpp | 4 +- example/01_gemm/gemm_dl_fp32.cpp | 215 ++++++++++++++++++ include/ck/config.hpp | 2 +- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 26 ++- 5 files changed, 242 insertions(+), 6 deletions(-) create mode 100644 example/01_gemm/gemm_dl_fp32.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index f5c0c233b5..650dcb1866 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index a32fc8c9a4..550d4260ab 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -50,8 +50,8 @@ using DeviceGemmInstance = ck::tensor_operation::device:: // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp new file mode 100644 index 0000000000..77172e29e6 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -0,0 +1,215 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + +#if 0 + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device : ", c_m_n_device_result.mData, ",") + << std::endl; +#endif + } + + return 0; +} diff --git a/include/ck/config.hpp b/include/ck/config.hpp index 919af1e6dd..eedeb7e136 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -15,7 +15,7 @@ #ifdef CK_USE_LAUNCH_BOUNDS #define CK_MAX_THREAD_PER_BLOCK 256 -#define CK_MIN_BLOCK_PER_CU 1 +#define CK_MIN_BLOCK_PER_CU 2 #endif // check GPU target diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index fb8ac230ff..cb69945d4e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -447,7 +447,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); +#if 0 __syncthreads(); +#else + block_sync_lds(); +#endif // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); @@ -469,7 +473,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); +#if 0 __syncthreads(); +#else + block_sync_lds(); +#endif // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); @@ -493,7 +501,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); - __syncthreads(); +#if 0 + __syncthreads(); +#else + block_sync_lds(); +#endif // LDS double buffer: load last data from device mem a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); @@ -507,7 +519,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); - __syncthreads(); +#if 0 + __syncthreads(); +#else + block_sync_lds(); +#endif // LDS double buffer: GEMM on last data blockwise_gemm.Run( @@ -515,7 +531,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 } else // if has 1 iteration left { - __syncthreads(); +#if 0 + __syncthreads(); +#else + block_sync_lds(); +#endif // LDS double buffer: GEMM on last data blockwise_gemm.Run( From 217b836167a3bb0768db9506202ee92aeda80889 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 12 May 2022 20:51:05 +0000 Subject: [PATCH 39/46] add gemm fp32 example --- example/01_gemm/CMakeLists.txt | 1 + example/01_gemm/gemm_dl_fp32.cpp | 4 +- example/01_gemm/gemm_dl_int8.cpp | 214 ++++++++++++++++++ .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 6 +- 4 files changed, 220 insertions(+), 5 deletions(-) create mode 100644 example/01_gemm/gemm_dl_int8.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 650dcb1866..a0fe1fe2fa 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,5 +1,6 @@ add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 77172e29e6..a11c71623b 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -47,9 +47,9 @@ using DeviceGemmInstance = ck::tensor_operation::device:: // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp new file mode 100644 index 0000000000..36688b6986 --- /dev/null +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -0,0 +1,214 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + +#if 0 + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device : ", c_m_n_device_result.mData, ",") + << std::endl; +#endif + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index cb69945d4e..65d1091f55 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -504,7 +504,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 #if 0 __syncthreads(); #else - block_sync_lds(); + block_sync_lds(); #endif // LDS double buffer: load last data from device mem @@ -522,7 +522,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 #if 0 __syncthreads(); #else - block_sync_lds(); + block_sync_lds(); #endif // LDS double buffer: GEMM on last data @@ -534,7 +534,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 #if 0 __syncthreads(); #else - block_sync_lds(); + block_sync_lds(); #endif // LDS double buffer: GEMM on last data From 162ac1dd30092bfe89b173ee154be7aeb62a2f0f Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 17 May 2022 23:48:28 +0000 Subject: [PATCH 40/46] clean up --- .../gpu/device/device_gemm_xdl_c_shuffle.hpp | 16 +++---- .../gpu/device/device_gemm_xdl_cshuffle.hpp | 10 ++--- .../gpu/device/device_gemm_xdl_splitk.hpp | 15 +++---- include/ck/utility/inner_product.hpp | 43 +------------------ 4 files changed, 19 insertions(+), 65 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index c4c62e89f6..23177d6556 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -370,18 +370,16 @@ struct DeviceGemmXdl_C_Shuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1030") + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) { return false; } - else - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index ce8cad94ce..8bd4375c55 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -556,15 +556,13 @@ struct DeviceGemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1030") + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) { return false; } - else - { - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); - } + + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp index 4b88e9dfc3..3618ea8948 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -530,19 +530,16 @@ struct DeviceGemmXdlSplitK static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1030") + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) { return false; } - else - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); } // polymorphic diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index d84879ff8f..59fe17e867 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,6 +1,4 @@ -#ifndef CK_INNER_PRODUCT_HPP -#define CK_INNER_PRODUCT_HPP - +#pragma once #include "data_type.hpp" namespace ck { @@ -70,12 +68,6 @@ inner_product(const float4_t& a, const float4_t& b, f c); } -template <> -__device__ void inner_product(const half_t& a, const half_t& b, float& c) -{ - c += a * b; -} - template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { @@ -140,41 +132,11 @@ __device__ void inner_product(const half8_t& a, const h c); } -template <> -__device__ void inner_product(const int8_t& a, const int8_t& b, int32_t& c) -{ - c += a * b; -} - -template <> -__device__ void -inner_product(const int8x2_t& a, const int8x2_t& b, int32_t& c) -{ - // #if defined(CK_USE_DOT2_I32_I8) - // #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM - // asm volatile("\n \ -// v_dot2_i32_i8 %0, %1, %2, %0\n \ -// " - // : "=v"(c) - // : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); - // #else - // c = __builtin_amdgcn_sdot2(bit_cast(a), bit_cast(b), c, false); - // #endif - // #else - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - static_for<0, 2, 1>{}([&](auto i) { - c += type_convert(a_vector.AsType()[i]) * - type_convert(b_vector.AsType()[i]); - }); - // #endif -} template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { -#if defined(CK_USE_DOT4_I32_I8) +#if defined(CK_USE_AMD_V_DOT4_I32_I8) #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ @@ -238,4 +200,3 @@ inner_product(const int8x16_t& a, const int8x16_t } } // namespace ck -#endif From f4f890a5ad04b4b729d1419056b44dca8f9b2737 Mon Sep 17 00:00:00 2001 From: shaojiewang Date: Wed, 18 May 2022 22:03:22 +0800 Subject: [PATCH 41/46] use 128x128x16 as MNK tile in navi21 gemm example --- example/01_gemm/gemm_dl_fp32.cpp | 2 +- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 5aa73dde5f..f0a95772f0 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -49,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device:: // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 65d1091f55..0360692abc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -447,6 +447,10 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + #if 0 __syncthreads(); #else @@ -454,8 +458,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 #endif // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + //a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + //b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, @@ -473,6 +477,10 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + #if 0 __syncthreads(); #else @@ -480,8 +488,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 #endif // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + //a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + //b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on current data blockwise_gemm.Run( From 9f602fa79fc856723652ae2ee2a5ebcbd7ba805f Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 19 May 2022 00:29:40 +0000 Subject: [PATCH 42/46] bug fix --- example/01_gemm/gemm_dl_fp16.cpp | 15 ++-- example/01_gemm/gemm_dl_fp32.cpp | 15 ++-- example/01_gemm/gemm_dl_int8.cpp | 16 ++--- example/CMakeLists.txt | 2 +- .../gpu/device/device_gemm_dl.hpp | 11 ++- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 30 +------- .../gpu/gemm/CMakeLists.txt | 24 +++---- ..._gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp | 18 ++--- ..._gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp | 26 +++---- ..._gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp | 26 +++---- ..._gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp | 9 +-- ..._gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp | 25 +++---- ..._gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp | 26 +++---- ..._gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp | 9 +-- ..._gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp | 9 +-- ...ice_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp | 29 +++----- ...ice_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp | 29 +++----- ...ice_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp | 29 +++----- ...ice_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp | 29 +++----- profiler/CMakeLists.txt | 4 +- profiler/include/profile_gemm_impl.hpp | 68 +++++++++---------- test/CMakeLists.txt | 3 +- test/gemm/CMakeLists.txt | 38 +++++++---- test/{gemm_dl => gemm}/gemm_dl_fp16.cpp | 8 +-- test/{gemm_dl => gemm}/gemm_dl_fp32.cpp | 18 ++--- test/{gemm_dl => gemm}/gemm_dl_int8.cpp | 20 +++--- .../gemm/{gemm_bf16.cpp => gemm_xdl_bf16.cpp} | 0 .../gemm/{gemm_fp16.cpp => gemm_xdl_fp16.cpp} | 0 .../gemm/{gemm_fp32.cpp => gemm_xdl_fp32.cpp} | 0 .../gemm/{gemm_int8.cpp => gemm_xdl_int8.cpp} | 20 +++--- test/gemm_dl/CMakeLists.txt | 11 --- 31 files changed, 205 insertions(+), 362 deletions(-) rename test/{gemm_dl => gemm}/gemm_dl_fp16.cpp (93%) rename test/{gemm_dl => gemm}/gemm_dl_fp32.cpp (83%) rename test/{gemm_dl => gemm}/gemm_dl_int8.cpp (84%) rename test/gemm/{gemm_bf16.cpp => gemm_xdl_bf16.cpp} (100%) rename test/gemm/{gemm_fp16.cpp => gemm_xdl_fp16.cpp} (100%) rename test/gemm/{gemm_fp32.cpp => gemm_xdl_fp32.cpp} (100%) rename test/gemm/{gemm_int8.cpp => gemm_xdl_int8.cpp} (82%) delete mode 100644 test/gemm_dl/CMakeLists.txt diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index 1dd85e4fa2..18b57e80bc 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -32,9 +32,9 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -48,14 +48,7 @@ using DeviceGemmInstance = ck::tensor_operation::device:: // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index f0a95772f0..f934a9e8fa 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -31,9 +31,9 @@ using BDataType = float; using CDataType = float; using AccDataType = float; -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -47,14 +47,7 @@ using DeviceGemmInstance = ck::tensor_operation::device:: // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index 9722d2606b..932728fcea 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -29,9 +29,9 @@ using BDataType = int8_t; using CDataType = int8_t; using AccDataType = int32_t; -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -45,15 +45,7 @@ using DeviceGemmInstance = ck::tensor_operation::device:: // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 3e276b92ea..d3e5b95831 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -34,7 +34,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) add_dependencies(examples ${EXAMPLE_NAME}) -endfunction(add_example_executable EXAMPLE_NAME) +endfunction(add_example_executable_no_testing EXAMPLE_NAME) add_subdirectory(01_gemm) add_subdirectory(02_gemm_alpha_beta) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp index 17c6c02629..a6a059df77 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -473,16 +473,15 @@ struct DeviceGemmDl static bool IsSupportedArgument(const Argument& arg) { -#if 0 - if(ck::get_device_name() == "gfx1030") -#else - if(true) -#endif + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030") { return GridwiseGemm::CheckValidity( arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); } - else { return false; } + else + { + return false; + } } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 0360692abc..3ae5e7a2ea 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -451,15 +451,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); -#if 0 - __syncthreads(); -#else block_sync_lds(); -#endif - - // LDS doubel buffer: load next data from device mem - //a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); - //b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, @@ -481,15 +473,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); -#if 0 - __syncthreads(); -#else block_sync_lds(); -#endif - - // LDS doubel buffer: load next data from device mem - //a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); - //b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -509,11 +493,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); -#if 0 - __syncthreads(); -#else block_sync_lds(); -#endif // LDS double buffer: load last data from device mem a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); @@ -527,11 +507,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); -#if 0 - __syncthreads(); -#else block_sync_lds(); -#endif // LDS double buffer: GEMM on last data blockwise_gemm.Run( @@ -539,11 +515,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 } else // if has 1 iteration left { -#if 0 - __syncthreads(); -#else - block_sync_lds(); -#endif + __syncthreads(); // LDS double buffer: GEMM on last data blockwise_gemm.Run( diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index b989873dc3..da769a5626 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -32,18 +32,18 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; -# device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; - device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; -# device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; -# device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp; -# device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp; - device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp; -# device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp; -# device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp; -# device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp; -# device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp; -# device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp; -# device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp; ) add_library(device_gemm_instance OBJECT ${DEVICE_GEMM_INSTANCE_SOURCE}) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp index 2b9a1c140a..db7f6af04b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -24,20 +24,12 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp index a82dfa90b2..c4253bcc4c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -23,23 +23,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = - std::tuple< - // clang-format off - // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; +using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp index 37d51e9410..d19d11f1f8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -23,23 +23,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = - std::tuple< - // clang-format off - // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; +using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp index d51191a3d6..cd86e5ceae 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -30,14 +30,7 @@ using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp index 1dd3829676..3fcc5fdfdc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -23,22 +23,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = - std::tuple< - // clang-format off - // #######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // #######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // #######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // #######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<2, 1, 2, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> -// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // clang-format on - >; +using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp index 3514650a22..8cd32128b5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -23,22 +23,16 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<4, 1, 64, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> -// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // clang-format on - >; +using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp index 26e0f98a54..4c4bfc440d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -30,14 +30,7 @@ using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> -// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, -// DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp index c21f0736d1..c6077341b1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -30,14 +30,7 @@ using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp index 0fa75a0f43..91b68d4bf2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -17,27 +17,18 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; +using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp index ab63994987..13b185fd93 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -17,27 +17,18 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; +using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp index 5d98de4650..ff4a89beb4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -17,27 +17,18 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 4>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; +using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp index afc378a93d..e32158a292 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -17,27 +17,18 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = - std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 4, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 4, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 4, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 4, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 4, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 4, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - // DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; +using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( std::vector>& instances) diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 79aae487e4..ee0050d200 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -46,9 +46,7 @@ add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE conv_util) target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_dl_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_xdl_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_dlops_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index e5b6583993..146449f3e8 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -42,14 +42,10 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( - std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( std::vector&); @@ -74,20 +70,20 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -194,7 +190,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); @@ -215,7 +211,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); @@ -236,7 +232,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); @@ -257,7 +253,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); @@ -282,7 +278,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); @@ -303,7 +299,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); @@ -327,7 +323,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); @@ -348,7 +344,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); @@ -396,40 +392,40 @@ void profile_gemm_impl(int do_verification, is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d0eb98c118..a8e8c7347b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -55,7 +55,6 @@ add_subdirectory(space_filling_curve) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) -add_subdirectory(gemm_dl) add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) @@ -65,4 +64,4 @@ add_subdirectory(convnd_fwd) add_subdirectory(reduce) add_subdirectory(conv2d_bwd_weight) add_subdirectory(convnd_bwd_data) -# DONOT add client_app, that is tested via CI independently \ No newline at end of file +# DONOT add client_app, that is tested via CI independently diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index ea581ee781..b8679e3715 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,15 +1,29 @@ -add_test_executable(test_gemm_fp32 gemm_fp32.cpp) -target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_xdl_instance) +# GEMM XDL +add_test_executable(test_gemm_xdl_fp32 gemm_xdl_fp32.cpp) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_fp16 gemm_fp16.cpp) -target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_xdl_instance) +add_test_executable(test_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_bf16 gemm_bf16.cpp) -target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) -target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_xdl_instance) +add_test_executable(test_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_int8 gemm_int8.cpp) -target_link_libraries(test_gemm_int8 PRIVATE host_tensor) -target_link_libraries(test_gemm_int8 PRIVATE device_gemm_xdl_instance) +add_test_executable(test_gemm_xdl_int8 gemm_xdl_int8.cpp) +target_link_libraries(test_gemm_xdl_int8 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_int8 PRIVATE device_gemm_instance) + +# GEMM DL +add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp) +target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp) +target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp) +target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor) +TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_instance) diff --git a/test/gemm_dl/gemm_dl_fp16.cpp b/test/gemm/gemm_dl_fp16.cpp similarity index 93% rename from test/gemm_dl/gemm_dl_fp16.cpp rename to test/gemm/gemm_dl_fp16.cpp index d5369a5846..6165355ec4 100644 --- a/test/gemm_dl/gemm_dl_fp16.cpp +++ b/test/gemm/gemm_dl_fp16.cpp @@ -31,10 +31,10 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -// void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); -// void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -// void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -54,7 +54,6 @@ int main() std::vector gemmPtrs; -#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); @@ -107,7 +106,6 @@ int main() PassThrough, PassThrough>{}(gemmPtr); } -#endif gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: diff --git a/test/gemm_dl/gemm_dl_fp32.cpp b/test/gemm/gemm_dl_fp32.cpp similarity index 83% rename from test/gemm_dl/gemm_dl_fp32.cpp rename to test/gemm/gemm_dl_fp32.cpp index aaece4c39d..cd0f816731 100644 --- a/test/gemm_dl/gemm_dl_fp32.cpp +++ b/test/gemm/gemm_dl_fp32.cpp @@ -14,7 +14,7 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_gemm_dlops.hpp" +#include "device_gemm_dl.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -31,10 +31,10 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -53,7 +53,7 @@ int main() bool res = true; std::vector gemmPtrs; ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -71,7 +71,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -89,7 +89,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -107,7 +107,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { diff --git a/test/gemm_dl/gemm_dl_int8.cpp b/test/gemm/gemm_dl_int8.cpp similarity index 84% rename from test/gemm_dl/gemm_dl_int8.cpp rename to test/gemm/gemm_dl_int8.cpp index bff29111c3..72b9f1440f 100644 --- a/test/gemm_dl/gemm_dl_int8.cpp +++ b/test/gemm/gemm_dl_int8.cpp @@ -14,7 +14,7 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_gemm_dlops.hpp" +#include "device_gemm_dl.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -31,10 +31,10 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(std::vector&); -void add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -53,7 +53,7 @@ int main() bool res = true; std::vector gemmPtrs; ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances(gemmPtrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -69,10 +69,9 @@ int main() PassThrough>{}(gemmPtr); } -#if 0 gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances(gemmPtrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -90,7 +89,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -108,7 +107,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -124,7 +123,6 @@ int main() PassThrough>{}(gemmPtr); } -#endif std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_xdl_bf16.cpp similarity index 100% rename from test/gemm/gemm_bf16.cpp rename to test/gemm/gemm_xdl_bf16.cpp diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_xdl_fp16.cpp similarity index 100% rename from test/gemm/gemm_fp16.cpp rename to test/gemm/gemm_xdl_fp16.cpp diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_xdl_fp32.cpp similarity index 100% rename from test/gemm/gemm_fp32.cpp rename to test/gemm/gemm_xdl_fp32.cpp diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_xdl_int8.cpp similarity index 82% rename from test/gemm/gemm_int8.cpp rename to test/gemm/gemm_xdl_int8.cpp index 870881dd76..fbb1b1ac98 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_xdl_int8.cpp @@ -31,14 +31,10 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( - std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation @@ -57,7 +53,7 @@ int main() bool res = true; ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -75,7 +71,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -93,7 +89,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { @@ -111,7 +107,7 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { diff --git a/test/gemm_dl/CMakeLists.txt b/test/gemm_dl/CMakeLists.txt deleted file mode 100644 index 6486474771..0000000000 --- a/test/gemm_dl/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp) -target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_dl_instance) - -add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp) -target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_dl_instance) - -add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp) -target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor) -TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_dl_instance) From 15c5b67c5a4dd80382688f34f1bcfe152cb68747 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 20 May 2022 17:57:10 -0400 Subject: [PATCH 43/46] fix test --- example/01_gemm/gemm_dl_fp16.cpp | 8 ++-- example/01_gemm/gemm_dl_fp32.cpp | 8 ++-- example/01_gemm/gemm_dl_int8.cpp | 8 ++-- include/ck/host_utility/device_prop.hpp | 25 +++++++++- test/gemm/gemm_util.hpp | 64 +++++++++++++++---------- 5 files changed, 77 insertions(+), 36 deletions(-) diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index 18b57e80bc..6e8e04f9e5 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -170,9 +170,11 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index f934a9e8fa..65c806bf07 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -169,9 +169,11 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index 932728fcea..a9590030c7 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -167,9 +167,11 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 5f13d6cb22..74b20acecd 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include namespace ck { @@ -19,8 +20,30 @@ inline std::string get_device_name() { return std::string(); } - const std::string name(props.gcnArchName); + const std::string raw_name(props.gcnArchName); + // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + static std::map device_name_map = { + {"Ellesmere", "gfx803"}, + {"Baffin", "gfx803"}, + {"RacerX", "gfx803"}, + {"Polaris10", "gfx803"}, + {"Polaris11", "gfx803"}, + {"Tonga", "gfx803"}, + {"Fiji", "gfx803"}, + {"gfx800", "gfx803"}, + {"gfx802", "gfx803"}, + {"gfx804", "gfx803"}, + {"Vega10", "gfx900"}, + {"gfx901", "gfx900"}, + {"10.3.0 Sienna_Cichlid 18", "gfx1030"}, + }; + + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + + auto match = device_name_map.find(name); + if(match != device_name_map.end()) + return match->second; return name; } diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 17e954b7f2..258ed60b08 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -60,7 +60,7 @@ template -void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, +bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, const ck::gemm_util::GemmParams& params, const Tensor& A, const Tensor& B, @@ -73,9 +73,6 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); - a_m_k_device_buf.ToDevice(A.mData.data()); - b_k_n_device_buf.ToDevice(B.mData.data()); - auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto argument_ptr = gemmPtr->MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), @@ -91,15 +88,23 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, b_element_op, c_element_op); - if(!gemmPtr->IsSupportedArgument(argument_ptr.get())) + if(gemmPtr->IsSupportedArgument(argument_ptr.get())) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + a_m_k_device_buf.ToDevice(A.mData.data()); + b_k_n_device_buf.ToDevice(B.mData.data()); + invoker_ptr->Run(argument_ptr.get()); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; } + else + { + std::cout << "device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; - invoker_ptr->Run(argument_ptr.get()); - c_m_n_device_buf.FromDevice(C.mData.data()); + return false; + } } template ::value) + if(is_supported) { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + + return res; } - else if(std::is_same::value) + else { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + return true; } - else if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - - return res; } }; From 39131c62dd021f014d5febe92cbec2f000cd4813 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 20 May 2022 18:17:40 -0400 Subject: [PATCH 44/46] use new block c tile --- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 32 +++++++------------ .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 1 + 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 3ae5e7a2ea..48994b2590 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,8 +1,10 @@ #pragma once + #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "blockwise_gemm_dl_v2r3.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" @@ -56,7 +58,7 @@ template {}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto block_2_ctile_map = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return block_2_ctile_map; + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); } using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = - decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CMNGridDesc{})); - using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CMNGridDesc{})); + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); template __device__ static void diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index bfa93e5866..d60f8c4d07 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,4 +1,5 @@ #pragma once + #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" From a838cb9daa36df0292f85bde19f1b68c2bdfcfa2 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 20 May 2022 18:32:39 -0400 Subject: [PATCH 45/46] clean --- .../gpu/device/device_gemm_xdl_c_shuffle.hpp | 483 ------------------ 1 file changed, 483 deletions(-) delete mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp deleted file mode 100644 index 23177d6556..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ /dev/null @@ -1,483 +0,0 @@ -#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP -#define DEVICE_GEMM_XDL_C_SHUFFLE_HPP - -#include -#include -#include "device.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r1.hpp" -#include "device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename CShuffleDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t KPerBlock, - ck::index_t AK1, - ck::index_t BK1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> -struct DeviceGemmXdl_C_Shuffle - : public DeviceGemm -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % AK1 == 0); - - const index_t K0 = K / AK1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_m_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % BK1 == 0); - - const index_t K0 = K / BK1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_n_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl, - NumPrefetch>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - a_grid_desc_k0_m_k1_ = - DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = - DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC); - - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceGemmXdl_C_Shuffle::Argument; - - float Run(const Argument& arg, int nrepeat = 1) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v3r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v3r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_b, - p_c, - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmXdl_C_Shuffle" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif From 7c7904ea07d96888bc65935ccac3d02e31db7360 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 23 May 2022 17:31:14 +0000 Subject: [PATCH 46/46] fix build --- .../tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 48994b2590..3b5daf6ead 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -273,6 +273,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + // TODO: change this. I think it needs multi-dimensional alignment constexpr auto max_lds_align = K1;