@@ -274,6 +274,92 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
274274 }
275275}
276276
277+ template <int NT, int NR> static __global__ void dequantize_mul_mat_q4_0_test (const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
278+ const block_q4_0 * x = (const block_q4_0 *) vx;
279+ const block_q8_0 * y = (const block_q8_0 *) vy;
280+
281+ const int bid = blockIdx .x ;
282+ const int tid = threadIdx .x ;
283+
284+ __shared__ float tmp[NR][NT];
285+ for (int i = 0 ; i < NR; ++i) {
286+ tmp[i][tid] = 0 .0f ;
287+ }
288+
289+ const int nbc = (ncols + 16 *NT - 1 )/(16 *NT);
290+ const int nbm = ncols/QK8_0;
291+
292+ uint64_t xa0;
293+ uint64_t xa1;
294+
295+ const int8_t * xb0 = (const int8_t *) &xa0;
296+ const int8_t * xb1 = (const int8_t *) &xa1;
297+
298+ for (int ibc = 0 ; ibc < nbc; ++ibc) {
299+ const int iyb = (ibc*(16 *NT) + 16 *tid)/QK8_0;
300+ const int iyq = (ibc*(16 *NT) + 16 *tid)%QK8_0;
301+
302+ if (iyb >= nbm) {
303+ continue ;
304+ }
305+
306+ const int8_t * yb = (const int8_t *) &y[iyb].qs [iyq];
307+
308+ const float dy = y[iyb].d ;
309+
310+ for (int ibr = 0 ; ibr < NR; ++ibr) {
311+ const int ir = bid*NR + ibr;
312+ if (ir >= nrows) {
313+ continue ;
314+ }
315+
316+ // block offset
317+ const int ixo = (ir*ncols)/QK4_0 + iyb;
318+
319+ memcpy (&xa0, &x[ixo].qs [iyq/2 + 0 ], sizeof (uint64_t ));
320+ xa1 = xa0;
321+
322+ xa0 = (xa0 ) & 0x0F0F0F0F0F0F0F0F ;
323+ xa1 = (xa1 >> 4 ) & 0x0F0F0F0F0F0F0F0F ;
324+
325+ const float dx = x[ixo].d ;
326+
327+ // the (int) cast is probably unnecessary, but just to make sure the result is accumulated in 32 bits
328+ tmp[ibr][tid] += (
329+ ((int )(xb0[0 ] - 8 ))*yb[0 ] + ((int )(xb1[0 ] - 8 ))*yb[1 ] +
330+ ((int )(xb0[1 ] - 8 ))*yb[2 ] + ((int )(xb1[1 ] - 8 ))*yb[3 ] +
331+ ((int )(xb0[2 ] - 8 ))*yb[4 ] + ((int )(xb1[2 ] - 8 ))*yb[5 ] +
332+ ((int )(xb0[3 ] - 8 ))*yb[6 ] + ((int )(xb1[3 ] - 8 ))*yb[7 ] +
333+ ((int )(xb0[4 ] - 8 ))*yb[8 ] + ((int )(xb1[4 ] - 8 ))*yb[9 ] +
334+ ((int )(xb0[5 ] - 8 ))*yb[10 ] + ((int )(xb1[5 ] - 8 ))*yb[11 ] +
335+ ((int )(xb0[6 ] - 8 ))*yb[12 ] + ((int )(xb1[6 ] - 8 ))*yb[13 ] +
336+ ((int )(xb0[7 ] - 8 ))*yb[14 ] + ((int )(xb1[7 ] - 8 ))*yb[15 ]
337+ )*dx*dy;
338+ }
339+ }
340+
341+ // reduce
342+ __syncthreads ();
343+
344+ for (int s = NT/2 ; s > 0 ; s >>= 1 ) {
345+ if (tid < s) {
346+ for (int ibr = 0 ; ibr < NR; ++ibr) {
347+ tmp[ibr][tid] += tmp[ibr][tid + s];
348+ }
349+ }
350+ __syncthreads ();
351+ }
352+
353+ if (tid == 0 ) {
354+ for (int ibr = 0 ; ibr < NR; ++ibr) {
355+ const int ir = bid*NR + ibr;
356+ if (ir < nrows) {
357+ dst[ir] = tmp[ibr][0 ];
358+ }
359+ }
360+ }
361+ }
362+
277363static void dequantize_row_q4_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
278364 const int nb = k / QK4_0;
279365 dequantize_block_q4_0<<<nb, 1 , 0 , stream>>> (vx, y);
@@ -316,9 +402,14 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float
316402 // }
317403 // }
318404 // dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
319- const int block_size = 32 ;
320- GGML_ASSERT (ncols % block_size == 0 );
321- dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0 , stream>>> (vx, y, dst, ncols);
405+ // const int block_size = 32;
406+ // GGML_ASSERT(ncols % block_size == 0);
407+ // dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
408+
409+ const int NR = 1 ; // unroll rows (seems to not help)
410+ const int NT = 64 ; // number of thrads per row
411+
412+ dequantize_mul_mat_q4_0_test<NT, NR><<<(nrows + NR - 1 )/NR, NT, 0 , stream>>> (vx, y, dst, ncols, nrows);
322413}
323414
324415// TODO: optimize
0 commit comments