@@ -246,6 +246,168 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
246246
247247 return q_weight , scale , zero_point
248248
249+ def quant_tensor_k_quant_cpu (data , num_bits = 4 , group_size = 32 ):
250+ """Quantize tensor per group based on k quant.
251+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
252+
253+ Args:
254+ data : input weight
255+ num_bits (int, optional): num_bits. Defaults to 4.
256+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
257+
258+ Returns:
259+ output: quantized weight
260+ scale: scale
261+ zero_point: zero point
262+ """
263+ data = np .reshape (data , (- 1 , group_size )).astype (np .float32 ) # (nb, group_size)
264+ maxq = 2 ** num_bits - 1
265+ minq = 0
266+ sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
267+ av_x = np .sqrt (sum_x2 / group_size ) # (nb, 1)
268+ weights = np .add (av_x , np .abs (data )) # (nb, group_size)
269+ rmin = np .min (data , axis = 1 , keepdims = True ) # (nb, 1)
270+ rmax = np .max (data , axis = 1 , keepdims = True ) # (nb, 1)
271+ sum_w = np .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
272+ sum_x = np .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
273+ iscale = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
274+ mask = rmin != rmax
275+ iscale [mask ] = (maxq - minq ) / (rmax [mask ] - rmin [mask ])
276+ scale = 1 / iscale
277+ quant_data = np .clip (np .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
278+ diff = scale * quant_data + rmin - data # (nb, group_size)
279+ best_mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
280+ nstep = 20
281+ rdelta = 0.1
282+ # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
283+ rrmin = - 1
284+ for is_ in range (nstep ):
285+ iscale_new = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
286+ factor = np .array ([rrmin + rdelta * is_ + maxq - minq ]).astype (data .dtype )[0 ]
287+ mask = rmin != rmax
288+ iscale_new [mask ] = factor / (rmax [mask ] - rmin [mask ])
289+ quant_data_new = np .clip (np .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
290+ mul_weights_quant_data_new = weights * quant_data_new
291+ sum_l = np .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
292+ sum_l2 = np .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
293+ sum_xl = np .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
294+ D = np .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
295+
296+ this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
297+ this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
298+
299+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
300+ mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
301+
302+ mad_1 = np .array (mad )
303+ best_mad_1 = np .array (best_mad )
304+ idx_to_replace = np .where (mad_1 < best_mad_1 )[0 ]
305+ quant_data [idx_to_replace , :] = quant_data_new [idx_to_replace , :]
306+ best_mad [idx_to_replace ] = mad [idx_to_replace ]
307+ scale [idx_to_replace ] = this_scale [idx_to_replace ]
308+ rmin [idx_to_replace ] = this_min [idx_to_replace ]
309+
310+ zero_point = np .clip ((( - rmin ) / scale ).round (), 0 , maxq ).astype ("uint8" )
311+ scale = scale .astype (np .float64 )
312+ q_weight = np .empty_like (data , dtype = scale .dtype )
313+ np .divide (data , scale , out = q_weight )
314+ np .add (q_weight , zero_point , out = q_weight )
315+ np .round (q_weight , out = q_weight )
316+ np .clip (q_weight , minq , maxq , out = q_weight )
317+
318+ return q_weight , scale , zero_point
319+
320+ def quant_tensor_k_quant_cuda (data , num_bits = 4 , group_size = 32 ):
321+ """Quantize tensor per group based on k quant.
322+ Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
323+
324+ Args:
325+ data : input weight
326+ num_bits (int, optional): num_bits. Defaults to 4.
327+ group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
328+
329+ Returns:
330+ output: quantized weight
331+ scale: scale
332+ zero_point: zero point
333+ """
334+ try :
335+ import cupy as cp
336+ import torch
337+ if torch .cuda .is_available ():
338+ data = cp .asarray (data )
339+ data = data .reshape ((- 1 , group_size )).astype (np .float32 ) # (nb, group_size)
340+ nb = data .shape [0 ]
341+ maxq = 2 ** num_bits - 1
342+ minq = 0
343+ sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
344+ av_x = np .sqrt (sum_x2 / group_size ) # (nb, 1)
345+ weights = np .add (av_x , np .abs (data )) # (nb, group_size)
346+ rmin = np .min (data , axis = 1 , keepdims = True ) # (nb, 1)
347+ rmax = np .max (data , axis = 1 , keepdims = True ) # (nb, 1)
348+ sum_w = np .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
349+ sum_x = np .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
350+ iscale = cp .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
351+ mask = rmin != rmax
352+ iscale [mask ] = (maxq - minq ) / (rmax [mask ] - rmin [mask ])
353+ scale = 1 / iscale
354+ quant_data = np .clip (np .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
355+ diff = scale * quant_data + rmin - data # (nb, group_size)
356+ best_mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
357+ nstep = 20
358+ rdelta = 0.1
359+ rrmin = - 1
360+ for is_ in range (nstep ):
361+ iscale_new = cp .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
362+ factor = cp .array ([rrmin + rdelta * is_ + maxq - minq ]).astype (data .dtype )[0 ]
363+ mask = rmin != rmax
364+ iscale_new [mask ] = factor / (rmax [mask ] - rmin [mask ])
365+ quant_data_new = np .clip (np .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
366+ mul_weights_quant_data_new = weights * quant_data_new
367+ sum_l = np .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
368+ sum_l2 = np .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
369+ sum_xl = np .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
370+ D = np .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
371+
372+ this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
373+ this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
374+
375+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
376+ mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
377+
378+ mad_1 = cp .array (mad )
379+ best_mad_1 = cp .array (best_mad )
380+ idx_to_replace = np .where (mad_1 < best_mad_1 )[0 ]
381+ quant_data [idx_to_replace , :] = quant_data_new [idx_to_replace , :]
382+ best_mad [idx_to_replace ] = mad [idx_to_replace ]
383+ scale [idx_to_replace ] = this_scale [idx_to_replace ]
384+ rmin [idx_to_replace ] = this_min [idx_to_replace ]
385+
386+ zero_point = np .clip ((( - rmin ) / scale ).round (), 0 , maxq ).astype ("uint8" )
387+ scale = scale .astype (np .float64 )
388+ q_weight = np .empty_like (data , dtype = scale .dtype )
389+ np .divide (data , scale , out = q_weight )
390+ np .add (q_weight , zero_point , out = q_weight )
391+ np .round (q_weight , out = q_weight )
392+ np .clip (q_weight , minq , maxq , out = q_weight )
393+
394+ return q_weight .get (), scale .get (), zero_point .get ()
395+ else :
396+ logger .warning ("Try to use k-quant quantization on CUDA. However, CUDA is not available." \
397+ "Fall back to k-quant quantization on CPU." )
398+ return quant_tensor_k_quant_cpu (
399+ data , num_bits , group_size
400+ )
401+ except ImportError :
402+ logger .info (
403+ "Now we are using k-quant quantization on cpu, which is time consuming." \
404+ "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" \
405+ "Please also install torch to check CUDA availablity."
406+ )
407+ return quant_tensor_k_quant_cpu (
408+ data , num_bits , group_size
409+ )
410+
249411
250412def qdq_tensor (data , num_bits = 4 , group_size = 32 , scheme = "asym" , dtype = "int" , ratio = 1.0 ):
251413 """Quant dequant tensor per group.
@@ -299,6 +461,7 @@ def rtn_quantize(
299461 ratios = {},
300462 accuracy_level = 0 ,
301463 providers = ["CPUExecutionProvider" ],
464+ algorithm = "rtn" ,
302465):
303466 """Quant the model with round to nearst method.
304467
@@ -372,9 +535,15 @@ def rtn_quantize(
372535 ): # pragma: no cover
373536 # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP
374537 # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP
375- q_weight , scale , zp = quant_tensor (
376- weight .T , num_bits , group_size , scheme , "uint" , ratios .get (node .input [1 ], 1 )
377- )
538+ if algorithm == "k_quant" :
539+ q_weight , scale , zp = quant_tensor_k_quant_cuda (
540+ weight .T , num_bits , group_size
541+ )
542+ else :
543+ q_weight , scale , zp = quant_tensor (
544+ weight .T , num_bits , group_size , scheme , "uint" , ratios .get (node .input [1 ], 1 )
545+ )
546+
378547 q_matmul_node , new_inits = make_matmul_weight_only_node (
379548 node = node ,
380549 weight_shape = org_w_shape ,
0 commit comments