@@ -514,6 +514,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
514514 .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
515515 .vec_dot = ggml_vec_dot_q8_0_q8_0,
516516 .vec_dot_type = GGML_TYPE_Q8_0,
517+ #if defined(__ARM_FEATURE_MMLA)
518+ .vec_mmla = ggml_vec_mmla_q8_0_q8_0,
519+ #endif
517520 },
518521 [GGML_TYPE_Q8_1] = {
519522 .type_name = "q8_1",
@@ -9801,6 +9804,9 @@ static void ggml_compute_forward_mul_mat(
98019804 ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
98029805 enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
98039806 ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
9807+ #if defined(__ARM_FEATURE_MMLA)
9808+ ggml_vec_mmla_t const vec_mmla = type_traits[type].vec_mmla;
9809+ #endif
98049810
98059811 GGML_ASSERT(ne0 == ne01);
98069812 GGML_ASSERT(ne1 == ne11);
@@ -9952,43 +9958,107 @@ static void ggml_compute_forward_mul_mat(
99529958
99539959 // attempt to reduce false-sharing (does not seem to make a difference)
99549960 float tmp[16];
9961+ #if defined(__ARM_FEATURE_MMLA)
9962+ float tmp1[16];
9963+ float tmp2[16];
99559964
9956- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9957- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9958- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9959- const int64_t i13 = (ir1/(ne12*ne1));
9960- const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
9961- const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
9965+ if ((vec_mmla != NULL) && (nr0 % 2 == 0) && (nr1 %2 == 0)) {
9966+ for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9967+ for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9968+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += 2) {
9969+ const int64_t i13 = (ir1/(ne12*ne11));
9970+ const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
9971+ const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
99629972
9963- // broadcast src0 into src1
9964- const int64_t i03 = i13/r3;
9965- const int64_t i02 = i12/r2;
9973+ // broadcast src0 into src1
9974+ const int64_t i03 = i13/r3;
9975+ const int64_t i02 = i12/r2;
9976+
9977+ const int64_t i1 = i11;
9978+ const int64_t i2 = i12;
9979+ const int64_t i3 = i13;
9980+
9981+ const int64_t i13_ = ((ir1+1)/(ne12*ne11));
9982+ const int64_t i12_ = ((ir1+1) - (i13_)*ne12*ne11)/ne11;
9983+ const int64_t i11_ = ((ir1+1) - (i13_)*ne12*ne11 - (i12_)*ne11);
9984+
9985+ // broadcast src0 into src1
9986+ const int64_t i03_ = i13_/r3;
9987+ const int64_t i02_ = i12_/r2;
9988+
9989+ const int64_t i1_ = i11_;
9990+ const int64_t i2_ = i12_;
9991+ const int64_t i3_ = i13_;
9992+
9993+ const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
9994+
9995+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
9996+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
9997+ // the original src1 data pointer, so we should index using the indices directly
9998+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
9999+ const char * src1_col = (const char *) wdata +
10000+ (src1_cont || src1->type != vec_dot_type
10001+ ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10002+ : (i11*nb11 + i12*nb12 + i13*nb13));
10003+
10004+ const char * src1_col_ = (const char *) wdata +
10005+ (src1_cont || src1->type != vec_dot_type
10006+ ? ((i11_) + (i12_)*ne11 + (i13_)*ne12*ne11)*row_size
10007+ : ((i11_)*nb11 + (i12_)*nb12 + (i13_)*nb13));
10008+
10009+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10010+ float * dst_col_ = (float *) ((char *) dst->data + ((i1_)*nb1 + (i2_)*nb2 + (i3_)*nb3));
10011+
10012+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += 2) {
10013+ vec_mmla(ne00, &tmp1[ir0 - iir0], &tmp2[ir0 - iir0], src0_row + ir0*nb01,
10014+ src0_row + (ir0+1)*nb01, src1_col, src1_col_);
10015+ }
10016+
10017+ memcpy(&dst_col[iir0], tmp1, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10018+ memcpy(&dst_col_[iir0], tmp2, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10019+ }
10020+ }
10021+ }
10022+ } else
10023+ #endif
10024+ {
10025+ for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
10026+ for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
10027+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
10028+ const int64_t i13 = (ir1/(ne12*ne1));
10029+ const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
10030+ const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
996610031
9967- const int64_t i1 = i11;
9968- const int64_t i2 = i12 ;
9969- const int64_t i3 = i13 ;
10032+ // broadcast src0 into src1
10033+ const int64_t i03 = i13/r3 ;
10034+ const int64_t i02 = i12/r2 ;
997010035
9971- const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
10036+ const int64_t i1 = i11;
10037+ const int64_t i2 = i12;
10038+ const int64_t i3 = i13;
997210039
9973- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
9974- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
9975- // the original src1 data pointer, so we should index using the indices directly
9976- // TODO: this is a bit of a hack, we should probably have a better way to handle this
9977- const char * src1_col = (const char *) wdata +
9978- (src1_cont || src1->type != vec_dot_type
9979- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
9980- : (i11*nb11 + i12*nb12 + i13*nb13));
10040+ const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
998110041
9982- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10042+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
10043+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
10044+ // the original src1 data pointer, so we should index using the indices directly
10045+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
10046+ const char * src1_col = (const char *) wdata +
10047+ (src1_cont || src1->type != vec_dot_type
10048+ ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10049+ : (i11*nb11 + i12*nb12 + i13*nb13));
998310050
9984- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9985- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
9986- //}
10051+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
998710052
9988- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9989- vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
10053+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10054+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
10055+ //}
10056+
10057+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10058+ vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
10059+ }
10060+ memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
999010061 }
9991- memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
999210062 }
999310063 }
999410064 }
@@ -19965,6 +20035,14 @@ int ggml_cpu_has_arm_fma(void) {
1996520035#endif
1996620036}
1996720037
20038+ int ggml_cpu_has_neon_mmla(void) {
20039+ #if defined(__ARM_FEATURE_MMLA)
20040+ return 1;
20041+ #else
20042+ return 0;
20043+ #endif
20044+ }
20045+
1996820046int ggml_cpu_has_metal(void) {
1996920047#if defined(GGML_USE_METAL)
1997020048 return 1;
0 commit comments