Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16(
const size_t ldb,
int32x4_t (&partial_sums)[8],
int32_t& row_sum_a,
int32_t (&row_sum_b)[8]) {
int32x4_t (&row_sum_b)[8]) {
int8x16_t a_vec = vld1q_s8(a);
int8x16_t ones = vdupq_n_s8(1);
row_sum_a = row_sum_a + vaddlvq_s8(a_vec);

// godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize
Expand All @@ -42,8 +43,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16(
// deconstruct the loop and do manual optimization. Or just write assembly.
#pragma unroll(8)
for (int i = 0; i < 8; ++i) {
int8x16_t b_vec = vld1q_s8(b + i * ldb);
row_sum_b[i] = row_sum_b[i] + vaddlvq_s8(b_vec);
int8x16_t b_vec = vld1q_s8(b);
b += ldb;
row_sum_b[i] = vdotq_s32(row_sum_b[i], b_vec, ones);
partial_sums[i] = vdotq_s32(partial_sums[i], a_vec, b_vec);
}
}
Expand Down Expand Up @@ -234,8 +236,9 @@ struct KernelImpl<true, true, false, true> {
const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx * rhs_stride_n;
int32x4_t int32_sums[nr] = {vdupq_n_s32(0)};
int32_t row_sum_lhs = 0;
int32_t row_sum_rhs[nr] = {0, 0, 0, 0, 0, 0, 0, 0};
int32x4_t row_sum_rhs_vec[nr] = {vdupq_n_s32(0)};
int32_t sums[nr];
int32_t row_sum_rhs[nr];

// Loop k_idx by group
int k_idx = 0;
Expand All @@ -246,12 +249,13 @@ struct KernelImpl<true, true, false, true> {
rhs_stride_n,
int32_sums,
row_sum_lhs,
row_sum_rhs);
row_sum_rhs_vec);
lhs_ptr += kr;
rhs_ptr += kr;
}

reduce_1x8_int32x4_t_sums(int32_sums, sums);
reduce_1x8_int32x4_t_sums(row_sum_rhs_vec, row_sum_rhs);
for (int ki = 0; ki < (k - k_idx); ++ki) {
row_sum_lhs += (int32_t)lhs_ptr[ki];
}
Expand Down
Loading