@@ -2405,19 +2405,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
24052405 // Initialize accumulator with zeros
24062406 __m256 acc = _mm256_setzero_ps ();
24072407
2408- for (int i = 0 ; i < nb ; i += 2 ) {
2409- __m256i bx = bytesFromCrumbs (x [i + 1 ].qs , x [i ].qs );
2408+ for (int i = 0 ; i < nb / 2 ; i ++ ) {
2409+ __m256i bx = bytesFromCrumbs (x [i * 2 + 1 ].qs , x [i * 2 ].qs );
24102410
24112411 // Compute combined scale for the block
2412- const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 0 ].d ) * y [i /2 ].d );
2413- const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 1 ].d ) * y [i /2 ].d );
2414- const __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2412+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2413+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2414+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2415+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
24152416
24162417 const __m256i off = _mm256_set1_epi8 (2 );
24172418 bx = _mm256_sub_epi8 (bx , off );
24182419
24192420 // Load y vector
2420- const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i / 2 ].qs );
2421+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
24212422
24222423 // Get absolute values of x vectors
24232424 const __m256i ax = _mm256_sign_epi8 (bx , bx );
@@ -2470,6 +2471,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
24702471static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
24712472 assert (n % QK3_0 == 0 );
24722473 const int nb = n / QK3_0 ;
2474+ assert (nb % 2 == 0 );
24732475
24742476 const block_q3_0 * restrict x = vx ;
24752477 const block_q8_0 * restrict y = vy ;
@@ -2479,77 +2481,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
24792481#if defined(__AVX2__ )
24802482 // Initialize accumulator with zeros
24812483 __m128 acc = _mm_setzero_ps ();
2482- for (int i = 0 ; i < nb ; i ++ ) {
2483- // Compute combined scale for the block
2484- const __m128 scale = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ) * y [i /2 ].d );
2485-
2486- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2487- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2488-
2489- __m256i bxx = _mm256_set1_epi64x (x [i ].qs );
2490-
2491- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2492-
2493- // shift the copies to be able to reach all values
2494- // 255 192 128 64 0
2495- // | | | |
2496- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2497- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2498- // _______________________sssssfedcba98765432__________________________________________ shift right
2499- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2500- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2501- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2502- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2503-
2504- // add to itself in masked places to shift some values left one bit
2505- // 127 64 0
2506- // | | | | | | | | | | | | | | | |
2507- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2508- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2509- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2510- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2511- //
2512- // 255 192 128
2513- // | | | | | | | | | | | | | | | |
2514- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2515- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2516- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2517- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2518- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2519- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2520-
2521- // collect 16 bytes from 256 into 128 bits
2522- const __m256i shufmask = _mm256_set_epi8 (
2523- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2524- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2525- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2484+ for (int i = 0 ; i < nb /2 ; i ++ ) {
2485+ const __m128 scale_y = _mm_set1_ps (y [i ].d );
2486+ for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2487+ // Compute combined scale for the block
2488+ const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2489+ const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2490+
2491+ __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2492+
2493+ // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2494+
2495+ // shift the copies to be able to reach all values
2496+ // 255 192 128 64 0
2497+ // | | | |
2498+ // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2499+ // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2500+ // _______________________sssssfedcba98765432__________________________________________ shift right
2501+ // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2502+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2503+ // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2504+ const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2505+ const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2506+ bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2507+
2508+ // add to itself in masked places to shift some values left one bit
2509+ // 127 64 0
2510+ // | | | | | | | | | | | | | | | |
2511+ // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2512+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2513+ // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2514+ // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2515+ //
2516+ // 255 192 128
2517+ // | | | | | | | | | | | | | | | |
2518+ // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2519+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2520+ // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2521+ // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2522+ const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2523+ bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2524+
2525+ // collect 16 bytes from 256 into 128 bits
2526+ const __m256i shufmask = _mm256_set_epi8 (
2527+ 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2528+ -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2529+ bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2530+
2531+ __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2532+
2533+ const __m128i mask = _mm_set1_epi8 (7 );
2534+ bx = _mm_and_si128 (mask , bx );
2535+
2536+ const __m128i off = _mm_set1_epi8 (4 );
2537+ bx = _mm_sub_epi8 (bx , off );
2538+
2539+ const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
25262540
2527- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2528-
2529- const __m128i mask = _mm_set1_epi8 (7 );
2530- bx = _mm_and_si128 (mask , bx );
2531-
2532- const __m128i off = _mm_set1_epi8 (4 );
2533- bx = _mm_sub_epi8 (bx , off );
2534-
2535- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i /2 ].qs + (i %2 )* QK3_0 ));
2536-
2537- // Get absolute values of x vectors
2538- const __m128i ax = _mm_sign_epi8 (bx , bx );
2539- // Sign the values of the y vectors
2540- const __m128i sy = _mm_sign_epi8 (by , bx );
2541- // Perform multiplication and create 16-bit values
2542- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2541+ // Get absolute values of x vectors
2542+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2543+ // Sign the values of the y vectors
2544+ const __m128i sy = _mm_sign_epi8 (by , bx );
2545+ // Perform multiplication and create 16-bit values
2546+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
25432547
2544- // Convert int16_t to int32_t by adding pairwise
2545- const __m128i ones = _mm_set1_epi16 (1 );
2546- __m128i i32 = _mm_madd_epi16 (dot , ones );
2548+ // Convert int16_t to int32_t by adding pairwise
2549+ const __m128i ones = _mm_set1_epi16 (1 );
2550+ __m128i i32 = _mm_madd_epi16 (dot , ones );
25472551
2548- // Convert int32_t to float
2549- const __m128 p = _mm_cvtepi32_ps (i32 );
2552+ // Convert int32_t to float
2553+ const __m128 p = _mm_cvtepi32_ps (i32 );
25502554
2551- // Apply the scale, and accumulate
2552- acc = _mm_fmadd_ps (scale , p , acc );
2555+ // Apply the scale, and accumulate
2556+ acc = _mm_fmadd_ps (scale , p , acc );
2557+ }
25532558 }
25542559
25552560 // Return horizontal sum of the acc vector
0 commit comments