Skip to content

Commit f8446f0

Browse files
committed
ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm
armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds support for mmla kernel for q4_1_q8_1 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel.
1 parent 5cb29cf commit f8446f0

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

ggml-quants.c

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4046,16 +4046,87 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx,
40464046
const int nb = n / qk;
40474047

40484048
assert(n % qk == 0);
4049+
#if defined(__ARM_FEATURE_MATMUL_INT8)
4050+
assert((nrc == 2) || (nrc == 1));
4051+
#else
40494052
assert(nrc == 1);
4050-
UNUSED(nrc);
4051-
UNUSED(bx);
4052-
UNUSED(by);
4053+
#endif
40534054

40544055
const block_q4_1 * restrict x = vx;
40554056
const block_q8_1 * restrict y = vy;
40564057

4058+
#if defined(__ARM_FEATURE_MATMUL_INT8)
4059+
if (nrc == 2) {
4060+
const block_q4_1 * restrict vx0 = vx;
4061+
const block_q4_1 * restrict vx1 = vx + bx;
4062+
const block_q8_1 * restrict vy0 = vy;
4063+
const block_q8_1 * restrict vy1 = vy + by;
4064+
4065+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
4066+
float32x4_t summs0 = vdupq_n_f32(0.0f);
4067+
4068+
for (int i = 0; i < nb; i++) {
4069+
const block_q4_1 * restrict b_x0 = &vx0[i];
4070+
const block_q4_1 * restrict b_x1 = &vx1[i];
4071+
const block_q8_1 * restrict b_y0 = &vy0[i];
4072+
const block_q8_1 * restrict b_y1 = &vy1[i];
4073+
4074+
float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
4075+
GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
4076+
GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
4077+
GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
4078+
summs0 += summs_t;
4079+
4080+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
4081+
4082+
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
4083+
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
4084+
4085+
// 4-bit -> 8-bit
4086+
const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
4087+
const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
4088+
const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
4089+
const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
4090+
4091+
// load y
4092+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4093+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4094+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4095+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4096+
4097+
// mmla into int32x4_t
4098+
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4099+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4100+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4101+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4102+
4103+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4104+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4105+
4106+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4107+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4108+
4109+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4110+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4111+
4112+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4113+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4114+
4115+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
4116+
l1, r1)), l2, r2)), l3, r3))), scale);
4117+
}
4118+
4119+
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
4120+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4121+
sumv2 = sumv2 + summs0;
4122+
4123+
vst1_f32(s, vget_low_f32(sumv2));
4124+
vst1_f32(s + 16, vget_high_f32(sumv2));
4125+
} else
4126+
#endif
40574127
// TODO: add WASM SIMD
40584128
#if defined(__ARM_NEON)
4129+
{
40594130
float32x4_t sumv0 = vdupq_n_f32(0.0f);
40604131
float32x4_t sumv1 = vdupq_n_f32(0.0f);
40614132

@@ -4097,6 +4168,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx,
40974168
}
40984169

40994170
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
4171+
}
41004172
#elif defined(__AVX2__) || defined(__AVX__)
41014173
// Initialize accumulator with zeros
41024174
__m256 acc = _mm256_setzero_ps();

ggml.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
493493
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
494494
.vec_dot = ggml_vec_dot_q4_1_q8_1,
495495
.vec_dot_type = GGML_TYPE_Q8_1,
496+
#if defined (__ARM_FEATURE_MATMUL_INT8)
497+
.hw_matmul = true,
498+
#endif
496499
},
497500
[4] = { // GGML_TYPE_Q4_2
498501
.type_name = "DEPRECATED",

0 commit comments

Comments
 (0)