Skip to content

Commit d924089

Browse files
committed
ggml: aarch64: implement mmla kernel for q8_0_q8_0 quantized gemm
armv8.2-a and above supports MMLA instructions that have better throughput then DOT. this commit adds support for mmla kernel for q8_0_q8_0 gemm. This is disabled by default. Please build llama.cpp with the following option to enable it. "LLAMA_ARM_MMLA=ON" on AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel.
1 parent 3e5ca79 commit d924089

File tree

8 files changed

+198
-28
lines changed

8 files changed

+198
-28
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STA
107107
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
108108
option(LLAMA_BUILD_SERVER "llama: build server example" ON)
109109

110+
# aarmv8.2-a+ extensions
111+
option(LLAMA_ARM_MMLA "llama: enable aarch64 mmla kernels" OFF)
112+
110113
# Required for relocatable CMake package
111114
include(${CMAKE_CURRENT_SOURCE_DIR}/scripts/build-info.cmake)
112115

@@ -626,6 +629,10 @@ if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATC
626629
# Raspberry Pi 3, 4, Zero 2 (32-bit)
627630
add_compile_options(-mno-unaligned-access)
628631
endif()
632+
if (LLAMA_ARM_MMLA)
633+
add_compile_options(-march=armv8.2-a+i8mm)
634+
add_compile_definitions(__ARM_FEATURE_MMLA)
635+
endif()
629636
endif()
630637
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" )
631638
message(STATUS "x86 detected")

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,13 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h
498498
$(CC) $(CFLAGS) -c $< -o $@
499499
endif # LLAMA_MPI
500500

501+
ifdef LLAMA_ARM_MMLA
502+
MK_CPPFLAGS += -D__ARM_FEATURE_MMLA
503+
MK_CFLAGS += -D__ARM_FEATURE_MMLA
504+
MK_CXXFLAGS += -D__ARM_FEATURE_MMLA
505+
endif # LLAMA_ARM_MMLA
506+
507+
501508
GF_CC := $(CC)
502509
include scripts/get-flags.mk
503510

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
14501450
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
14511451
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
14521452
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
1453+
fprintf(stream, "cpu_has_neon_mmla: %s\n", ggml_cpu_has_neon_mmla() ? "true" : "false");
14531454
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
14541455
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
14551456
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");

ggml-quants.c

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4421,6 +4421,68 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
44214421
#endif
44224422
}
44234423

4424+
#if defined(__ARM_FEATURE_MMLA)
4425+
void ggml_vec_mmla_q8_0_q8_0(const int n, float * restrict s0, float * restrict s1, const void * restrict lhs0,
4426+
const void * restrict lhs1, const void * restrict rhs0, const void * restrict rhs1) {
4427+
const int qk = QK8_0;
4428+
const int nb = n / qk;
4429+
4430+
assert(n % qk == 0);
4431+
4432+
const block_q8_0 * restrict vx0 = lhs0;
4433+
const block_q8_0 * restrict vx1 = lhs1;
4434+
const block_q8_0 * restrict vy0 = rhs0;
4435+
const block_q8_0 * restrict vy1 = rhs1;
4436+
4437+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
4438+
4439+
for (int i = 0; i < nb; i++) {
4440+
const block_q8_0 * restrict b_x0 = &vx0[i];
4441+
const block_q8_0 * restrict b_y0 = &vy0[i];
4442+
4443+
const block_q8_0 * restrict b_x1 = &vx1[i];
4444+
const block_q8_0 * restrict b_y1 = &vy1[i];
4445+
4446+
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4447+
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4448+
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4449+
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4450+
4451+
// load y
4452+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4453+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4454+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4455+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4456+
4457+
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4458+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4459+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4460+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4461+
4462+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4463+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4464+
4465+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4466+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4467+
4468+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4469+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4470+
4471+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4472+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4473+
4474+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), l1, r1)), l2, r2)), l3, r3))), scale);
4475+
}
4476+
4477+
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
4478+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4479+
4480+
vst1_f32(s0, vget_low_f32(sumv2));
4481+
vst1_f32(s1, vget_high_f32(sumv2));
4482+
}
4483+
#endif
4484+
4485+
44244486
void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
44254487
const int qk = QK8_0;
44264488
const int nb = n / qk;

ggml-quants.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx,
243243
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
244244
void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
245245

