diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index d655abaf62..7444c3dbb5 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -9,6 +9,7 @@ import copy import unittest +from typing import List import torch import torch.nn.functional as F @@ -26,7 +27,9 @@ from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, + IntXQuantizationAwareTrainingConfig, from_intx_quantization_aware_training, + initialize_fake_quantizers, intx_quantization_aware_training, ) from torchao.quantization.qat.embedding import ( @@ -99,6 +102,16 @@ def __init__(self): def example_inputs(self): return (torch.randn(1, 512).to(torch.float),) + def _get_all_weight_qparams(self) -> List[torch.Tensor]: + return [ + self.linear1.weight_fake_quantizer.scale, + self.linear1.weight_fake_quantizer.zero_point, + self.sub.linear.weight_fake_quantizer.scale, + self.sub.linear.weight_fake_quantizer.zero_point, + self.linear2.weight_fake_quantizer.scale, + self.linear2.weight_fake_quantizer.zero_point, + ] + def forward(self, x): x = self.linear1(x) x = self.sub(x) @@ -996,6 +1009,21 @@ def test_fake_quantize_config_dtype(self): FakeQuantizeConfig(TorchAODType.INT7, "per_token") FakeQuantizeConfig(torch.int8, "per_token") + def test_fake_quantize_config_dynamic_and_range_learning(self): + """ + Test that `is_dynamic` and `range_learning` cannot both be set. + """ + FakeQuantizeConfig( + torch.int8, "per_channel", is_dynamic=True, range_learning=False + ) + FakeQuantizeConfig( + torch.int8, "per_channel", is_dynamic=False, range_learning=True + ) + with self.assertRaisesRegex(ValueError, "not compatible"): + FakeQuantizeConfig( + torch.int8, "per_channel", is_dynamic=True, range_learning=True + ) + @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) @@ -1591,6 +1619,95 @@ def test_qat_8da4w_eps(self): actual_out = converted_model.linear1(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_fake_quantizer_range_learning(self): + """ + Test that range learning requires `FakeQuantizer`s to be initialized correctly. + """ + config = FakeQuantizeConfig( + torch.int8, + "per_channel", + is_dynamic=False, + range_learning=True, + scale_precision=torch.float32, + zero_point_precision=torch.float32, + ) + fake_quantizer = FakeQuantizer(config) + example_inputs = (torch.randn(2, 3),) + + # Not initialized, should fail + self.assertFalse(fake_quantizer._initialized) + self.assertIsNone(fake_quantizer.scale) + self.assertIsNone(fake_quantizer.zero_point) + with self.assertRaisesRegex( + ValueError, + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " + "before initializing the optimizer and beginning training.", + ): + fake_quantizer(*example_inputs) + + # Should pass after initializing + initialize_fake_quantizers(fake_quantizer, example_inputs) + self.assertTrue(fake_quantizer._initialized) + self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter) + self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter) + self.assertTrue(fake_quantizer.scale.requires_grad) + self.assertTrue(fake_quantizer.zero_point.requires_grad) + fake_quantizer(*example_inputs) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_range_learning(self): + """ + Test end-to-end QAT flow with range learning. + """ + config = FakeQuantizeConfig( + torch.int8, + "per_channel", + is_dynamic=False, + range_learning=True, + scale_precision=torch.float32, + zero_point_precision=torch.float32, + ) + m = M() + example_inputs = m.example_inputs() + quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) + + # Not initialized, should fail + for t in m._get_all_weight_qparams(): + self.assertIsNone(t) + with self.assertRaisesRegex( + ValueError, + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " + "before initializing the optimizer and beginning training.", + ): + m(*example_inputs) + + # Should pass after initializing + # All scales and zero points should be in `m.parameters()` + initialize_fake_quantizers(m, example_inputs) + params = set(m.parameters()) + for t in m._get_all_weight_qparams(): + self.assertIsInstance(t, torch.nn.Parameter) + self.assertTrue(t.requires_grad) + self.assertTrue(t in params) + m(*example_inputs) + + # Simulate training + optimizer = torch.optim.SGD( + m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + loss_fn = torch.nn.CrossEntropyLoss() + target = torch.randn(1, 512).float() + out = m(*example_inputs) + loss = loss_fn(out, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if __name__ == "__main__": unittest.main() diff --git a/third_party/cutlass b/third_party/cutlass index e94e888df3..afa1772203 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit e94e888df3551224738bfa505787b515eae8352f +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 5dc3d8e008..010ccfc8cc 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -4,6 +4,7 @@ FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, from_intx_quantization_aware_training, + initialize_fake_quantizers, intx_quantization_aware_training, ) from .embedding import ( @@ -17,11 +18,12 @@ __all__ = [ "ComposableQATQuantizer", "FakeQuantizeConfig", - "Int4WeightOnlyQATQuantizer", + "FromIntXQuantizationAwareTrainingConfig", "Int4WeightOnlyEmbeddingQATQuantizer", + "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", + "IntXQuantizationAwareTrainingConfig", + "initialize_fake_quantizers", "intx_quantization_aware_training", "from_intx_quantization_aware_training", - "FromIntXQuantizationAwareTrainingConfig", - "IntXQuantizationAwareTrainingConfig", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index e025a43d94..8fba195363 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import torch @@ -51,7 +51,8 @@ class FakeQuantizeConfig: zero_point_precision: zero point dtype (default torch.int32) zero_point_domain: whether zero point is in integer (default) or float domain is_dynamic: whether to use dynamic (default) or static scale and zero points - range_learning: whether to learn scale and zero points during training (coming soon) + range_learning: whether to learn scale and zero points during training + (default false), not compatible with `is_dynamic`. kwargs (optional): group_size: size of each group in per group fake quantization, @@ -123,6 +124,10 @@ def __init__( "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) ) + # Dynamic is not compatible with range learning + if is_dynamic and range_learning: + raise ValueError("`is_dynamic` is not compatible with `range_learning`") + def _get_granularity( self, granularity: Union[Granularity, str, None], @@ -394,3 +399,23 @@ def convert( for quantizer in self.quantizers: model = quantizer.convert(model) return model + + +def initialize_fake_quantizers( + model: torch.nn.Module, + example_inputs: Tuple[Any, ...], +) -> None: + """ + Initialize the scales and zero points on all + :class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer` + in the model based on the provided example inputs. + """ + # avoid circular dependencies + from torchao.quantization.qat.fake_quantizer import FakeQuantizer + + def _set_initialized(m: torch.nn.Module): + if isinstance(m, FakeQuantizer): + m._initialized = True + + model.apply(_set_initialized) + model(*example_inputs) diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 2770956a2c..aec23712ed 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -92,6 +92,7 @@ def to_embedding(self) -> torch.nn.Embedding: self.scale_grad_by_freq, self.sparse, device=self.weight.device, + dtype=self.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to @@ -116,6 +117,7 @@ def from_embedding( mod.sparse, weight_config=weight_config, device=mod.weight.device, + dtype=mod.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 0d2521cac0..90206b5d6e 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -31,6 +31,7 @@ from .utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, + _Round, ) @@ -46,11 +47,12 @@ def __init__(self, config: FakeQuantizeConfig): self.scale: Optional[torch.Tensor] = None self.zero_point: Optional[torch.Tensor] = None - # TODO: support range learinng - if self.config.range_learning: - raise NotImplementedError("Range learning is not supported yet") + # For range learning only + # TODO: make this configurable? + self._scale_eps = 1e-9 + self._initialized = False - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply fake quantization to the tensor based on the bit-width, granularity, symmetry, and other properties specified in the config. @@ -58,6 +60,17 @@ def forward(self, x: torch.Tensor): if not self.enabled: return x + if ( + self.config.range_learning + and not self._initialized + and (self.scale is None or self.zero_point is None) + ): + raise ValueError( + "Scales and zero points must be initialized for range learning. " + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " + "before initializing the optimizer and beginning training." + ) + if isinstance(self.config.granularity, PerToken): return self._per_token_forward(x) elif isinstance(self.config.granularity, (PerAxis, PerGroup)): @@ -65,13 +78,12 @@ def forward(self, x: torch.Tensor): else: raise ValueError("Unknown granularity '%s'" % self.config.granularity) - def _per_token_forward(self, x: torch.Tensor): + def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform per token fake quantization on the tensor. """ if self.config.is_symmetric: raise NotImplementedError("Symmetric per token is not supported yet") - qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] if self._should_compute_qparams(): self.scale, self.zero_point = choose_qparams_affine( @@ -85,9 +97,10 @@ def _per_token_forward(self, x: torch.Tensor): scale_dtype=self.config.scale_precision, zero_point_dtype=self.config.zero_point_precision, ) + self._maybe_update_qparams_for_range_learning() return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) - def _per_channel_or_group_forward(self, x: torch.Tensor): + def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform per channel or per group fake quantization on the tensor. We express per channel using per group where the group size is the size @@ -129,6 +142,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor): eps=self.config.eps, ) self.zero_point = self.zero_point.to(zero_point_precision) + self._maybe_update_qparams_for_range_learning() qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] return _fake_quantize_per_channel_group( @@ -147,6 +161,26 @@ def _should_compute_qparams(self) -> bool: """ return self.config.is_dynamic or self.scale is None or self.zero_point is None + def _maybe_update_qparams_for_range_learning(self) -> None: + """ + If range learning is enabled, turn scales and zero points into trainable parameters. + This function is idempotent and should only be called once. + """ + if ( + not self.config.range_learning + or isinstance(self.scale, torch.nn.Parameter) + or isinstance(self.zero_point, torch.nn.Parameter) + ): + return + scale, zero_point = self.scale, self.zero_point + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] + # Stabilize range learning + scale = torch.clamp(scale, min=self._scale_eps) + zero_point = _Round.apply(zero_point) + zero_point = torch.clamp(zero_point, qmin, qmax) + self.scale = torch.nn.Parameter(scale, requires_grad=True) + self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True) + def __repr__(self) -> str: """ Return a human readable representation of this `FakeQuantizer` with config details. diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index a912f04b83..7c32bc4b19 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -18,6 +18,7 @@ _replace_linear_int4, groupwise_affine_quantize_tensor, ) +from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_primitives import ( TorchAODType, ZeroPointDomain, @@ -83,12 +84,13 @@ def __init__( # initialize weight fake quantizer if weight_config is not None: - group_size = weight_config.group_size - if group_size is not None and in_features % group_size != 0: - raise ValueError( - "in_features (%s) %% group_size (%s) must be == 0" - % (in_features, group_size) - ) + if isinstance(weight_config.granularity, PerGroup): + group_size = weight_config.group_size + if group_size is not None and in_features % group_size != 0: + raise ValueError( + "in_features (%s) %% group_size (%s) must be == 0" + % (in_features, group_size) + ) self.weight_fake_quantizer = FakeQuantizer(weight_config) else: self.weight_fake_quantizer = None @@ -108,6 +110,7 @@ def to_linear(self) -> torch.nn.Linear: self.out_features, self.bias is not None, device=self.weight.device, + dtype=self.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to @@ -131,6 +134,7 @@ def from_linear( activation_config=activation_config, weight_config=weight_config, device=mod.weight.device, + dtype=mod.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index 12e9097ada..71e9a96ec5 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -91,6 +91,20 @@ def backward(ctx, gy): return (gy,) +class _Round(torch.autograd.Function): + """ + Implementation of generic round operation with backward STE. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + return torch.round(x) + + @staticmethod + def backward(ctx, gy: torch.Tensor) -> torch.Tensor: + return gy + + def _fake_quantize_per_channel_group( input: torch.Tensor, scales: torch.Tensor,