@@ -4046,16 +4046,87 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx,
40464046 const int nb = n / qk ;
40474047
40484048 assert (n % qk == 0 );
4049+ #if defined(__ARM_FEATURE_MATMUL_INT8 )
4050+ assert ((nrc == 2 ) || (nrc == 1 ));
4051+ #else
40494052 assert (nrc == 1 );
4050- UNUSED (nrc );
4051- UNUSED (bx );
4052- UNUSED (by );
4053+ #endif
40534054
40544055 const block_q4_1 * restrict x = vx ;
40554056 const block_q8_1 * restrict y = vy ;
40564057
4058+ #if defined(__ARM_FEATURE_MATMUL_INT8 )
4059+ if (nrc == 2 ) {
4060+ const block_q4_1 * restrict vx0 = vx ;
4061+ const block_q4_1 * restrict vx1 = vx + bx ;
4062+ const block_q8_1 * restrict vy0 = vy ;
4063+ const block_q8_1 * restrict vy1 = vy + by ;
4064+
4065+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
4066+ float32x4_t summs0 = vdupq_n_f32 (0.0f );
4067+
4068+ for (int i = 0 ; i < nb ; i ++ ) {
4069+ const block_q4_1 * restrict b_x0 = & vx0 [i ];
4070+ const block_q4_1 * restrict b_x1 = & vx1 [i ];
4071+ const block_q8_1 * restrict b_y0 = & vy0 [i ];
4072+ const block_q8_1 * restrict b_y1 = & vy1 [i ];
4073+
4074+ float32x4_t summs_t = {GGML_FP16_TO_FP32 (b_x0 -> m ) * b_y0 -> s ,
4075+ GGML_FP16_TO_FP32 (b_x1 -> m ) * b_y0 -> s ,
4076+ GGML_FP16_TO_FP32 (b_x0 -> m ) * b_y1 -> s ,
4077+ GGML_FP16_TO_FP32 (b_x1 -> m ) * b_y1 -> s };
4078+ summs0 += summs_t ;
4079+
4080+ const uint8x16_t m4b = vdupq_n_u8 (0x0F );
4081+
4082+ const uint8x16_t v0_0 = vld1q_u8 (b_x0 -> qs );
4083+ const uint8x16_t v0_1 = vld1q_u8 (b_x1 -> qs );
4084+
4085+ // 4-bit -> 8-bit
4086+ const int8x16_t x0_l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
4087+ const int8x16_t x0_h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
4088+ const int8x16_t x1_l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
4089+ const int8x16_t x1_h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
4090+
4091+ // load y
4092+ const int8x16_t y0_l = vld1q_s8 (b_y0 -> qs );
4093+ const int8x16_t y0_h = vld1q_s8 (b_y0 -> qs + 16 );
4094+ const int8x16_t y1_l = vld1q_s8 (b_y1 -> qs );
4095+ const int8x16_t y1_h = vld1q_s8 (b_y1 -> qs + 16 );
4096+
4097+ // mmla into int32x4_t
4098+ float32x4_t scale = {GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
4099+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
4100+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
4101+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )};
4102+
4103+ int8x16_t l0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
4104+ int8x16_t l1 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
4105+
4106+ int8x16_t l2 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_h ), vreinterpretq_s64_s8 (x1_h )));
4107+ int8x16_t l3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (x0_h ), vreinterpretq_s64_s8 (x1_h )));
4108+
4109+ int8x16_t r0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (y0_l ), vreinterpretq_s64_s8 (y1_l )));
4110+ int8x16_t r1 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_l ), vreinterpretq_s64_s8 (y1_l )));
4111+
4112+ int8x16_t r2 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
4113+ int8x16_t r3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
4114+
4115+ sumv0 = vmlaq_f32 (sumv0 ,(vcvtq_f32_s32 (vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 (vdupq_n_s32 (0 ), l0 , r0 )),
4116+ l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
4117+ }
4118+
4119+ float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
4120+ float32x4_t sumv2 = vzip1q_f32 (sumv0 , sumv1 );
4121+ sumv2 = sumv2 + summs0 ;
4122+
4123+ vst1_f32 (s , vget_low_f32 (sumv2 ));
4124+ vst1_f32 (s + 16 , vget_high_f32 (sumv2 ));
4125+ } else
4126+ #endif
40574127 // TODO: add WASM SIMD
40584128#if defined(__ARM_NEON )
4129+ {
40594130 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
40604131 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
40614132
@@ -4097,6 +4168,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx,
40974168 }
40984169
40994170 * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs ;
4171+ }
41004172#elif defined(__AVX2__ ) || defined(__AVX__ )
41014173 // Initialize accumulator with zeros
41024174 __m256 acc = _mm256_setzero_ps ();
0 commit comments