246+
#if defined(__ARM_FEATURE_MMLA)
247+
// mmla
248+
void ggml_vec_mmla_q8_0_q8_0(int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1,
249+
const void * restrict vy0, const void * restrict vy1);
250+
#endif
251+
246252
//
247253
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
248254
//

ggml.c

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
514514
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
515515
.vec_dot = ggml_vec_dot_q8_0_q8_0,
516516
.vec_dot_type = GGML_TYPE_Q8_0,
517+
#if defined(__ARM_FEATURE_MMLA)
518+
.vec_mmla = ggml_vec_mmla_q8_0_q8_0,
519+
#endif
517520
},
518521
[GGML_TYPE_Q8_1] = {
519522
.type_name = "q8_1",
@@ -9801,6 +9804,9 @@ static void ggml_compute_forward_mul_mat(
98019804
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
98029805
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
98039806
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
9807+
#if defined(__ARM_FEATURE_MMLA)
9808+
ggml_vec_mmla_t const vec_mmla = type_traits[type].vec_mmla;
9809+
#endif
98049810

98059811
GGML_ASSERT(ne0 == ne01);
98069812
GGML_ASSERT(ne1 == ne11);
@@ -9952,43 +9958,107 @@ static void ggml_compute_forward_mul_mat(
99529958

99539959
// attempt to reduce false-sharing (does not seem to make a difference)
99549960
float tmp[16];
9961+
#if defined(__ARM_FEATURE_MMLA)
9962+
float tmp1[16];
9963+
float tmp2[16];
99559964

9956-
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9957-
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9958-
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9959-
const int64_t i13 = (ir1/(ne12*ne1));
9960-
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
9961-
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
9965+
if ((vec_mmla != NULL) && (nr0 % 2 == 0) && (nr1 %2 == 0)) {
9966+
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9967+
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9968+
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += 2) {
9969+
const int64_t i13 = (ir1/(ne12*ne11));
9970+
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
9971+
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
99629972

9963-
// broadcast src0 into src1
9964-
const int64_t i03 = i13/r3;
9965-
const int64_t i02 = i12/r2;
9973+
// broadcast src0 into src1
9974+
const int64_t i03 = i13/r3;
9975+
const int64_t i02 = i12/r2;
9976+
9977+
const int64_t i1 = i11;
9978+
const int64_t i2 = i12;
9979+
const int64_t i3 = i13;
9980+
9981+
const int64_t i13_ = ((ir1+1)/(ne12*ne11));
9982+
const int64_t i12_ = ((ir1+1) - (i13_)*ne12*ne11)/ne11;
9983+
const int64_t i11_ = ((ir1+1) - (i13_)*ne12*ne11 - (i12_)*ne11);
9984+
9985+
// broadcast src0 into src1
9986+
const int64_t i03_ = i13_/r3;
9987+
const int64_t i02_ = i12_/r2;
9988+
9989+
const int64_t i1_ = i11_;
9990+
const int64_t i2_ = i12_;
9991+
const int64_t i3_ = i13_;
9992+
9993+
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
9994+
9995+
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
9996+
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
9997+
// the original src1 data pointer, so we should index using the indices directly
9998+
// TODO: this is a bit of a hack, we should probably have a better way to handle this
9999+
const char * src1_col = (const char *) wdata +
10000+
(src1_cont || src1->type != vec_dot_type
10001+
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10002+
: (i11*nb11 + i12*nb12 + i13*nb13));
10003+
10004+
const char * src1_col_ = (const char *) wdata +
10005+
(src1_cont || src1->type != vec_dot_type
10006+
? ((i11_) + (i12_)*ne11 + (i13_)*ne12*ne11)*row_size
10007+
: ((i11_)*nb11 + (i12_)*nb12 + (i13_)*nb13));
10008+
10009+
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10010+
float * dst_col_ = (float *) ((char *) dst->data + ((i1_)*nb1 + (i2_)*nb2 + (i3_)*nb3));
10011+
10012+
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += 2) {
10013+
vec_mmla(ne00, &tmp1[ir0 - iir0], &tmp2[ir0 - iir0], src0_row + ir0*nb01,
10014+
src0_row + (ir0+1)*nb01, src1_col, src1_col_);
10015+
}
10016+
10017+
memcpy(&dst_col[iir0], tmp1, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10018+
memcpy(&dst_col_[iir0], tmp2, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10019+
}
10020+
}
10021+
}
10022+
} else
10023+
#endif
10024+
{
10025+
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
10026+
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
10027+
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
10028+
const int64_t i13 = (ir1/(ne12*ne1));
10029+
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
10030+
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
996610031

9967-
const int64_t i1 = i11;
9968-
const int64_t i2 = i12;
9969-
const int64_t i3 = i13;
10032+
// broadcast src0 into src1
10033+
const int64_t i03 = i13/r3;
10034+
const int64_t i02 = i12/r2;
997010035

9971-
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
10036+
const int64_t i1 = i11;
10037+
const int64_t i2 = i12;
10038+
const int64_t i3 = i13;
997210039

9973-
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
9974-
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
9975-
// the original src1 data pointer, so we should index using the indices directly
9976-
// TODO: this is a bit of a hack, we should probably have a better way to handle this
9977-
const char * src1_col = (const char *) wdata +
9978-
(src1_cont || src1->type != vec_dot_type
9979-
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
9980-
: (i11*nb11 + i12*nb12 + i13*nb13));
10040+
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
998110041

9982-
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10042+
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
10043+
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
10044+
// the original src1 data pointer, so we should index using the indices directly
10045+
// TODO: this is a bit of a hack, we should probably have a better way to handle this
10046+
const char * src1_col = (const char *) wdata +
10047+
(src1_cont || src1->type != vec_dot_type
10048+
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10049+
: (i11*nb11 + i12*nb12 + i13*nb13));
998310050

9984-
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9985-
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
9986-
//}
10051+
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
998710052

9988-
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9989-
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
10053+
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10054+
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
10055+
//}
10056+
10057+
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10058+
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
10059+
}
10060+
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
999010061
}
9991-
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
999210062
}
999310063
}
999410064
}
@@ -19965,6 +20035,14 @@ int ggml_cpu_has_arm_fma(void) {
1996520035
#endif
1996620036
}
1996720037

