diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h index 898fa30b18..c976be39f5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h @@ -32,8 +32,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16( const size_t ldb, int32x4_t (&partial_sums)[8], int32_t& row_sum_a, - int32_t (&row_sum_b)[8]) { + int32x4_t (&row_sum_b)[8]) { int8x16_t a_vec = vld1q_s8(a); + int8x16_t ones = vdupq_n_s8(1); row_sum_a = row_sum_a + vaddlvq_s8(a_vec); // godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize @@ -42,8 +43,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16( // deconstruct the loop and do manual optimization. Or just write assembly. #pragma unroll(8) for (int i = 0; i < 8; ++i) { - int8x16_t b_vec = vld1q_s8(b + i * ldb); - row_sum_b[i] = row_sum_b[i] + vaddlvq_s8(b_vec); + int8x16_t b_vec = vld1q_s8(b); + b += ldb; + row_sum_b[i] = vdotq_s32(row_sum_b[i], b_vec, ones); partial_sums[i] = vdotq_s32(partial_sums[i], a_vec, b_vec); } } @@ -234,8 +236,9 @@ struct KernelImpl { const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx * rhs_stride_n; int32x4_t int32_sums[nr] = {vdupq_n_s32(0)}; int32_t row_sum_lhs = 0; - int32_t row_sum_rhs[nr] = {0, 0, 0, 0, 0, 0, 0, 0}; + int32x4_t row_sum_rhs_vec[nr] = {vdupq_n_s32(0)}; int32_t sums[nr]; + int32_t row_sum_rhs[nr]; // Loop k_idx by group int k_idx = 0; @@ -246,12 +249,13 @@ struct KernelImpl { rhs_stride_n, int32_sums, row_sum_lhs, - row_sum_rhs); + row_sum_rhs_vec); lhs_ptr += kr; rhs_ptr += kr; } reduce_1x8_int32x4_t_sums(int32_sums, sums); + reduce_1x8_int32x4_t_sums(row_sum_rhs_vec, row_sum_rhs); for (int ki = 0; ki < (k - k_idx); ++ki) { row_sum_lhs += (int32_t)lhs_ptr[ki]; }