@@ -228,7 +228,7 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
228228template <int block_size> static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
229229 const block_q4_0 * x = (const block_q4_0 *) vx;
230230
231- const int row = blockIdx .x ;
231+ const int row = blockIdx .x * 2 + threadIdx . y ;
232232 const int tid = threadIdx .x ;
233233
234234 float tmp = 0 ;
@@ -305,9 +305,12 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float
305305 // }
306306 // }
307307 // dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
308- const int block_size = 32 ;
309- GGML_ASSERT (ncols % block_size == 0 );
310- dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0 , stream>>> (vx, y, dst, ncols);
308+ const int reduce_size = 32 ;
309+ const int rows_per_block = 2 ;
310+ const dim3 block_size (reduce_size, rows_per_block, 1 );
311+ GGML_ASSERT (nrows % rows_per_block == 0 );
312+ GGML_ASSERT (ncols % reduce_size == 0 );
313+ dequantize_mul_mat_q4_0<reduce_size><<<nrows / rows_per_block, block_size, 0 , stream>>> (vx, y, dst, ncols);
311314}
312315
313316// TODO: optimize
0 commit comments