@@ -228,8 +228,7 @@ def test_quantization(self):
228228 ("int8wo" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4199 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
229229 ("int8dq" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4199 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
230230 ("uint4wo" , np .array ([0.4609 , 0.5234 , 0.5508 , 0.4199 , 0.4336 , 0.6406 , 0.4316 , 0.4531 , 0.5625 ])),
231- ("int_a8w8" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4199 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
232- ("uint_a16w7" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4219 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
231+ ("uint7wo" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4219 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
233232 ]
234233
235234 if TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
@@ -253,8 +252,8 @@ def test_quantization(self):
253252
254253 for quantization_name , expected_slice in QUANTIZATION_TYPES_TO_TEST :
255254 quant_kwargs = {}
256- if quantization_name in ["uint4wo" , "uint_a16w7 " ]:
257- # The dummy flux model that we use requires us to impose some restrictions on group_size here
255+ if quantization_name in ["uint4wo" , "uint7wo " ]:
256+ # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
258257 quant_kwargs .update ({"group_size" : 16 })
259258 quantization_config = TorchAoConfig (
260259 quant_type = quantization_name , modules_to_not_convert = ["x_embedder" ], ** quant_kwargs
0 commit comments