@@ -427,9 +427,35 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427427// quantization
428428//
429429
430- // AVX routines provided by GH user Const-me
431- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
432430#if __AVX2__ || __AVX512F__
431+ // Unpack 32 2-bit fields into 32 bytes
432+ // The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval
433+ static inline __m256i bytesFromCrumbs (uint32_t packed_hi , uint32_t packed_lo ) {
434+ __m128i bx_hi = _mm_set1_epi32 (packed_hi );
435+ __m128i bx_lo = _mm_set1_epi32 (packed_lo );
436+ __m256i bx = _mm256_set_m128i (bx_hi , bx_lo );
437+
438+ // shift counts to get all bit pairs in lowest position of each byte
439+ const __m256i shift256 = _mm256_set_epi32 (6 , 4 , 2 , 0 ,
440+ 6 , 4 , 2 , 0 );
441+ bx = _mm256_srlv_epi32 (bx , shift256 );
442+
443+ const __m256i shufmask = _mm256_set_epi8 (15 ,11 , 7 , 3 ,
444+ 14 ,10 , 6 , 2 ,
445+ 13 , 9 , 5 , 1 ,
446+ 12 , 8 , 4 , 0 ,
447+ 15 ,11 , 7 , 3 ,
448+ 14 ,10 , 6 , 2 ,
449+ 13 , 9 , 5 , 1 ,
450+ 12 , 8 , 4 , 0 );
451+ bx = _mm256_shuffle_epi8 (bx , shufmask );
452+
453+ const __m256i mask = _mm256_set1_epi8 (3 );
454+ bx = _mm256_and_si256 (mask , bx );
455+
456+ return bx ;
457+ }
458+
433459// Unpack 32 4-bit fields into 32 bytes
434460// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435461static inline __m256i bytesFromNibbles ( const uint8_t * rsi )
@@ -2368,6 +2394,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
23682394static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
23692395 assert (n % QK2_0 == 0 );
23702396 const int nb = n / QK2_0 ;
2397+ assert (nb % 2 == 0 );
23712398
23722399 const block_q2_0 * restrict x = vx ;
23732400 const block_q8_0 * restrict y = vy ;
@@ -2376,49 +2403,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
23762403
23772404#if defined(__AVX2__ )
23782405 // Initialize accumulator with zeros
2379- __m128 acc = _mm_setzero_ps ();
2380-
2381- for (int i = 0 ; i < nb ; i ++ ) {
2382- // Compute combined scale for the block
2383- const __m128 scale = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ) * y [i /2 ].d );
2384-
2385- __m128i bx = _mm_set1_epi32 (x [i ].qs );
2406+ __m256 acc = _mm256_setzero_ps ();
23862407
2387- // shift counts to get all bit pairs in lowest position of each byte
2388- const __m128i shift128 = _mm_set_epi32 (6 , 4 , 2 , 0 );
2389- bx = _mm_srlv_epi32 (bx , shift128 );
2408+ for (int i = 0 ; i < nb ; i += 2 ) {
2409+ __m256i bx = bytesFromCrumbs (x [i + 1 ].qs , x [i ].qs );
23902410
2391- const __m128i shufmask = _mm_set_epi8 (15 ,11 ,7 ,3 ,14 ,10 ,6 ,2 ,13 ,9 ,5 ,1 ,12 ,8 ,4 ,0 );
2392- bx = _mm_shuffle_epi8 (bx , shufmask );
2411+ // Compute combined scale for the block
2412+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 0 ].d ) * y [i /2 ].d );
2413+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 1 ].d ) * y [i /2 ].d );
2414+ const __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
23932415
2394- const __m128i mask = _mm_set1_epi8 ( 3 );
2395- bx = _mm_and_si128 ( mask , bx );
2416+ const __m256i off = _mm256_set1_epi8 ( 2 );
2417+ bx = _mm256_sub_epi8 ( bx , off );
23962418
2397- const __m128i off = _mm_set1_epi8 (2 );
2398- bx = _mm_sub_epi8 (bx , off );
2399-
2400- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i /2 ].qs + (i %2 )* QK2_0 ));
2419+ // Load y vector
2420+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i /2 ].qs );
24012421
24022422 // Get absolute values of x vectors
2403- const __m128i ax = _mm_sign_epi8 (bx , bx );
2423+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
24042424 // Sign the values of the y vectors
2405- const __m128i sy = _mm_sign_epi8 (by , bx );
2425+ const __m256i sy = _mm256_sign_epi8 (by , bx );
24062426 // Perform multiplication and create 16-bit values
2407- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2427+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
24082428
24092429 // Convert int16_t to int32_t by adding pairwise
2410- const __m128i ones = _mm_set1_epi16 (1 );
2411- __m128i i32 = _mm_madd_epi16 ( dot , ones );
2430+ const __m256i ones = _mm256_set1_epi16 (1 );
2431+ __m256i i32 = _mm256_madd_epi16 ( ones , dot );
24122432
24132433 // Convert int32_t to float
2414- const __m128 p = _mm_cvtepi32_ps (i32 );
2434+ __m256 p = _mm256_cvtepi32_ps (i32 );
24152435
24162436 // Apply the scale, and accumulate
2417- acc = _mm_fmadd_ps (scale , p , acc );
2437+ acc = _mm256_fmadd_ps (scale , p , acc );
24182438 }
24192439
24202440 // Return horizontal sum of the acc vector
2421- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2441+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2442+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2443+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
24222444 res = _mm_add_ss (res , _mm_movehdup_ps (res ));
24232445 sumf = _mm_cvtss_f32 (res );
24242446#else
0 commit comments