diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 4e98ffd564..0e6a37c3b5 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -10,6 +10,7 @@ int8_dynamic_activation_int8_semi_sparse_weight, float8_weight_only, ) +from torchao.quantization.quant_primitives import MappingType from torchao.dtypes import SemiSparseLayout from torch.testing._internal import common_utils from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -26,6 +27,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight(), + int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: base_functions.append(int4_weight_only(group_size=32)) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 310f51f897..458cd07810 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -47,6 +47,7 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, ) from pathlib import Path from torchao._models.llama.tokenizer import get_tokenizer @@ -576,6 +577,7 @@ def test_quantized_tensor_subclass_int8_wo(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") def test_quantized_tensor_subclass_int8_dyn_quant(self): # use multiples of 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") @@ -732,8 +734,8 @@ def test_multitensor_pad_unpad(self): self.assertEqual(mt.count, 3) mt.unpad() self.assertEqual(mt.count, 1) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ_MT import MultiTensor @@ -742,7 +744,7 @@ def test_multitensor_inplace_operation(self): mt += 1 # In-place addition self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2))) - + common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 91803fe3f7..58855f39d9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -38,6 +38,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, unwrap_tensor_subclass, ) from .subclass import ( @@ -480,7 +481,10 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: """ mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int8 - return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype) + if TORCH_VERSION_AT_LEAST_2_6: + return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + else: + return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype) def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32, mapping_type=MappingType.SYMMETRIC): """This is defined here instead of local function to support serialization @@ -586,7 +590,7 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) -def int8_dynamic_activation_int8_weight(layout=PlainLayout()): +def int8_dynamic_activation_int8_weight(layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers @@ -609,7 +613,10 @@ def get_weight_block_size(x): zero_point_dtype = torch.int64 # input settings - input_quant_func = _int8_symm_per_token_reduced_range_quant + if act_mapping_type == MappingType.SYMMETRIC: + input_quant_func = _int8_symm_per_token_reduced_range_quant + else: + input_quant_func = _int8_asymm_per_token_quant block_size = get_weight_block_size(weight) weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout)