From 0e5fc3eadec5fc0d91ad9fd07ecbdc11dfec2cef Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 21 Jun 2024 20:57:08 -0700 Subject: [PATCH 01/13] adding default inductor config settings Summary: making autoquant and quantize apis call a new recommended_inductor_config_setter util to set recommended apis also update groupsize -> groupsize in generate.py Test Plan: sh benchmarks.sh comparison of different config combinations for matmul precision, mixed_mm and coordinate_descent tok/s= 9.14, mem/s= 60.55 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=147.02, mem/s= 973.53 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.23, mem/s= 61.11 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=139.59, mem/s= 924.33 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.10, mem/s= 60.26 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=146.98, mem/s= 973.23 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.28, mem/s= 61.48 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=146.90, mem/s= 972.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.08, mem/s= 60.09 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=137.58, mem/s= 911.00 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.19, mem/s= 60.87 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=166.02, mem/s=1099.30 GB/s, peak_mem= 8.97 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 6 +++--- torchao/_models/llama/eval.py | 4 +--- torchao/_models/llama/generate.py | 14 ++++---------- torchao/quantization/README.md | 10 +++------- torchao/quantization/autoquant.py | 9 ++++++++- torchao/quantization/quant_api.py | 6 +++++- torchao/quantization/utils.py | 18 ++++++++++++++++++ 7 files changed, 42 insertions(+), 25 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b4fbcb152a..b2ef3c07ec 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -98,21 +98,21 @@ def _int8wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_weight_only()) + quantize(mod, int8_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) def _int8da_int8w_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_dynamic_activation_int8_weight()) + quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4_weight_only()) + quantize(mod, int4_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 36e5085018..d19a038415 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -22,9 +22,7 @@ import time from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer from torchao._models.llama.model import prepare_inputs_for_model - -torch._inductor.config.fx_graph_cache = True -torch._inductor.config.force_fuse_int_mm_with_mul = True +torchao.quantization.utils.recommended_inductor_config_setter() def run_evaluation( checkpoint_path: Path, diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 34e7ca82b2..ea20799d75 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -13,6 +13,7 @@ import torch._dynamo.config import torch._inductor.config from torchao.utils import get_model_size_in_bytes +torchao.quantization.utils.recommended_inductor_config_setter() def device_sync(device): if "cuda" in device: @@ -22,13 +23,6 @@ def device_sync(device): else: print(f"device={device} is not yet suppported") - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future -torch._inductor.config.force_fuse_int_mm_with_mul = True -# torch._inductor.config.use_mixed_mm = True - default_device = 'cuda' if torch.cuda.is_available() else 'cpu' # support running without installing as a package @@ -203,7 +197,7 @@ def main( if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4_weight_only(groupsize=groupsize)) + quantize(model, int4_weight_only(group_size=groupsize)) if "autoquant" == quantization: model = autoquant(model, manual=True) @@ -339,8 +333,8 @@ def callback(x): parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument("--quantization", type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') parser.add_argument('--profile', type=Path, default=None, help='Profile path.') diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index a6e95d0bed..086065b8da 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -30,10 +30,6 @@ of the activations that the different linear layers see, it then benchmarks thes import torch import torchao -# inductor settings which improve torch.compile performance for quantized modules -torch._inductor.config.force_fuse_int_mm_with_mul = True -torch._inductor.config.use_mixed_mm = True - # Plug in your model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') @@ -107,9 +103,6 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') group_size = 32 m = quantize(m, int4_weight_only(group_size=group_size)) -torch._inductor.config.force_fuse_int_mm_with_mul = True -torch._inductor.config.use_mixed_mm = True - # temporary workaround for tensor subclass + torch.compile from torchao.quantization.utils import unwrap_tensor_subclass m = unwrap_tensor_subclass(m) @@ -163,6 +156,9 @@ m = torch.export.export(m_unwrapped, example_inputs).module() torch._export.aot_compile(m_unwrapped, example_inputs) ``` +### Automatic Inductor Configuration +The `quantize` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. + ### Other Available Quantization Techniques #### A8W8 Dynamic Quantization diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 18a58cd17f..40c333e322 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -1,4 +1,5 @@ import torch +import torchao from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -443,8 +444,10 @@ def autoquant( model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, - filter_fn=None, mode=["interpolate", .85], + filter_fn=None, + mode=["interpolate", .85], manual=False, + set_inductor_config=True, **aq_kwargs ): """ @@ -477,6 +480,7 @@ def autoquant( and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. + set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -493,6 +497,9 @@ def autoquant( model(*example_input2) model.finalize_autoquant() """ + if set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + # perform initial swap from linear weights # to AutoQuantizableLinearWeight diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6f7f549704..33821f1d82 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -16,6 +16,7 @@ """ import torch +import torchao import torch.nn as nn import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional @@ -258,7 +259,7 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: +def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` Args: @@ -266,6 +267,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance) filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module + set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) Example:: @@ -306,6 +308,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: m = quantize(m, apply_weight_quant, filter_fn) """ + if set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() if isinstance(apply_tensor_subclass, str): if apply_tensor_subclass not in _APPLY_TS_TABLE: raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}") diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3e3943c93c..d158862147 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -36,6 +36,7 @@ "groupwise_affine_dequantize_tensor", "per_token_dynamic_quant", "get_group_qparams_symmetric", + "recommended_inductor_config_setter" ] try: @@ -456,3 +457,20 @@ def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor: input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype ) return input.to(orig_dtype) + +def recommended_inductor_config_setter(): + """ + Set inductor config to use the following optimizations which have been showed to improve performance for quantized models: + coordinate_descent_tuning = True + coordinate_descent_check_all_directions = True + force_fuse_int_mm_with_mul = True + fx_graph_cache = True + triton.unique_kernel_names = True + torch.set_float32_matmul_precision("high") + """ + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.coordinate_descent_check_all_directions = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.fx_graph_cache = True + torch._inductor.config.triton.unique_kernel_names = True + torch.set_float32_matmul_precision("high") From 1421de84d74baafab255cddb6857248ee56f5c10 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 24 Jun 2024 22:51:54 -0700 Subject: [PATCH 02/13] fixing tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b2ef3c07ec..a2467f61e7 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -689,6 +689,8 @@ def test_int8_dynamic_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) def test_int8_weight_only_quant_subclass(self, device, dtype): + if dtype == torch.float32: + self.skipTest("Currently not working for float32") self._test_lin_weight_subclass_impl( Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype ) @@ -881,6 +883,8 @@ def test_weight_only_quant(self): def test_weight_only_quant_force_mixed_mm(self, device, dtype): if device != "cuda": self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") + if dtype == torch.float32: + self.skipTest("currently not working for float32") if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config From c9932cea35949f34647d95298864efc47ccb94ed Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 24 Jun 2024 23:54:10 -0700 Subject: [PATCH 03/13] fix weight only failures Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index a2467f61e7..9d15fbd151 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -796,6 +796,8 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): + if dtype == torch.float32: + self.skipTest("Currently not working for float32") self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype ) @@ -915,6 +917,8 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") + if dtype == torch.float32: + self.skipTest("currently not working for float32") torch.manual_seed(0) from torch._inductor import config mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AFTER_2_4 else ("force_mixed_mm", True) From 18cd1aa67102ac953f9a6514b0564b68081d5f0e Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 00:27:25 -0700 Subject: [PATCH 04/13] fixing new broken test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 9d15fbd151..b14f07a41d 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1012,6 +1012,8 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): + if dtype == torch.float32: + self.skipTest("currently not working for float32") self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) From ee5183fbe9763c8c45a6b692a44cbc1c964b02ff Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 06:19:47 -0700 Subject: [PATCH 05/13] fixing autoquant test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b14f07a41d..5f26ddd1d0 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1212,7 +1212,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) - mod = torchao.autoquant(torch.compile(model), manual=True) + mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) mod(example_input) mod(example_input2) mod.finalize_autoquant() From d0e822ba7688c5320a5d05cc7a70c77b92736a1f Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 06:38:54 -0700 Subject: [PATCH 06/13] testing if inductor config is the issue Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index d158862147..15dfdeb5d0 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -468,9 +468,9 @@ def recommended_inductor_config_setter(): triton.unique_kernel_names = True torch.set_float32_matmul_precision("high") """ - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.coordinate_descent_check_all_directions = True - torch._inductor.config.force_fuse_int_mm_with_mul = True - torch._inductor.config.fx_graph_cache = True - torch._inductor.config.triton.unique_kernel_names = True - torch.set_float32_matmul_precision("high") + # torch._inductor.config.coordinate_descent_tuning = True + # torch._inductor.config.coordinate_descent_check_all_directions = True + # torch._inductor.config.force_fuse_int_mm_with_mul = True + # torch._inductor.config.fx_graph_cache = True + # torch._inductor.config.triton.unique_kernel_names = True + # torch.set_float32_matmul_precision("high") From a5654f3d27e910e4802f99b21bd7baaec6ce0520 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 07:06:03 -0700 Subject: [PATCH 07/13] are inductor configs somehow being set? Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 4 ++-- torchao/_models/llama/eval.py | 4 +++- torchao/_models/llama/generate.py | 4 +++- torchao/quantization/utils.py | 12 ++++++------ 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5f26ddd1d0..827376d937 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1190,9 +1190,8 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ - (1, 1, 128, 128), (1, 32, 128, 128), - (32, 32, 128, 128), + (8, 16, 128, 128), ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): @@ -1213,6 +1212,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): out = model(example_input) mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) + assert torch._inductor.config.coordinate_descent_tuning == False mod(example_input) mod(example_input2) mod.finalize_autoquant() diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index d19a038415..73deafffec 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -22,7 +22,6 @@ import time from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer from torchao._models.llama.model import prepare_inputs_for_model -torchao.quantization.utils.recommended_inductor_config_setter() def run_evaluation( checkpoint_path: Path, @@ -39,6 +38,9 @@ def run_evaluation( pad_calibration_inputs: Optional[bool] = False, ): """Runs the evaluation of a model using LM Eval.""" + + torchao.quantization.utils.recommended_inductor_config_setter() + assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index ea20799d75..8142f80bb8 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -13,7 +13,6 @@ import torch._dynamo.config import torch._inductor.config from torchao.utils import get_model_size_in_bytes -torchao.quantization.utils.recommended_inductor_config_setter() def device_sync(device): if "cuda" in device: @@ -157,6 +156,9 @@ def main( ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer. """ + + torchao.quantization.utils.recommended_inductor_config_setter() + assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 15dfdeb5d0..d158862147 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -468,9 +468,9 @@ def recommended_inductor_config_setter(): triton.unique_kernel_names = True torch.set_float32_matmul_precision("high") """ - # torch._inductor.config.coordinate_descent_tuning = True - # torch._inductor.config.coordinate_descent_check_all_directions = True - # torch._inductor.config.force_fuse_int_mm_with_mul = True - # torch._inductor.config.fx_graph_cache = True - # torch._inductor.config.triton.unique_kernel_names = True - # torch.set_float32_matmul_precision("high") + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.coordinate_descent_check_all_directions = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.fx_graph_cache = True + torch._inductor.config.triton.unique_kernel_names = True + torch.set_float32_matmul_precision("high") From 26d911067ba62da20d60b1c31d082eafd5744406 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 10:37:48 -0700 Subject: [PATCH 08/13] when is coordinate descent tuning beinng enabled? Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 827376d937..3d6f499a3c 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1195,6 +1195,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): + assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled A" if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1202,17 +1203,19 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.skipTest(f"bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled B" model = torch.nn.Sequential( torch.nn.ReLU(), torch.nn.Linear(k,n), torch.nn.ReLU(), ).to(device).to(dtype) + assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled C" example_input = torch.randn(m1, k, device=device, dtype=dtype) example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) - assert torch._inductor.config.coordinate_descent_tuning == False + assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled D" mod(example_input) mod(example_input2) mod.finalize_autoquant() From 762ef4188b39ead55127739dafe1029f522af7d6 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 12:09:34 -0700 Subject: [PATCH 09/13] reset inductor config for tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3d6f499a3c..bc88370040 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -124,6 +124,13 @@ def _int4wo_api(mod): _int4wo_api, ] +def undo_recommended_configs(): + torch._inductor.config.coordinate_descent_tuning = False + torch._inductor.config.coordinate_descent_check_all_directions = False + torch._inductor.config.force_fuse_int_mm_with_mul = False + torch._inductor.config.fx_graph_cache = False + torch._inductor.config.triton.unique_kernel_names = False + torch.set_float32_matmul_precision("highest") def combine_parameters(a, b): new_tuples = [] @@ -689,8 +696,7 @@ def test_int8_dynamic_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) def test_int8_weight_only_quant_subclass(self, device, dtype): - if dtype == torch.float32: - self.skipTest("Currently not working for float32") + undo_recommended_configs() self._test_lin_weight_subclass_impl( Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype ) @@ -796,8 +802,7 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): - if dtype == torch.float32: - self.skipTest("Currently not working for float32") + undo_recommended_configs() self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype ) @@ -883,10 +888,9 @@ def test_weight_only_quant(self): @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_weight_only_quant_force_mixed_mm(self, device, dtype): + undo_recommended_configs() if device != "cuda": self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") - if dtype == torch.float32: - self.skipTest("currently not working for float32") if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config @@ -913,12 +917,11 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_weight_only_quant_use_mixed_mm(self, device, dtype): + undo_recommended_configs() if device != "cuda": self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") - if dtype == torch.float32: - self.skipTest("currently not working for float32") torch.manual_seed(0) from torch._inductor import config mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AFTER_2_4 else ("force_mixed_mm", True) @@ -1012,8 +1015,7 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): - if dtype == torch.float32: - self.skipTest("currently not working for float32") + undo_recommended_configs() self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -1190,12 +1192,13 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ + (1, 1, 128, 128), (1, 32, 128, 128), - (8, 16, 128, 128), + (32, 32, 128, 128), ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): - assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled A" + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1203,19 +1206,16 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.skipTest(f"bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled B" model = torch.nn.Sequential( torch.nn.ReLU(), torch.nn.Linear(k,n), torch.nn.ReLU(), ).to(device).to(dtype) - assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled C" example_input = torch.randn(m1, k, device=device, dtype=dtype) example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) - assert torch._inductor.config.coordinate_descent_tuning == False, "coordinate descent tuning was enabled D" mod(example_input) mod(example_input2) mod.finalize_autoquant() From b0c4e2348904bbbb8dd29e2ad777aa1d891c6436 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 12:36:19 -0700 Subject: [PATCH 10/13] more test fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 12 ++++++++---- torchao/quantization/autoquant.py | 5 ++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index bc88370040..8ce568abeb 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1165,6 +1165,7 @@ class TestAutoQuant(unittest.TestCase): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): + undo_recommended_configs() print("(m, k, n): ", (m, k, n)) if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") @@ -1185,7 +1186,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): torch.nn.ReLU(), ).to(device).to(dtype) out = model(example_input) - torchao.autoquant(model) + torchao.autoquant(model, set_inductor_config=False) out2 = model(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) @@ -1227,6 +1228,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_manual(self, device, dtype): + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1242,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype): example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) - mod = torchao.autoquant(torch.compile(model), manual=True) + mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) mod(example_input) mod(example_input2) mod.finalize_autoquant() @@ -1267,6 +1269,7 @@ def test_autoquant_manual(self, device, dtype): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1293,7 +1296,7 @@ def forward(self, x, y): } out = model(**example_input) - mod = torchao.autoquant(torch.compile(model)) + mod = torchao.autoquant(torch.compile(model), set_inductor_config=False) mod(**example_input) out2 = mod(**example_input) @@ -1306,6 +1309,7 @@ def forward(self, x, y): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1329,7 +1333,7 @@ def forward(self, x): x_in = torch.randn(m, k, device=device, dtype=dtype) model = DoubleAccess().to(device).to(dtype) model(x_in) - torchao.autoquant(model) + torchao.autoquant(model, set_inductor_config=False) assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) model(x_in) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 40c333e322..c81f84a9b8 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -91,7 +91,10 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): with torch.no_grad(): act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device) - res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + try: + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + except Exception as e: + print(f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}") update_cache(q_cls, shapes_and_dtype, res) @torch.no_grad() From 8fadc3984c94e0053db15b1ac6b91a2326483948 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 12:37:12 -0700 Subject: [PATCH 11/13] adding warning Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/autoquant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index c81f84a9b8..07287030d7 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -95,6 +95,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) except Exception as e: print(f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}") + res = torch.inf update_cache(q_cls, shapes_and_dtype, res) @torch.no_grad() From 3c2825e5cb7c7d9805cba1ea03f14721da48655c Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 12:48:49 -0700 Subject: [PATCH 12/13] handling of errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 8ce568abeb..4d5a2c511c 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1252,7 +1252,7 @@ def test_autoquant_manual(self, device, dtype): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) - mod2 = torchao.autoquant(model, manual=True) + mod2 = torchao.autoquant(model, manual=True, set_inductor_config=False) mod2(example_input) mod2(example_input2) mod2.finalize_autoquant() @@ -1460,7 +1460,7 @@ def test_get_model_size_autoquant(self, device, dtype): qtensor_class_list = ( AQWeightOnlyQuantizedLinearWeight2, ) - mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list) + mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False) mod(example_input) size2 = torchao.utils.get_model_size_in_bytes(mod) self.assertTrue(size2 < size) From d1050729e7a2a2c2515d13a63d33fe45da6e32d2 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Jun 2024 13:22:12 -0700 Subject: [PATCH 13/13] option to supress autoquant errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/autoquant.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 07287030d7..83d7837d3e 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -412,16 +412,21 @@ def _change_linears_to_autoquantizable(model, **kwargs): filter_fn if filter_fn is not None else _is_linear, ) -def _change_autoquantizable_to_quantized(model, **kwargs): +def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, **kwargs): """ Converts AutoQuantizableLinearWeight tensor subclasses to various quantized/non-quantized tensor subclasses depending on benchmark results. Expectation is that these modules are torch.compiled afterwards. """ - hold = torch._dynamo.config.automatic_dynamic_shapes + hold_automatic_dynamic_shapes = torch._dynamo.config.automatic_dynamic_shapes torch._dynamo.config.automatic_dynamic_shapes = False + if supress_autoquant_errors: + hold_supress_errors = torch._dynamo.config.suppress_errors + torch._dynamo.config.suppress_errors = True + import logging + torch._logging.set_logs(inductor=logging.CRITICAL, dynamo=logging.CRITICAL) filter_fn = kwargs.pop( "filter_fn", lambda mod, *args: @@ -437,7 +442,13 @@ def _change_autoquantizable_to_quantized(model, **kwargs): ), filter_fn, ) - torch._dynamo.config.automatic_dynamic_shapes = hold + # undo dynamic shape change + torch._dynamo.config.automatic_dynamic_shapes = hold_automatic_dynamic_shapes + + # undo error supression + if supress_autoquant_errors: + torch._dynamo.config.suppress_errors = hold_supress_errors + torch._logging.set_logs() torch._dynamo.reset() # TODO: example_input seems weird to include in the API @@ -452,6 +463,7 @@ def autoquant( mode=["interpolate", .85], manual=False, set_inductor_config=True, + supress_autoquant_errors=True, **aq_kwargs ): """ @@ -485,6 +497,7 @@ def autoquant( manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) + supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True) **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -550,6 +563,7 @@ def autoquant_prehook(module, args, kwargs): def finalize_autoquant(): _change_autoquantizable_to_quantized( real_model, + supress_autoquant_errors, **aq_kwargs, ) if hasattr(real_model, "old_forward"):