4040ONNXRT1161_VERSION = Version ("1.16.1" )
4141
4242
43- def get_blob_size (group_size , has_zp ): # pragma: no cover
43+ def get_blob_size (group_size , num_bits , has_zp ): # pragma: no cover
4444 """Get blob_size.
4545
4646 Args:
4747 group_size (int): how many elements share one scale/zp
4848 has_zp (bool): whether zero_point is None
4949 """
5050 if Version (ort .__version__ ) > ONNXRT1161_VERSION :
51- blob_size = group_size // 2
51+ blob_size = group_size * num_bits // 8
5252 elif has_zp :
53- blob_size = group_size // 2 + 4 + 1
53+ blob_size = group_size * num_bits // 8 + 4 + 1
5454 else :
55- blob_size = group_size // 2 + 4
55+ blob_size = group_size * num_bits // 8 + 4
5656 return blob_size
5757
5858
@@ -86,7 +86,7 @@ def make_matmul_weight_only_node(
8686 matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
8787 new_inits: initializers of the new node
8888 """
89- blob_size = get_blob_size (group_size , zero_point is not None )
89+ blob_size = get_blob_size (group_size , num_bits , zero_point is not None )
9090 packed = np .zeros ((q_weight .shape [0 ], blob_size ), dtype = "uint8" )
9191 q_weight_name = node .input [1 ] + "_Q{}G{}" .format (str (num_bits ), str (group_size ))
9292 input_names = [node .input [0 ], q_weight_name ]
@@ -97,8 +97,16 @@ def make_matmul_weight_only_node(
9797 op_type = "MatMulNBits"
9898
9999 # pack quantized weight
100- q_weight_pairs = q_weight [:, ::2 ] | q_weight [:, 1 ::2 ] << 4
101- packed [:, :] = q_weight_pairs [:, :blob_size ]
100+ if num_bits == 4 :
101+ q_weight_pairs = q_weight [:, ::2 ] | q_weight [:, 1 ::2 ] << 4
102+ packed [:, :] = q_weight_pairs [:, :blob_size ]
103+ elif num_bits == 8 :
104+ packed = q_weight
105+ else :
106+ logger .error (
107+ "MatMulNBits does not have kernel support for num_bits = {}." .format (num_bits )
108+ )
109+
102110 packed = np .reshape (packed , (- 1 , k_blocks , blob_size ))
103111
104112 # build scale tensor
@@ -115,8 +123,10 @@ def make_matmul_weight_only_node(
115123
116124 # build zero_point tensor
117125 if zero_point is not None :
118- if num_bits > 4 :
119- packed_zp = np .reshape (zero_point , (1 , - 1 )).astype ("uint8" )
126+ if num_bits == 8 :
127+ packed_zp = zero_point .astype ("uint8" )
128+ elif num_bits > 4 :
129+ packed_zp = np .reshape (zero_point , (scale .shape [0 ], - 1 )).astype ("uint8" )
120130 else :
121131 packed_zp = np .full ((zero_point .shape [0 ] + 1 ) // 2 , 136 , dtype = "uint8" )
122132 # create an index array
@@ -463,7 +473,7 @@ def rtn_quantize(
463473 ratios = {},
464474 accuracy_level = 0 ,
465475 providers = ["CPUExecutionProvider" ],
466- algorithm = "rtn " ,
476+ algorithm = "k_quant " ,
467477):
468478 """Quant the model with round to nearst method.
469479
@@ -527,7 +537,8 @@ def rtn_quantize(
527537
528538 weight = pad_tensor (weight , group_size , k_blocks )
529539
530- satisfy_MatMulNBits_condition = Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4
540+ enable_MatMulNBits_8bits = True
541+ satisfy_MatMulNBits_condition = (Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4 ) or (enable_MatMulNBits_8bits and num_bits == 8 )
531542 satisfy_MatMulFpQ4_condition = (
532543 Version (ort .__version__ ) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
533544 )
0 commit comments