From ac57a0c9f5c7c66bed5bdc98da9b6ba5b1a47cdb Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 17 May 2024 14:39:54 -0700 Subject: [PATCH] Add `WeightQuantizer` and `DynamicActQuantizer` Summary: This exposes the AffineQuantizedTensor and LinearActQuantizedTensor subclass as a model level API that will replace the weights of linear layers This is in preparation to replace existing tensor subclass APIs such as `change_linear_weights_to_int4_woqtensors` but currently we can't combine the two quantizers due to some problem with parametrization/nn.Parameter the error is: raise KeyError(f"attribute '{name}' already exists") KeyError: "attribute 'weight' already exists" happens in ``` lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False) ``` Test Plan: regression tests: ``` python test/quantization/test_quant_api.py ``` Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 115 +++++++++++++++------------- torchao/quantization/quant_api.py | 50 +++++++++++- torchao/quantization/subclass.py | 35 ++------- torchao/quantization/utils.py | 49 ++++++++++++ 4 files changed, 166 insertions(+), 83 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 8cceefb0a8..f0830cf8a8 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -18,12 +18,24 @@ get_symmetric_quantization_config, ) +from torchao.quantization.subclass import ( + to_aqt, + to_laqt, + AffineQuantizedTensor, + LinearActQuantizedTensor, +) +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, apply_dynamic_quant, apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, + quantize, ) from torchao.quantization.utils import ( TORCH_VERSION_AFTER_2_3, @@ -32,6 +44,7 @@ from pathlib import Path from sentencepiece import SentencePieceProcessor from model import Transformer, prepare_inputs_for_model +import copy def dynamic_quant(model, example_inputs): @@ -92,8 +105,8 @@ def __init__(self, m=64, n=32, k=64): self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) - def example_inputs(self): - return (torch.randn(1, self.linear1.in_features).to(torch.float),) + def example_inputs(self, batch_size=1): + return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -395,13 +408,6 @@ def test_eval_wrapper(self): # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") def test_quantized_tensor_subclass_8da4w(self): - from torchao.quantization.subclass import ( - AffineQuantizedTensor, - LinearActQuantizedTensor, - ) - from torchao.quantization.quant_primitives import MappingType - import copy - # weight settings groupsize = 32 mapping_type = MappingType.SYMMETRIC @@ -423,20 +429,26 @@ def get_per_token_block_size(x): # input settings input_mapping_type = MappingType.ASYMMETRIC input_target_dtype = torch.int8 - input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) - - def dynamic_quant(linear): - # note: order is important - linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - dynamic_quant(m.linear1) - dynamic_quant(m.linear2) + + def apply_weight_quant(weight): + return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) + + def apply_act_quant(weight): + return to_laqt(weight, input_quant_func) + + # note: order is important + m = quantize(m, apply_weight_quant) + m = quantize(m, apply_act_quant) + assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) + assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) # reference from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -454,11 +466,6 @@ def dynamic_quant(linear): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int4(self): - from torchao.quantization.subclass import AffineQuantizedTensor - from torchao.quantization.quant_primitives import MappingType - from torchao.quantization.quant_primitives import ZeroPointDomain - import copy - # weight settings groupsize = 32 mapping_type = MappingType.ASYMMETRIC @@ -469,22 +476,17 @@ def test_quantized_tensor_subclass_int4(self): eps = 1e-6 preserve_zero = False zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) - def to_quantized(weight): - return AffineQuantizedTensor.from_float( - weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + def apply_weight_quant(weight): + return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) + m = quantize(m, apply_weight_quant) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -501,10 +503,6 @@ def to_quantized(weight): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8(self): - from torchao.quantization.subclass import AffineQuantizedTensor - from torchao.quantization.quant_primitives import MappingType - import copy - # weight settings mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -515,12 +513,12 @@ def test_quantized_tensor_subclass_int8(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - def to_quantized(weight): + def apply_weight_quant(weight): block_size = (1, weight.shape[1]) - return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + + m = quantize(m, apply_weight_quant) - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -537,12 +535,6 @@ def to_quantized(weight): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): - from torchao.quantization.subclass import AffineQuantizedTensor - from torchao.quantization.subclass import LinearActQuantizedTensor - from torchao.quantization.quant_primitives import MappingType - from torchao.quantization.quant_primitives import ZeroPointDomain - import copy - # weight settings mapping_type = MappingType.SYMMETRIC def get_weight_block_size(x): @@ -563,20 +555,24 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) + # setting batch_size to 20 to be compatible with the kernel + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20))) + + def apply_weight_quant(weight): + block_size = get_weight_block_size(weight) + return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - def dynamic_quant(linear): - # note: order is important - linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + def apply_act_quant(weight): + return to_laqt(weight, input_quant_func) + + m = quantize(m, apply_weight_quant) + m = quantize(m, apply_act_quant) - dynamic_quant(m.linear1) - dynamic_quant(m.linear2) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) @@ -591,6 +587,19 @@ def dynamic_quant(linear): self.assertTrue(torch.equal(res, ref)) + # workaround for export path + from torchao.quantization.utils import unwrap_tensor_subclass + m_unwrapped = unwrap_tensor_subclass(m) + + m = torch.export.export(m_unwrapped, example_inputs).module() + exported_model_res = m(*example_inputs) + + self.assertTrue(torch.equal(exported_model_res, ref)) + + # make sure it compiles + torch._export.aot_compile(m_unwrapped, example_inputs) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a5a3a2b3db..d73742140e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Any, Callable from .dynamic_quant import DynamicallyPerAxisQuantizedLinear from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 @@ -48,7 +49,8 @@ "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", "Int4WeightOnlyQuantizer", - "autoquant" + "quantize", + "autoquant", ] if TORCH_VERSION_AFTER_2_3: @@ -214,3 +216,49 @@ def replace_conv2d_1x1(conv): _replace_with_custom_fn_if_matches_filter( model, replace_conv2d_1x1, filter_fn=filter_fn ) + + +def _get_linear_subclass_inserter(constructor): + def insert_subclass(lin): + lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False) + return lin + + return insert_subclass + +def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module: + """Convert the weight of linear modules in the model with `apply_tensor_subclass` + + Args: + model: input model + apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance + filter_fn: used to filter out the modules that we don't want to apply tenosr subclass + + Example:: + + # weight settings + groupsize = 32 + mapping_type = MappingType.ASYMMETRIC + block_size = (1, groupsize) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + + apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + + # apply to modules under block0 submodule + def filter_fn(module, fqn): + return fqn == "block0" + + m = MyModel(...) + m = quantize(m, apply_weight_quant, filter_fn) + """ + _replace_with_custom_fn_if_matches_filter( + model, + _get_linear_subclass_inserter(apply_tensor_subclass), + _is_linear if filter_fn is None else filter_fn, + ) + return model diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 8d0af8b369..6e844530d4 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -35,6 +35,7 @@ "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", "AffineQuantizedTensor", + "LinearActQuantizedTensor", ] @@ -266,7 +267,6 @@ def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs): return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs): - self.q_scales = q_scales super().__init__(int_data, transposed) @@ -629,32 +629,6 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles -def to_aqt( - input_float, - mapping_type, - block_size, - target_dtype, - quant_min = None, - quant_max = None, - eps = None, - scale_dtype = None, - zero_point_dtype = None, - preserve_zero = True, - zero_point_domain = ZeroPointDomain.INT, -): - return AffineQuantizedTensor.from_float( - input_float, - mapping_type, - block_size, - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain - ) # TODO: merge with nf4 implements decorator # aten op to their __torch_dispatch__ implemnetations for the tensor subclass @@ -777,7 +751,7 @@ def dequantize(self, output_dtype=None): return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( @@ -1091,7 +1065,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): original_weight_tensor = tensor_data_dict["original_weight_tensor"] - input_quant_func = tensor_attributes + input_quant_func, = tensor_attributes return cls( original_weight_tensor, input_quant_func, @@ -1176,3 +1150,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise NotImplementedError( f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) + +to_aqt = AffineQuantizedTensor.from_float +to_laqt = LinearActQuantizedTensor.from_float diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index a178edf125..11756ad616 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -10,6 +10,7 @@ from packaging import version from functools import reduce from math import gcd +import torch.nn.utils.parametrize as parametrize __all__ = [ @@ -17,6 +18,7 @@ "compute_error", "_apply_logging_hook", "get_model_size_in_bytes", + "unwrap_tensor_subclass", "TORCH_VERSION_AFTER_2_3", ] @@ -88,6 +90,53 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return rs +class UnwrapTensorSubclass(torch.nn.Module): + def forward(self, *tensors): + todo = list(tensors) + for tp, meta, inner_tensors in reversed(self.rebuild_stack): + nb_tensor = len(inner_tensors) + inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} + todo = todo[nb_tensor:] + rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None) + todo.append(rebuilt) + + assert len(todo) == 1 + return todo[0] + + def right_inverse(self, tensor): + assert type(tensor) is not torch.Tensor + rebuild_stack = [] + plain_tensors = [] + todo = [tensor] + while todo: + obj = todo.pop() + inner_tensors, metadata = obj.__tensor_flatten__() + rebuild_stack.append((type(obj), metadata, inner_tensors)) + for attr_name in inner_tensors: + val = getattr(obj, attr_name) + if type(val) is torch.Tensor: + plain_tensors.append(val) + else: + assert isinstance(val, torch.Tensor) + todo.append(val) + + self.rebuild_stack = rebuild_stack + + return plain_tensors + +def unwrap_tensor_subclass(model, filter_fn=None): + for name, child in model.named_children(): + if ( + isinstance(child, torch.nn.Linear) and + hasattr(child, "weight") and + type(child.weight) is not torch.Tensor and + isinstance(child.weight, torch.Tensor) + ): + parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) + unwrap_tensor_subclass(child) + return model + + # https://discuss.pytorch.org/t/finding-model-size/130275