From 794fdd6d1119659a4648432595d149b78b4138cf Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 10 Apr 2025 13:48:00 -0700 Subject: [PATCH] Add gemm kernel to interface Summary: and add ability to split large work between gemm and gemv Reviewed By: metascroy Differential Revision: D71833068 --- .../kernels/cpu/aarch64/matmul/matmul.h | 78 +++++++++++++++++++ .../kernels/cpu/interface/quantized_matmul.h | 8 +- .../cpu/interface/test_qmatmul_interface.cpp | 16 ++++ 3 files changed, 97 insertions(+), 5 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h index 4663927655..94c53add0e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h @@ -10,6 +10,7 @@ #pragma once +#include #if defined(__aarch64__) && defined(__ARM_NEON) #include @@ -106,6 +107,83 @@ void kernel( const int rhs_qparams_stride); } // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32 + +namespace fp32_a_input_channelwise_8bit_b_f32 { + +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride); + +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + assert(n >= 16); + if (m > 16) { + auto remaining_m = m % 16; + auto m_for_gemm_kernel = m - remaining_m; + fp32_a_input_channelwise_8bit_b_4x16x4_f32:: + kernel( + m_for_gemm_kernel, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + rhs_zero_points, + rhs_scales, + beta, + rhs_qparams_stride); + output += m_for_gemm_kernel * out_stride_m; + lhs += m_for_gemm_kernel * lhs_stride_m; + m = remaining_m; + } + if (m > 0) { + fp32_a_input_channelwise_8bit_b_1x16x4_f32:: + kernel( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + rhs_zero_points, + rhs_scales, + beta, + rhs_qparams_stride); + } +} + +} // namespace fp32_a_input_channelwise_8bit_b_f32 } // namespace torchao::kernels::cpu::aarch64::quantized_matmul #include diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h index d9c9d23271..1fec5109a6 100644 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -12,9 +12,7 @@ #include #if defined(__aarch64__) && defined(__ARM_NEON) -#include -#include -#include +#include #endif // defined(__aarch64__) && defined(__ARM_NEON) namespace torchao::kernels::cpu::quantized_matmul { @@ -138,8 +136,8 @@ get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( if (!a_transposed && !b_transposed && n >= 16) { a_stride_m = k; b_stride_n = n; - return aarch64::quantized_matmul:: - fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel; + return aarch64::quantized_matmul::fp32_a_input_channelwise_8bit_b_f32:: + kernel; } #endif // defined(__aarch64__) && defined(__ARM_NEON) assert(!a_transposed); diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp index 4024f3f1de..5062dfb908 100644 --- a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -624,6 +624,22 @@ TEST_P( /*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32); } +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2) { + generate(19, 37, 35, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/19, /*k=*/37, /*n=*/35, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesStrided2) { + generate(23, 37, 50, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/23, /*k=*/37, /*n=*/50, beta(), *this, 32); +} + INSTANTIATE_TEST_SUITE_P( F32AInt8BFP32CTest, FP32A_QuantizedB_FP32C_Interface_Test,