diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index cbf916c2aa..627a7be8f2 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -338,6 +338,7 @@ def main(checkpoint_path, baseline=False, fast=False, furious=False, + use_autoquant=False, unittest=False, benchmark=False, profile=None, @@ -366,13 +367,13 @@ def main(checkpoint_path, from torchao._models.sam2.build_sam import build_sam2 from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from torchao._models.sam2.utils.amg import rle_to_mask - + device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) - + logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}") sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) - + logging.info(f"Using {points_per_batch} points_per_batch") mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") @@ -409,6 +410,18 @@ def main(checkpoint_path, # NOTE: Not baseline feature mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 + # since autoquant is replicating what furious mode is doing, don't use these two together + elif use_autoquant: + from torchao import autoquant + from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + mask_generator.predictor.model = autoquant(mask_generator.predictor.model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + + mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40) + # NOTE: Not baseline feature + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision('high') + + with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) @@ -487,7 +500,7 @@ async def upload_rle(image: UploadFile = File(...)): await request_queue.put((image_tensor, response_future)) masks = await response_future return masks_to_rle_dict(masks) - + @app.post("/upload") async def upload_image(image: UploadFile = File(...)): image_tensor = file_bytes_to_image_tensor(bytearray(await image.read())) @@ -505,7 +518,7 @@ async def upload_image(image: UploadFile = File(...)): plt.savefig(buf, format='png') buf.seek(0) return StreamingResponse(buf, media_type="image/png") - + # uvicorn.run(app, host=host, port=port, log_level="info") uvicorn.run(app, host=host, port=port) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index ac2403d6dc..663db20b7b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1514,6 +1514,23 @@ def forward(self, x): assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) model(x_in) + @parameterized.expand(list(itertools.product(["cuda"], COMMON_DTYPES))) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_autoquant_min_sqnr(self, device, dtype): + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + out = model(example_input) + torchao.autoquant(model, min_sqnr=60) + out2 = model(example_input) + sqnr = SQNR(out, out2) + # without setting min_sqnr to 60, we get around 45-50 final sqnr + # setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr + self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 862f5d186d..549f10c7dc 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -379,6 +379,8 @@ def main( model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) elif "autoquant-float8" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + if "autoquant-fp" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs) else: model = autoquant(model, manual=True, example_input=inputs) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ff66e23cc9..344bdeea41 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -11,6 +11,7 @@ from .autoquant import ( DEFAULT_AUTOQUANT_CLASS_LIST, + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, autoquant, @@ -89,6 +90,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 87cb5e2655..1731b6cf39 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -18,7 +18,10 @@ MappingType, ZeroPointDomain, ) -from torchao.quantization.utils import quantize_activation_per_token_absmax +from torchao.quantization.utils import ( + compute_error, + quantize_activation_per_token_absmax, +) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from .granularity import ( @@ -36,6 +39,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", ] @@ -69,7 +73,15 @@ class AutoQuantizableLinearWeight(torch.Tensor): """ @staticmethod - def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + def __new__( + cls, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, + ): kwargs["device"] = weight.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else weight.layout @@ -82,12 +94,19 @@ def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwarg return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( - self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs + self, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, ): self.weight = weight self.qtensor_class_list = qtensor_class_list self.logged_data = {} self.mode = mode + self.min_sqnr = min_sqnr def __repr__(self): return ( @@ -123,9 +142,25 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): else torch.randn(bias_shape, dtype=act_dtype, device=self.device) ) try: - res = q_cls._autoquant_test( - act_mat, self.weight, bias, best_time, self.mode + ref_output = AQDefaultLinearWeight._quantized_linear_op( + act_mat, self.weight, bias ) + q_output = q_cls._quantized_linear_op( + act_mat, q_cls.from_float(self.weight), bias + ) + if ( + self.min_sqnr is not None + and (sqnr := compute_error(q_output, ref_output)) + < self.min_sqnr + ): + print( + f"skipping q_cls: {q_cls} because the sqnr is too small, minimum expected sqnr: {self.min_sqnr}, got {sqnr}" + ) + res = torch.inf + else: + res = q_cls._autoquant_test( + act_mat, self.weight, bias, best_time, self.mode + ) except Exception as e: print( f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}" @@ -141,7 +176,7 @@ def to_quantized(self, error_on_unseen, **kwargs): ) elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen - self = AQFloatLinearWeight.from_float(self.weight) + self = AQDefaultLinearWeight.from_float(self.weight) return self # only want to print shape (at start) and final result (at end) @@ -194,34 +229,49 @@ def count_shapes(self, do_print=True): print( f">time (all shapes): {cur_time:0.4f}ms for {q_cls}, prev_best: {best_time:0.4f}ms" ) - if best_time >= cur_time: + if cur_time != torch.inf and best_time >= cur_time: best_time = cur_time best_cls = q_cls # if no new benchmarking was done, don't print the final result, it will be the same as for another layer if ran_new_benchmarks: print(f"best_cls={best_cls}\n") + + if best_cls is None: + best_cls = AQDefaultLinearWeight + # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) return self def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode + fn(self.weight), + self.qtensor_class_list, + dtype=self.dtype, + mode=self.mode, + min_sqnr=self.min_sqnr, ) def __tensor_flatten__(self): - return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] + return ["weight"], [ + self.qtensor_class_list, + self.mode, + self.min_sqnr, + self.dtype, + self.shape, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): weight = tensor_data_dict["weight"] - qtensor_class_list, mode, dtype, shape = tensor_attributes + qtensor_class_list, mode, min_sqnr, dtype, shape = tensor_attributes return cls( weight, qtensor_class_list, - mode, + mode=mode, + min_sqnr=min_sqnr, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride, @@ -608,7 +658,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 -class AQFloatLinearWeight(torch.Tensor, AQMixin): +class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the @@ -629,6 +679,81 @@ def from_float(cls, weight): return weight +class AQFloat32LinearWeight(torch.Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(torch.float32), + w_qtensor, + bias.to(torch.float32) if bias is not None else bias, + ).to(dtype=orig_dtype) + + @classmethod + def from_float(cls, weight): + return weight.to(torch.float32) + + +class AQBFloat16LinearWeight(torch.Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(torch.bfloat16), + w_qtensor, + bias.to(torch.bfloat16) if bias is not None else bias, + ).to(dtype=orig_dtype) + + @classmethod + def from_float(cls, weight): + return weight.to(torch.bfloat16) + + +class AQFloat16LinearWeight(torch.Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(torch.float16), + w_qtensor, + bias.to(torch.float16) if bias is not None else bias, + ).to(dtype=orig_dtype) + + @classmethod + def from_float(cls, weight): + return weight.to(torch.float16) + + class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn @@ -742,7 +867,7 @@ def get_weight_block_size(x): # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, # AQInt8WeightOnlyQuantizedLinearWeight3, @@ -751,11 +876,17 @@ def get_weight_block_size(x): ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt4G64WeightOnlyQuantizedLinearWeight, ] +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, @@ -779,6 +910,7 @@ def _change_linears_to_autoquantizable(model, **kwargs): "qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST ) kwargs["mode"] = kwargs.get("mode", ["relu", None]) + kwargs["min_sqnr"] = kwargs.get("min_sqnr", None) from torchao.quantization.quant_api import ( _get_subclass_inserter, _replace_with_custom_fn_if_matches_filter, @@ -853,6 +985,7 @@ def autoquant( manual=False, set_inductor_config=True, supress_autoquant_errors=True, + min_sqnr=None, **aq_kwargs, ): """ @@ -887,6 +1020,9 @@ def autoquant( the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True) + min_sqnr (float, optional): minimum acceptable signal to quantization noise ration (https://en.wikipedia.org/wiki/Signal-to-quantization-noise_ratio) for output of quantized layer v.s. non-quantized layer, this is used to filter + out quantization methods that causes too large numerical impact, user can start with a resaonable + number like 40 and adjust depending on the result **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -919,6 +1055,7 @@ def autoquant( filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, + min_sqnr=min_sqnr, **aq_kwargs, )