diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 83c60ccf12..68c25821ee 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -27,8 +27,15 @@ ) from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE from torchao.quantization.granularity import PerGroup +from torchao.quantization.qat import ( + FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, +) from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, + MappingType, _is_linear, int4_weight_only, quantize_, @@ -68,9 +75,9 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None): class M(nn.Module): - def __init__(self, m=256, n=128, k=16, bias=False): + def __init__(self, m=256, n=128, k=16, bias=False, embedding=True): super().__init__() - self.embedding = nn.Embedding(10, m) + self.embedding = nn.Embedding(10, m) if embedding else nn.Identity() self.linear1 = nn.Linear(m, n, bias=bias) self.linear2 = nn.Linear(n, k, bias=bias) self.relu = nn.ReLU() @@ -83,7 +90,11 @@ def reset_parameters(self): nn.init.zeros_(module.bias) def example_inputs(self, device=None): - return torch.randint(1, 10, (1, 256), device=device) + return ( + torch.randint(1, 10, (1, self.linear1.in_features), device=device) + if isinstance(self.embedding, nn.Embedding) + else torch.randn(1, self.linear1.in_features, device=device) + ) def forward(self, x): x = self.embedding(x) @@ -150,11 +161,11 @@ def compare_quantized_models( p = p.view(-1, group_size) q, Q = quantizer.quantize(p, b=b, dim=-1) - q = q.view(original_shape) # compare to AffineQuantizedTensor instance + q = q.view(original_shape) ref = getattr(m_ref, n).weight.dequantize() - self.assertTrue(q.equal(ref)) + torch.testing.assert_close(q, ref, atol=0, rtol=0) def compare_parq_convert( self, @@ -182,13 +193,13 @@ def compare_parq_convert( p = module.weight.dequantize() # PARQ weight after quantize_ p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ - self.assertTrue(p_orig.equal(p_ref)) - self.assertTrue(p.equal(p_ref)) + torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0) + torch.testing.assert_true(p, p_ref, atol=0, rtol=0) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize("group_size", [32, 256]) def test_int4_weight_only(self, group_size: int = 32): - model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE) + model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16) model.reset_parameters() m_ref = copy.deepcopy(model).eval().to(_DEVICE) @@ -265,8 +276,70 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): self.compare_parq_convert(model, m_ref, optimizer, config) +class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase): + def setUp(self): + torch.manual_seed(123) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("b", [2, 3, 4, 8]) + @common_utils.parametrize("model_dtype", [torch.float16, torch.float32]) + @common_utils.parametrize("group_size", [32, 128]) + def test_int8_dynamic_activation_intx_e2e( + self, + b: int = 2, + model_dtype: torch.dtype = torch.float32, + group_size: int = 32, + ): + model = M(embedding=False).to(_DEVICE, dtype=model_dtype) + x = model.example_inputs(device=_DEVICE).to(model_dtype) + + # reference model using native quantization + m_ref = copy.deepcopy(model).eval().to(_DEVICE) + quantizer = UnifTorchaoQuantizer() + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=_BIT_WIDTH_TO_DTYPE[b], + weight_granularity=PerGroup(group_size), + weight_mapping_type=quantizer.mapping_type, + act_mapping_type=MappingType.ASYMMETRIC, + ) + quantize_(m_ref, config) + ref_out = m_ref(x) + + # quantize weights with PARQ + base_optimizer = torch.optim.SGD(build_param_groups(model, b, group_size)) + optimizer = QuantOptimizer( + base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True + ) + optimizer.zero_grad() + optimizer.step() + + # apply torchao quantized activations on top + activation_config = FakeQuantizeConfig( + torch.int8, + granularity="per_token", + mapping_type=config.act_mapping_type, + ) + filter_fn = optimizer.get_filter_fn(model) + quantize_( + model, + IntXQuantizationAwareTrainingConfig(activation_config=activation_config), + filter_fn=filter_fn, + ) + out = model(x) + torch.testing.assert_close(out, ref_out, atol=0, rtol=0) + + # equivalent to torchao's convert step + model.eval() + optimizer.restore_latent_params() + quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn) + quantize_(model, config, filter_fn=filter_fn) + converted_out = model(x) + torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0) + + common_utils.instantiate_parametrized_tests(TestPARQuantization) common_utils.instantiate_parametrized_tests(TestUnifTorchaoQuantizer) +common_utils.instantiate_parametrized_tests(TestInt8DynamicActivationTorchaoQuantizer) if __name__ == "__main__": diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index a71ac8b5b3..4f90f9cb92 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -50,8 +50,23 @@ def __init__( self.quant_min = quant_min self.quant_max = quant_max self.eps = eps - self.preserve_zero = preserve_zero - self.zero_point_domain = zero_point_domain + + # defaults: zero_point_domain=ZeroPointDomain.INT, preserve_zero=True + self._choose_qparams = choose_qparams_affine + self._quantize = quantize_affine + self._dequantize = dequantize_affine + + if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: + self._choose_qparams = choose_qparams_affine_tinygemm + self._quantize = quantize_affine_tinygemm + self._dequantize = dequantize_affine_tinygemm + elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: + self._choose_qparams = choose_qparams_affine_dont_preserve_zero + self._quantize = quantize_affine + self._dequantize = dequantize_affine + elif zero_point_domain == ZeroPointDomain.NONE: + self._quantize = quantize_affine_no_zero_point + self._dequantize = dequantize_affine_no_zero_point def _init_quant_min_max(self, b: int) -> None: if self.quant_min is None or self.quant_max is None: @@ -74,24 +89,7 @@ def quantize( # assume that p has already been grouped in QuantOptimizer.step block_size = (1, p.size(-1)) if dim is not None else p.size() - if self.zero_point_domain == ZeroPointDomain.FLOAT and not self.preserve_zero: - _choose_qparams_affine = choose_qparams_affine_tinygemm - _quantize_affine = quantize_affine_tinygemm - _dequantize_affine = dequantize_affine_tinygemm - elif self.zero_point_domain == ZeroPointDomain.INT and not self.preserve_zero: - _choose_qparams_affine = choose_qparams_affine_dont_preserve_zero - _quantize_affine = quantize_affine - _dequantize_affine = dequantize_affine - else: # Default case: zero_point_domain == ZeroPointDomain.INT/NONE and preserve_zero - _choose_qparams_affine = choose_qparams_affine - if self.zero_point_domain == ZeroPointDomain.INT: - _quantize_affine = quantize_affine - _dequantize_affine = dequantize_affine - else: - _quantize_affine = quantize_affine_no_zero_point - _dequantize_affine = dequantize_affine_no_zero_point - - s, zero_point = _choose_qparams_affine( + s, zero_point = self._choose_qparams( p, self.mapping_type, block_size, @@ -101,13 +99,13 @@ def quantize( quant_max=self.quant_max, ) q_args = (block_size, s, zero_point, self.target_dtype) - q = _quantize_affine( + q = self._quantize( p, *q_args, quant_min=self.quant_min, quant_max=self.quant_max, ) - q = _dequantize_affine( + q = self._dequantize( q, *q_args, output_dtype=p.dtype, @@ -124,7 +122,7 @@ def quantize( else: block_size = Q.shape - Q = _dequantize_affine( + Q = self._dequantize( Q, block_size, *q_args[1:],