diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index ef1714808c..329d3fb1f2 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -107,12 +107,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, w) +class _LegacyQATQuantizer(TwoStepQuantizer): + """ + Base class for sharing common methods across legacy QAT quantizers. + """ + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + return None + + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + return None + + # ========================================================= # | Linear int8 dynamic activations + int4 weight QAT | # ========================================================= -class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): +class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int8 dynamic per token fake quantized activations and int4 fake quantized @@ -189,6 +200,12 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): else: self._convert_qat_linear_8da4w(child) + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + return _get_8da4w_activation_config(self.scales_precision) + + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + return _get_8da4w_weight_config(self.groupsize, self.scales_precision) + class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): """ @@ -211,22 +228,8 @@ def __init__( precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: - activation_config = FakeQuantizeConfig( - dtype=torch.int8, - granularity="per_token", - is_symmetric=False, - is_dynamic=True, - scale_precision=scales_precision, - zero_point_precision=scales_precision, - ) - weight_config = FakeQuantizeConfig( - dtype=TorchAODType.INT4, - group_size=groupsize, - is_symmetric=True, - is_dynamic=True, - scale_precision=scales_precision, - zero_point_precision=scales_precision, - ) + activation_config = _get_8da4w_activation_config(scales_precision) + weight_config = _get_8da4w_weight_config(groupsize, scales_precision) super().__init__( in_features, out_features, @@ -261,12 +264,43 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module): mod.disable_fake_quant() +def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig: + """ + Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + """ + return FakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=False, + is_dynamic=True, + scale_precision=qparams_precision, + zero_point_precision=qparams_precision, + ) + + +def _get_8da4w_weight_config( + group_size: int, + qparams_precision: torch.dtype, +) -> FakeQuantizeConfig: + """ + Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + """ + return FakeQuantizeConfig( + dtype=TorchAODType.INT4, + group_size=group_size, + is_symmetric=True, + is_dynamic=True, + scale_precision=qparams_precision, + zero_point_precision=qparams_precision, + ) + + # =================================== # | Linear int4 weight-only QAT | # =================================== -class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): +class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int4 fake quantized grouped per channel weights. @@ -348,6 +382,9 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): else: self._convert_qat_linear_4w(child) + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + return _get_4w_weight_config(self.groupsize, self.scales_precision) + class Int4WeightOnlyQATLinear(FakeQuantizedLinear): """ @@ -376,15 +413,7 @@ def __init__( if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): raise ValueError("Padding for QAT 4w is not supported yet") self.inner_k_tiles = inner_k_tiles - weight_config = FakeQuantizeConfig( - dtype=torch.uint4, - group_size=groupsize, - is_symmetric=False, - is_dynamic=True, - scale_precision=scales_precision, - zero_point_precision=scales_precision, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + weight_config = _get_4w_weight_config(groupsize, scales_precision) super().__init__( in_features, out_features, @@ -417,3 +446,21 @@ def disable_4w_fake_quant(mod: torch.nn.Module): """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.disable_fake_quant() + + +def _get_4w_weight_config( + group_size: int, + qparams_precision: torch.dtype, +) -> FakeQuantizeConfig: + """ + Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. + """ + return FakeQuantizeConfig( + dtype=torch.uint4, + group_size=group_size, + is_symmetric=False, + is_dynamic=True, + scale_precision=qparams_precision, + zero_point_precision=qparams_precision, + zero_point_domain=ZeroPointDomain.FLOAT, + )