@@ -606,12 +606,12 @@ typedef struct {
606606static_assert (sizeof (block_q2_0 ) == sizeof (ggml_fp16_t ) + QK2_0 / 4 , "wrong q2_0 size/padding" );
607607
608608#define QK3_0 16
609- typedef union {
610- struct {
611- uint16_t pad [ 3 ];
612- ggml_fp16_t d ;
613- };
614- uint64_t qs ;
609+ typedef struct {
610+ ggml_fp16_t d ;
611+ // Instead of representing q3_0 as a packed format "...210210210210",
612+ // represent it as two planes: "...10101010" and "...2222"
613+ uint16_t qhi ; // The highest bit of each 3-bit number, packed together
614+ uint32_t qlo ; // The low 2-bits of each 3-bit number, packed together
615615} block_q3_0 ;
616616static_assert (sizeof (block_q3_0 ) == sizeof (ggml_fp16_t ) + QK3_0 * 3 / 8 , "wrong q3_0 size/padding" );
617617
@@ -691,17 +691,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y,
691691 const float d = max / -4 ;
692692 const float id = d ? 1.0f /d : 0.0f ;
693693
694- uint64_t qs = 0 ;
694+ uint32_t lo = 0 ;
695+ uint16_t hi = 0 ;
695696
696697 for (int l = 0 ; l < QK3_0 ; l ++ ) {
697698 const float v = x [i * QK3_0 + l ]* id ;
698699 const uint8_t vi = MIN (7 , (int8_t )roundf (v ) + 4 );
699700 assert (vi < 8 );
700- qs |= (uint64_t )vi << (l * 3 );
701+ lo |= (vi & 3 ) << (l * 2 );
702+ hi |= ((vi >> 2 ) & 1 ) << l ;
701703 }
702704
703- y [i ].qs = qs ;
704- y [i ].d = GGML_FP32_TO_FP16 (d ); // overwrite unused part of uint64_t qs
705+ y [i ].d = GGML_FP32_TO_FP16 (d );
706+ y [i ].qlo = lo ;
707+ y [i ].qhi = hi ;
705708 }
706709}
707710
@@ -1335,13 +1338,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in
13351338
13361339 for (int i = 0 ; i < nb ; i ++ ) {
13371340 const float d = GGML_FP16_TO_FP32 (x [i ].d );
1338- uint64_t qs = x [i ].qs ;
1341+ uint_fast32_t lo = x [i ].qlo ;
1342+ uint_fast32_t hi = x [i ].qhi << 2 ;
13391343 for (int l = 0 ; l < QK3_0 ; l ++ ) {
1340- const int8_t vi = qs & 7 ;
1344+ const int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
13411345 const float v = (vi - 4 )* d ;
13421346 y [i * QK3_0 + l ] = v ;
13431347 assert (!isnan (y [i * QK3_0 + l ]));
1344- qs >>= 3 ;
1348+ lo >>= 2 ;
1349+ hi >>= 1 ;
13451350 }
13461351 }
13471352}
@@ -2391,6 +2396,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
23912396 * s = sumf ;
23922397}
23932398
2399+ #if __AVX2__ || __AVX512F__
2400+ // Computes the dot product of signed 8-bit integers packed into 256-bit vectors,
2401+ // converting the result to 32-bit floats packed into a 256-bit vector.
2402+ static inline __m256 dotMul (__m256i bx , __m256i by ) {
2403+ # if __AVXVNNIINT8__
2404+ // Perform multiplication and sum to 32-bit values
2405+ const __m256i i32 = _mm256_dpbssd_epi32 (bx , by , _mm256_setzero_si256 ());
2406+ # else
2407+ // Get absolute values of x vectors
2408+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2409+ // Sign the values of the y vectors
2410+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2411+ // Perform multiplication and create 16-bit values
2412+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2413+
2414+ // Convert int16_t to int32_t by adding pairwise
2415+ const __m256i ones = _mm256_set1_epi16 (1 );
2416+ const __m256i i32 = _mm256_madd_epi16 (ones , dot );
2417+ # endif
2418+ // Convert int32_t to float
2419+ return _mm256_cvtepi32_ps (i32 );
2420+ }
2421+
2422+ // Return horizontal sum of 32-bit floats packed into a 256-bit vector.
2423+ static inline float horizontalSum (__m256 acc ) {
2424+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2425+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2426+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2427+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2428+ return _mm_cvtss_f32 (res );
2429+ }
2430+ #endif
2431+
23942432static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
23952433 assert (n % QK2_0 == 0 );
23962434 const int nb = n / QK2_0 ;
@@ -2420,30 +2458,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
24202458 // Load y vector
24212459 const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
24222460
2423- // Get absolute values of x vectors
2424- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2425- // Sign the values of the y vectors
2426- const __m256i sy = _mm256_sign_epi8 (by , bx );
2427- // Perform multiplication and create 16-bit values
2428- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2429-
2430- // Convert int16_t to int32_t by adding pairwise
2431- const __m256i ones = _mm256_set1_epi16 (1 );
2432- __m256i i32 = _mm256_madd_epi16 (ones , dot );
2433-
2434- // Convert int32_t to float
2435- __m256 p = _mm256_cvtepi32_ps (i32 );
2461+ // Do the product:
2462+ __m256 p = dotMul (bx , by );
24362463
24372464 // Apply the scale, and accumulate
24382465 acc = _mm256_fmadd_ps (scale , p , acc );
24392466 }
24402467
24412468 // Return horizontal sum of the acc vector
2442- __m128 res = _mm256_extractf128_ps (acc , 1 );
2443- res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2444- res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2445- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2446- sumf = _mm_cvtss_f32 (res );
2469+ sumf = horizontalSum (acc );
24472470#else
24482471 for (int i = 0 ; i < nb ; i ++ ) {
24492472 const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
@@ -2468,6 +2491,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
24682491 * s = sumf ;
24692492}
24702493
2494+ // Lookup table used to convert q3_0 to SIMD vectors.
2495+ // Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte.
2496+ // A zero bit turns into 0xFC, while a one bit turns into 0x00.
2497+ #define B0 (n ) 0x ## n
2498+ #define B1 (n ) B0(n ## FC), B0(n ## 00)
2499+ #define B2 (n ) B1(n ## FC), B1(n ## 00)
2500+ #define B3 (n ) B2(n ## FC), B2(n ## 00)
2501+ #define B4 (n ) B3(n ## FC), B3(n ## 00)
2502+ #define B5 (n ) B4(n ## FC), B4(n ## 00)
2503+ #define B6 (n ) B5(n ## FC), B5(n ## 00)
2504+ #define B7 (n ) B6(n ## FC), B6(n ## 00)
2505+ #define B8 ( ) B7( FC), B7( 00)
2506+ static const uint64_t ggml_q3_table [256 ] = { B8 () };
2507+
24712508static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
24722509 assert (n % QK3_0 == 0 );
24732510 const int nb = n / QK3_0 ;
@@ -2480,103 +2517,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
24802517
24812518#if defined(__AVX2__ )
24822519 // Initialize accumulator with zeros
2483- __m128 acc = _mm_setzero_ps ();
2520+ __m256 acc = _mm256_setzero_ps ();
2521+
24842522 for (int i = 0 ; i < nb /2 ; i ++ ) {
2485- const __m128 scale_y = _mm_set1_ps (y [i ].d );
2486- for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2487- // Compute combined scale for the block
2488- const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2489- const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2490-
2491- __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2492-
2493- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2494-
2495- // shift the copies to be able to reach all values
2496- // 255 192 128 64 0
2497- // | | | |
2498- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2499- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2500- // _______________________sssssfedcba98765432__________________________________________ shift right
2501- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2502- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2503- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2504- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2505- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2506- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2507-
2508- // add to itself in masked places to shift some values left one bit
2509- // 127 64 0
2510- // | | | | | | | | | | | | | | | |
2511- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2512- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2513- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2514- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2515- //
2516- // 255 192 128
2517- // | | | | | | | | | | | | | | | |
2518- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2519- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2520- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2521- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2522- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2523- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2524-
2525- // collect 16 bytes from 256 into 128 bits
2526- const __m256i shufmask = _mm256_set_epi8 (
2527- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2528- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2529- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2530-
2531- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2532-
2533- const __m128i mask = _mm_set1_epi8 (7 );
2534- bx = _mm_and_si128 (mask , bx );
2535-
2536- const __m128i off = _mm_set1_epi8 (4 );
2537- bx = _mm_sub_epi8 (bx , off );
2538-
2539- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2523+ __m256i bx = bytesFromCrumbs (x [i * 2 + 1 ].qlo , x [i * 2 ].qlo );
25402524
2541- // Get absolute values of x vectors
2542- const __m128i ax = _mm_sign_epi8 (bx , bx );
2543- // Sign the values of the y vectors
2544- const __m128i sy = _mm_sign_epi8 (by , bx );
2545- // Perform multiplication and create 16-bit values
2546- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2525+ __m256i const bxhi = _mm256_set_epi64x (
2526+ ggml_q3_table [x [i * 2 + 1 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 1 ].qhi & 0xFF ],
2527+ ggml_q3_table [x [i * 2 + 0 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 0 ].qhi & 0xFF ]);
25472528
2548- // Convert int16_t to int32_t by adding pairwise
2549- const __m128i ones = _mm_set1_epi16 (1 );
2550- __m128i i32 = _mm_madd_epi16 (dot , ones );
2529+ // OR the high bits (which also handles the sign):
2530+ bx = _mm256_or_si256 (bx , bxhi );
2531+
2532+ // Compute combined scale for the block
2533+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2534+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2535+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2536+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
25512537
2552- // Convert int32_t to float
2553- const __m128 p = _mm_cvtepi32_ps ( i32 );
2538+ // Load y vector
2539+ const __m256i by = _mm256_loadu_si256 (( const __m256i * ) y [ i ]. qs );
25542540
2555- // Apply the scale, and accumulate
2556- acc = _mm_fmadd_ps (scale , p , acc );
2557- }
2541+ // Do the product,
2542+ __m256 p = dotMul (bx , by );
2543+
2544+ // Apply the scale, and accumulate
2545+ acc = _mm256_fmadd_ps (scale , p , acc );
25582546 }
25592547
25602548 // Return horizontal sum of the acc vector
2561- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2562- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2563- sumf = _mm_cvtss_f32 (res );
2549+ sumf = horizontalSum (acc );
25642550#else
25652551 for (int i = 0 ; i < nb ; i ++ ) {
25662552 const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
25672553 const float d1 = y [i /2 ].d ;
25682554
2569- uint64_t qs0 = x [i ].qs ;
2555+ uint_fast32_t lo0 = x [i ].qlo ;
2556+ uint_fast32_t hi0 = x [i ].qhi << 2 ;
25702557 const int8_t * restrict p1 = y [i /2 ].qs + (i %2 )* QK3_0 ;
25712558
25722559 int sumi = 0 ;
2573- for (int j = 0 ; j < QK3_0 ; j ++ ) {
2574- const int8_t i0 = (int8_t )(qs0 & 7 ) - 4 ;
2575- const int_fast16_t i1 = p1 [j ];
2560+ for (int l = 0 ; l < QK3_0 ; l ++ ) {
2561+ const int8_t i0 = (int8_t )(( lo0 & 3 ) | (( hi0 & 4 ) - 4 )) ;
2562+ const int_fast16_t i1 = p1 [l ];
25762563
25772564 sumi += i0 * i1 ;
25782565
2579- qs0 >>= 3 ;
2566+ lo0 >>= 2 ;
2567+ hi0 >>= 1 ;
25802568 }
25812569 sumf += d0 * d1 * sumi ;
25822570 }
@@ -12064,11 +12052,13 @@ size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t h
1206412052 quantize_row_q3_0 (src + j , y , k );
1206512053
1206612054 for (int i = 0 ; i < nb ; i ++ ) {
12067- uint64_t qs = y [i ].qs ;
12055+ uint_fast32_t lo = y [i ].qlo ;
12056+ uint_fast32_t hi = y [i ].qhi << 2 ;
1206812057 for (int l = 0 ; l < QK3_0 ; l ++ ) {
12069- const int8_t vi = qs & 7 ;
12058+ int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
1207012059 hist [vi ]++ ;
12071- qs >>= 3 ;
12060+ lo >>= 2 ;
12061+ hi >>= 1 ;
1207212062 }
1207312063 }
1207412064 }
0 commit comments