diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 70fe96429b..8ebe83cedc 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,5 +1,9 @@ +import types +from dataclasses import dataclass + import torch +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( TensorCoreTiledLayout, to_affine_quantized_intx, @@ -7,12 +11,18 @@ from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata from torchao.quantization.granularity import PerGroup -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization.quant_api import ( + _linear_extra_repr, + _replace_with_custom_fn_if_matches_filter, +) from torchao.quantization.quant_primitives import ( _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from .core import ( AWQObservedLinear, @@ -82,88 +92,86 @@ def replace_with_observer(layer): _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) -def _observed_linear_subclass_inserter(constructor): +@dataclass +class AWQUIntXConfig(AOBaseConfig): """ - Replaces unquantized AWQObservedLinear instances with quantized linear instances. + Configuration for quantizing linear layers when passed into quantize_() Args: - constructor: the function which applies quantization to the AWQObservedLinear layer + quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 + group_size: Quantization granularity. Use -1 for channel wise quantization + weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used """ - def insert_subclass(observed_linear): - # creates the new linear layer using constructor - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - observed_linear.bias != None, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.weight = torch.nn.Parameter( - constructor(observed_linear), requires_grad=False - ) - linear.bias = observed_linear.bias - return linear + quant_dtype: torch.dtype = torch.uint4 + group_size: int = 64 + use_hqq: bool = False - return insert_subclass +# for bc +awq_uintx = AWQUIntXConfig -def awq_uintx( - quant_dtype: torch.dtype = torch.uint4, - group_size: int = 64, - use_hqq: bool = False, -): - """ - Quantizes linear layers when passed into quantize_() - Args: - quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 - group_size: Quantization granularity. Use -1 for channel wise quantization - weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used - """ +@register_quantize_module_handler(AWQUIntXConfig) +def _awq_uintx_transform( + module: torch.nn.Module, + config: AWQUIntXConfig, +) -> torch.nn.Module: + quant_dtype = config.quant_dtype + group_size = config.group_size + use_hqq = config.use_hqq + observed_linear = module + assert ( quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8 ), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - def weight_quant_func(observed_linear): - equalization_scale = observed_linear.act_obs.calculate_qparams() - # AQT config - if quant_dtype == torch.uint4: - target_dtype = torch.int32 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - _layout = TensorCoreTiledLayout(inner_k_tiles=8) - else: - target_dtype = torch.uint8 - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - _layout = UintxLayout(quant_dtype) - - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] - qw = to_affine_quantized_intx( - observed_linear.weight * equalization_scale, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=_layout, - use_hqq=use_hqq, - ) + equalization_scale = observed_linear.act_obs.calculate_qparams() + # AQT config + if quant_dtype == torch.uint4: + target_dtype = torch.int32 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + _layout = TensorCoreTiledLayout(inner_k_tiles=8) + else: + target_dtype = torch.uint8 + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + _layout = UintxLayout(quant_dtype) - return to_weight_tensor_with_linear_activation_scale_metadata( - qw, equalization_scale - ) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] + qw = to_affine_quantized_intx( + observed_linear.weight * equalization_scale, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=_layout, + use_hqq=use_hqq, + ) + + qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) - return _observed_linear_subclass_inserter(weight_quant_func) + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias != None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = torch.nn.Parameter(qw, requires_grad=False) + linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.bias = observed_linear.bias + return linear