@@ -235,8 +235,8 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
235235 __shared__ float tmp[block_size]; // separate sum for each thread
236236 tmp[tid] = 0 ;
237237
238- for (int i = 0 ; i < ncols/block_size; i += 2 ) {
239- const int col = i*block_size + 2 *tid;
238+ for (int i = 0 ; i < ncols/block_size; i += 4 ) {
239+ const int col = i*block_size + 4 *tid;
240240
241241 // dequantize
242242 const float d0 = x[(row*ncols + col)/QK4_0].d ;
@@ -245,19 +245,21 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
245245 const uint8_t * p0 = x[(row*ncols + col)/QK4_0].qs ;
246246 const int8_t * p1 = y[col/QK8_0].qs ;
247247
248- const uint8_t vui0 = p0[((row*ncols + col)%QK4_0)/2 ];
248+ const uint8_t vui00 = p0[((row*ncols + col)%QK4_0)/2 ];
249+ const uint8_t vui01 = p0[((row*ncols + col + 2 )%QK4_0)/2 ];
249250 const int vi10 = p1[(col + 0 )%QK8_0];
250251 const int vi11 = p1[(col + 1 )%QK8_0];
252+ const int vi12 = p1[(col + 2 )%QK8_0];
253+ const int vi13 = p1[(col + 3 )%QK8_0];
251254
252- const int vi00 = vui0 & 0xF ;
253- const int vi01 = vui0 >> 4 ;
254-
255- const float v0 = (vi00 - 8 )*vi10*d0*d1;
256- const float v1 = (vi01 - 8 )*vi11*d0*d1;
255+ const int vi00 = vui00 & 0xF ;
256+ const int vi01 = vui00 >> 4 ;
257+ const int vi02 = vui01 & 0xF ;
258+ const int vi03 = vui01 >> 4 ;
257259
258260 // matrix multiplication
259- tmp[tid] += v0 ;
260- tmp[tid] += v1 ;
261+ const int sumi = (vi00 - 8 )*vi10 + (vi01 - 8 )*vi11 + (vi02 - 8 )*vi12 + (vi03 - 8 )*vi13 ;
262+ tmp[tid] += sumi*d0*d1 ;
261263 }
262264
263265 // sum up partial sums and write back result
0 commit comments