20038+
int ggml_cpu_has_neon_mmla(void) {
20039+
#if defined(__ARM_FEATURE_MMLA)
20040+
return 1;
20041+
#else
20042+
return 0;
20043+
#endif
20044+
}
20045+
1996820046
int ggml_cpu_has_metal(void) {
1996920047
#if defined(GGML_USE_METAL)
1997020048
return 1;

ggml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,7 @@ extern "C" {
22172217
GGML_API int ggml_cpu_has_fma (void);
22182218
GGML_API int ggml_cpu_has_neon (void);
22192219
GGML_API int ggml_cpu_has_arm_fma (void);
2220+
GGML_API int ggml_cpu_has_neon_mmla (void);
22202221
GGML_API int ggml_cpu_has_metal (void);
22212222
GGML_API int ggml_cpu_has_f16c (void);
22222223
GGML_API int ggml_cpu_has_fp16_va (void);
@@ -2242,6 +2243,10 @@ extern "C" {
22422243
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
22432244
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
22442245
typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
2246+
#if defined(__ARM_FEATURE_MMLA)
2247+
typedef void (*ggml_vec_mmla_t) (const int n, float * GGML_RESTRICT s0, float * GGML_RESTRICT s1, const void * GGML_RESTRICT x0,
2248+
const void * GGML_RESTRICT x1, const void * GGML_RESTRICT y0, const void * GGML_RESTRICT y1);
2249+
#endif
22452250

22462251
typedef struct {
22472252
const char * type_name;
@@ -2253,6 +2258,9 @@ extern "C" {
22532258
ggml_from_float_t from_float_reference;
22542259
ggml_vec_dot_t vec_dot;
22552260
enum ggml_type vec_dot_type;
2261+
#if defined(__ARM_FEATURE_MMLA)
2262+
ggml_vec_mmla_t vec_mmla;
2263+
#endif
22562264
} ggml_type_traits_t;
22572265

22582266
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);

llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10541,6 +10541,7 @@ const char * llama_print_system_info(void) {
1054110541
s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
1054210542
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
1054310543
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
10544+
s += "NEON_MMLA = " + std::to_string(ggml_cpu_has_neon_mmla()) + " | ";
1054410545
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
1054510546
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
1054610547
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";

0 commit comments

Comments
 (0)