@@ -662,12 +662,12 @@ typedef struct {
662662static_assert (sizeof (block_q2_0 ) == sizeof (ggml_fp16_t ) + QK2_0 / 4 , "wrong q2_0 size/padding" );
663663
664664#define QK3_0 16
665- typedef union {
666- struct {
667- uint16_t pad [ 3 ];
668- ggml_fp16_t d ;
669- };
670- uint64_t qs ;
665+ typedef struct {
666+ ggml_fp16_t d ;
667+ // Instead of representing q3_0 as a packed format "...210210210210",
668+ // represent it as two planes: "...10101010" and "...2222"
669+ uint16_t qhi ; // The highest bit of each 3-bit number, packed together
670+ uint32_t qlo ; // The low 2-bits of each 3-bit number, packed together
671671} block_q3_0 ;
672672static_assert (sizeof (block_q3_0 ) == sizeof (ggml_fp16_t ) + QK3_0 * 3 / 8 , "wrong q3_0 size/padding" );
673673
@@ -762,17 +762,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y,
762762 const float d = max / -4 ;
763763 const float id = d ? 1.0f /d : 0.0f ;
764764
765- uint64_t qs = 0 ;
765+ uint32_t lo = 0 ;
766+ uint16_t hi = 0 ;
766767
767768 for (int l = 0 ; l < QK3_0 ; l ++ ) {
768769 const float v = x [i * QK3_0 + l ]* id ;
769770 const uint8_t vi = MIN (7 , (int8_t )roundf (v ) + 4 );
770771 assert (vi < 8 );
771- qs |= (uint64_t )vi << (l * 3 );
772+ lo |= (vi & 3 ) << (l * 2 );
773+ hi |= ((vi >> 2 ) & 1 ) << l ;
772774 }
773775
774- y [i ].qs = qs ;
775- y [i ].d = GGML_FP32_TO_FP16 (d ); // overwrite unused part of uint64_t qs
776+ y [i ].d = GGML_FP32_TO_FP16 (d );
777+ y [i ].qlo = lo ;
778+ y [i ].qhi = hi ;
776779 }
777780}
778781
@@ -1573,13 +1576,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in
15731576
15741577 for (int i = 0 ; i < nb ; i ++ ) {
15751578 const float d = GGML_FP16_TO_FP32 (x [i ].d );
1576- uint64_t qs = x [i ].qs ;
1579+ uint_fast32_t lo = x [i ].qlo ;
1580+ uint_fast32_t hi = x [i ].qhi << 2 ;
15771581 for (int l = 0 ; l < QK3_0 ; l ++ ) {
1578- const int8_t vi = qs & 7 ;
1582+ const int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
15791583 const float v = (vi - 4 )* d ;
15801584 y [i * QK3_0 + l ] = v ;
15811585 assert (!isnan (y [i * QK3_0 + l ]));
1582- qs >>= 3 ;
1586+ lo >>= 2 ;
1587+ hi >>= 1 ;
15831588 }
15841589 }
15851590}
@@ -2525,6 +2530,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
25252530 * s = sumf ;
25262531}
25272532
2533+ #if __AVX2__ || __AVX512F__
2534+ // Computes the dot product of signed 8-bit integers packed into 256-bit vectors,
2535+ // converting the result to 32-bit floats packed into a 256-bit vector.
2536+ static inline __m256 dotMul (__m256i bx , __m256i by ) {
2537+ # if __AVXVNNIINT8__
2538+ // Perform multiplication and sum to 32-bit values
2539+ const __m256i i32 = _mm256_dpbssd_epi32 (bx , by , _mm256_setzero_si256 ());
2540+ # else
2541+ // Get absolute values of x vectors
2542+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2543+ // Sign the values of the y vectors
2544+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2545+ // Perform multiplication and create 16-bit values
2546+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2547+
2548+ // Convert int16_t to int32_t by adding pairwise
2549+ const __m256i ones = _mm256_set1_epi16 (1 );
2550+ const __m256i i32 = _mm256_madd_epi16 (ones , dot );
2551+ # endif
2552+ // Convert int32_t to float
2553+ return _mm256_cvtepi32_ps (i32 );
2554+ }
2555+
2556+ // Return horizontal sum of 32-bit floats packed into a 256-bit vector.
2557+ static inline float horizontalSum (__m256 acc ) {
2558+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2559+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2560+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2561+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2562+ return _mm_cvtss_f32 (res );
2563+ }
2564+ #endif
2565+
25282566static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
25292567 assert (n % QK2_0 == 0 );
25302568 const int nb = n / QK2_0 ;
@@ -2554,30 +2592,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
25542592 // Load y vector
25552593 const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
25562594
2557- // Get absolute values of x vectors
2558- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2559- // Sign the values of the y vectors
2560- const __m256i sy = _mm256_sign_epi8 (by , bx );
2561- // Perform multiplication and create 16-bit values
2562- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2563-
2564- // Convert int16_t to int32_t by adding pairwise
2565- const __m256i ones = _mm256_set1_epi16 (1 );
2566- __m256i i32 = _mm256_madd_epi16 (ones , dot );
2567-
2568- // Convert int32_t to float
2569- __m256 p = _mm256_cvtepi32_ps (i32 );
2595+ // Do the product:
2596+ __m256 p = dotMul (bx , by );
25702597
25712598 // Apply the scale, and accumulate
25722599 acc = _mm256_fmadd_ps (scale , p , acc );
25732600 }
25742601
25752602 // Return horizontal sum of the acc vector
2576- __m128 res = _mm256_extractf128_ps (acc , 1 );
2577- res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2578- res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2579- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2580- sumf = _mm_cvtss_f32 (res );
2603+ sumf = horizontalSum (acc );
25812604#else
25822605 for (int i = 0 ; i < nb ; i ++ ) {
25832606 const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
@@ -2602,6 +2625,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
26022625 * s = sumf ;
26032626}
26042627
2628+ // Lookup table used to convert q3_0 to SIMD vectors.
2629+ // Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte.
2630+ // A zero bit turns into 0xFC, while a one bit turns into 0x00.
2631+ #define B0 (n ) 0x ## n
2632+ #define B1 (n ) B0(n ## FC), B0(n ## 00)
2633+ #define B2 (n ) B1(n ## FC), B1(n ## 00)
2634+ #define B3 (n ) B2(n ## FC), B2(n ## 00)
2635+ #define B4 (n ) B3(n ## FC), B3(n ## 00)
2636+ #define B5 (n ) B4(n ## FC), B4(n ## 00)
2637+ #define B6 (n ) B5(n ## FC), B5(n ## 00)
2638+ #define B7 (n ) B6(n ## FC), B6(n ## 00)
2639+ #define B8 ( ) B7( FC), B7( 00)
2640+ static const uint64_t ggml_q3_table [256 ] = { B8 () };
2641+
26052642static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
26062643 assert (n % QK3_0 == 0 );
26072644 const int nb = n / QK3_0 ;
@@ -2614,103 +2651,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
26142651
26152652#if defined(__AVX2__ )
26162653 // Initialize accumulator with zeros
2617- __m128 acc = _mm_setzero_ps ();
2654+ __m256 acc = _mm256_setzero_ps ();
2655+
26182656 for (int i = 0 ; i < nb /2 ; i ++ ) {
2619- const __m128 scale_y = _mm_set1_ps (y [i ].d );
2620- for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2621- // Compute combined scale for the block
2622- const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2623- const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2624-
2625- __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2626-
2627- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2628-
2629- // shift the copies to be able to reach all values
2630- // 255 192 128 64 0
2631- // | | | |
2632- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2633- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2634- // _______________________sssssfedcba98765432__________________________________________ shift right
2635- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2636- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2637- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2638- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2639- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2640- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2641-
2642- // add to itself in masked places to shift some values left one bit
2643- // 127 64 0
2644- // | | | | | | | | | | | | | | | |
2645- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2646- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2647- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2648- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2649- //
2650- // 255 192 128
2651- // | | | | | | | | | | | | | | | |
2652- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2653- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2654- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2655- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2656- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2657- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2658-
2659- // collect 16 bytes from 256 into 128 bits
2660- const __m256i shufmask = _mm256_set_epi8 (
2661- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2662- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2663- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2664-
2665- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2666-
2667- const __m128i mask = _mm_set1_epi8 (7 );
2668- bx = _mm_and_si128 (mask , bx );
2669-
2670- const __m128i off = _mm_set1_epi8 (4 );
2671- bx = _mm_sub_epi8 (bx , off );
2672-
2673- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2657+ __m256i bx = bytes_from_crumbs (x [i * 2 + 1 ].qlo , x [i * 2 ].qlo );
26742658
2675- // Get absolute values of x vectors
2676- const __m128i ax = _mm_sign_epi8 (bx , bx );
2677- // Sign the values of the y vectors
2678- const __m128i sy = _mm_sign_epi8 (by , bx );
2679- // Perform multiplication and create 16-bit values
2680- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2659+ __m256i const bxhi = _mm256_set_epi64x (
2660+ ggml_q3_table [x [i * 2 + 1 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 1 ].qhi & 0xFF ],
2661+ ggml_q3_table [x [i * 2 + 0 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 0 ].qhi & 0xFF ]);
26812662
2682- // Convert int16_t to int32_t by adding pairwise
2683- const __m128i ones = _mm_set1_epi16 (1 );
2684- __m128i i32 = _mm_madd_epi16 (dot , ones );
2663+ // OR the high bits (which also handles the sign):
2664+ bx = _mm256_or_si256 (bx , bxhi );
26852665
2686- // Convert int32_t to float
2687- const __m128 p = _mm_cvtepi32_ps (i32 );
2666+ // Compute combined scale for the block
2667+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2668+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2669+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2670+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
26882671
2689- // Apply the scale, and accumulate
2690- acc = _mm_fmadd_ps (scale , p , acc );
2691- }
2672+ // Load y vector
2673+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2674+
2675+ // Do the product,
2676+ __m256 p = dotMul (bx , by );
2677+
2678+ // Apply the scale, and accumulate
2679+ acc = _mm256_fmadd_ps (scale , p , acc );
26922680 }
26932681
26942682 // Return horizontal sum of the acc vector
2695- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2696- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2697- sumf = _mm_cvtss_f32 (res );
2683+ sumf = horizontalSum (acc );
26982684#else
26992685 for (int i = 0 ; i < nb ; i ++ ) {
27002686 const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
27012687 const float d1 = y [i /2 ].d ;
27022688
2703- uint64_t qs0 = x [i ].qs ;
2689+ uint_fast32_t lo0 = x [i ].qlo ;
2690+ uint_fast32_t hi0 = x [i ].qhi << 2 ;
27042691 const int8_t * restrict p1 = y [i /2 ].qs + (i %2 )* QK3_0 ;
27052692
27062693 int sumi = 0 ;
2707- for (int j = 0 ; j < QK3_0 ; j ++ ) {
2708- const int8_t i0 = (int8_t )(qs0 & 7 ) - 4 ;
2709- const int_fast16_t i1 = p1 [j ];
2694+ for (int l = 0 ; l < QK3_0 ; l ++ ) {
2695+ const int8_t i0 = (int8_t )(( lo0 & 3 ) | (( hi0 & 4 ) - 4 )) ;
2696+ const int_fast16_t i1 = p1 [l ];
27102697
27112698 sumi += i0 * i1 ;
27122699
2713- qs0 >>= 3 ;
2700+ lo0 >>= 2 ;
2701+ hi0 >>= 1 ;
27142702 }
27152703 sumf += d0 * d1 * sumi ;
27162704 }
@@ -12497,11 +12485,13 @@ size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t h
1249712485 quantize_row_q3_0 (src + j , y , k );
1249812486
1249912487 for (int i = 0 ; i < nb ; i ++ ) {
12500- uint64_t qs = y [i ].qs ;
12488+ uint_fast32_t lo = y [i ].qlo ;
12489+ uint_fast32_t hi = y [i ].qhi << 2 ;
1250112490 for (int l = 0 ; l < QK3_0 ; l ++ ) {
12502- const int8_t vi = qs & 7 ;
12491+ int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
1250312492 hist [vi ]++ ;
12504- qs >>= 3 ;
12493+ lo >>= 2 ;
12494+ hi >>= 1 ;
1250512495 }
1250612496 }
1250712497 }
0 commit comments