diff --git a/test/prototype/test_autoround.py b/test/prototype/test_autoround.py new file mode 100644 index 0000000000..e100348725 --- /dev/null +++ b/test/prototype/test_autoround.py @@ -0,0 +1,174 @@ +import pytest +from torchao.prototype.autoround.utils import is_auto_round_available + +if not is_auto_round_available(): + pytest.skip("AutoRound is not available", allow_module_level=True) + +import torch +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torchao import quantize_ + +from torchao.dtypes import AffineQuantizedTensor +from torchao.prototype.autoround.core import ( + apply_auto_round, + prepare_model_for_applying_auto_round_, +) +from torchao.prototype.autoround.multi_tensor import MultiTensor +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +# Copied from https://github.com/pytorch/ao/pull/721 +class TwoLinear(torch.nn.Module): + def __init__(self, in_features=64, out_features=128): + super().__init__() + self.linear1 = torch.nn.Linear(in_features, out_features) + self.linear2 = torch.nn.Linear(in_features, out_features) + + def forward(self, x, y): + x = self.linear1(x) + y = self.linear2(y) + return x + y + + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.two_linear1 = TwoLinear() + self.two_linear2 = TwoLinear(128, 256) + + def forward(self, x, y): + x1 = self.two_linear1(x, y) + x2 = self.two_linear2(x1, x1) + return x2 + + +def _is_two_linear(mod, fqn): + return isinstance(mod, TwoLinear) + + +class ModelWithInplaceOp(torch.nn.Module): + def __init__(self, DIM=128): + super().__init__() + self.lin = torch.nn.Linear(DIM, DIM) + self.register_buffer("other", torch.zeros(DIM, DIM)) + + def forward(self, x, idx): + x = x + self.lin(x) + # update buffer + self.other[idx] = x + return x + + +class M2(torch.nn.Module): + def __init__(self, DIM=128): + super().__init__() + self.m1 = ModelWithInplaceOp(DIM) + self.m2 = ModelWithInplaceOp(DIM) + + def forward(self, x, idx): + x = self.m1(x, idx) + x = self.m2(x, idx) + return x + + +def _check_params_and_buffers_type(module, check_fun): + return [check_fun(p) for p in module.parameters()] + [ + check_fun(b) for b in module.buffers() + ] + + +class TestAutoRound(TestCase): + + @pytest.mark.skip(not TORCH_VERSION_AT_LEAST_2_5, "Requires torch 2.5 or later") + @parametrize("device", _AVAILABLE_DEVICES) + @torch.no_grad() + def test_auto_round(self, device: str): + example_inputs = ( + torch.randn(32, 64).to(device), + torch.randn(32, 64).to(device), + ) + m = M().eval().to(device) + before_quant = m(*example_inputs) + prepare_model_for_applying_auto_round_( + m, + is_target_module=_is_two_linear, + bits=7, + group_size=32, + iters=20, + device=device, + ) + assert all( + _check_params_and_buffers_type(m, lambda x: isinstance(x, MultiTensor)) + ), "Expected all parameters and buffers to be `MultiTensor`." + input1 = [] + input2 = [] + for _ in range(10): + input1.append(torch.randn(32, 64).to(device)) + input2.append(torch.randn(32, 64).to(device)) + + mt_input1 = MultiTensor(input1) + mt_input2 = MultiTensor(input2) + out = m(mt_input1, mt_input2) + assert isinstance(out, MultiTensor), f"Expected MultiTensor, got {type(out)}" + assert all( + _check_params_and_buffers_type(m, lambda x: not isinstance(x, MultiTensor)) + ), "Expected all parameters and buffers have been converted back to tensor." + quantize_(m, apply_auto_round(), _is_two_linear, device=device) + for l in m.modules(): + if isinstance(l, torch.nn.Linear): + assert isinstance(l.weight, AffineQuantizedTensor) + after_quant = m(*example_inputs) + assert after_quant is not None, "Quantized model forward pass failed" + + @pytest.mark.skip(not TORCH_VERSION_AT_LEAST_2_5, "Requires torch 2.5 or later") + @parametrize("device", _AVAILABLE_DEVICES) + @torch.no_grad() + def test_wrap_model_with_multi_tensor(self, device: str): + + _is_model_with_inplace_op = lambda mod, fqn: isinstance(mod, ModelWithInplaceOp) + + DIM = 128 + m = M2(DIM).eval().to(device) + prepare_model_for_applying_auto_round_( + m, + is_target_module=_is_model_with_inplace_op, + bits=7, + group_size=32, + iters=20, + device=device, + ) + assert all( + _check_params_and_buffers_type(m, lambda x: isinstance(x, MultiTensor)) + ), "Expected all parameters and buffers to be `MultiTensor`." + input1 = [] + input2 = [] + for _ in range(2): + input1.append(torch.randn(DIM, DIM).to(device)) + input2.append(torch.randint(0, DIM, (DIM,), dtype=torch.long).to(device)) + + mt_input1 = MultiTensor(input1) + mt_input2 = MultiTensor(input2) + out = m(mt_input1, mt_input2) + assert isinstance(out, MultiTensor), f"Expected MultiTensor, got {type(out)}" + assert all( + _check_params_and_buffers_type(m, lambda x: not isinstance(x, MultiTensor)) + ), "Expected all parameters and buffers have been converted back to tensor." + quantize_(m, apply_auto_round(), _is_model_with_inplace_op, device=device) + for l in m.modules(): + if isinstance(l, torch.nn.Linear): + assert isinstance(l.weight, AffineQuantizedTensor) + after_quant = m(input1[0], input2[0]) + assert after_quant is not None, "Quantized model forward pass failed" + + +instantiate_parametrized_tests(TestAutoRound) + +if __name__ == "__main__": + run_tests() diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index b4d3bf5fe9..9020715e70 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -12,6 +12,12 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +# auto-round w/ quant_lm_head +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround +# auto-round w/o quant_lm_head +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 + + export MODEL_REPO=meta-llama/Meta-Llama-3-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt @@ -23,6 +29,12 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +# auto-round w/ quant_lm_head +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround +# auto-round w/o quant_lm_head +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 + + export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 94a18488b2..a559fc241c 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -30,7 +30,7 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.model import Transformer, prepare_inputs_for_model, TransformerBlock from torchao._models.llama.tokenizer import get_tokenizer def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization @@ -219,6 +219,53 @@ def main( groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) + + if "autoround" in quantization: + from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ + from transformers import AutoTokenizer + + _tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent) + # parse args from quantization string: + # autoround------- + # A lightweight configuration for generation benchmarking. + _quant_args = quantization.split("-") + _default_quant_args = [True, 1, 128, 1, 512, 32] + _model_devie = _quant_args[1] if len(_quant_args) > 1 else device + _quant_args = _quant_args[2:] + quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [ + int(x) for x in _quant_args + ] + _default_quant_args[len(_quant_args) :] + model = model.to(_model_devie) + print( + ( + f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, " + f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})" + ) + ) + with torch.device(_model_devie): + model.setup_caches( + max_batch_size=batch_size, max_seq_length=seqlen, training=True + ) + + if quant_lm_head: + is_target_module = ( + lambda mod, fqn: isinstance(mod, TransformerBlock) or "output" in fqn + ) + else: + is_target_module = lambda mod, fqn: isinstance(mod, TransformerBlock) + quantize_model_with_autoround_( + model=model, + tokenizer=_tokenizer, + is_target_module=is_target_module, + bits=4, + seqlen=seqlen, + bs=batch_size, + iters=iters, + nsamples=nsamples, + ) + model.to(device) + model.reset_caches() + if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if "autoquant" == quantization: @@ -387,7 +434,7 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoround-------') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index ab3a51eef3..59b5cf98fd 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -190,7 +190,16 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_ dtype, use_scaled=self.config.use_scaled_rope ) + + def reset_caches(self): + """Reset caches. + The caches used by training stage and inference stage may be different, reset them before switching. + """ + self.max_batch_size = -1 + self.max_seq_length = -1 + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md new file mode 100644 index 0000000000..11671009b5 --- /dev/null +++ b/torchao/prototype/autoround/README.md @@ -0,0 +1,104 @@ +# Auto-Round + +Auto-Round is an advanced quantization algorithm designed for low-bit LLM inference. It leverages [sign gradient descent](https://arxiv.org/abs/1905.12938) to fine-tune rounding values and minmax values of weights. This approach competes impressively with recent methods without introducing any additional inference overhead while using low tuning costs. This module provides the end-to-end examples to quantize floating-point models to low-bit and integration with torchao's `quantize_` API and low-bit kernels. + +## Usage + +### Quick Start + +```python +python autoround_llm.py -m /model/name/or/path +``` + + +> [!NOTE] +> Before running, ensure you have installed the `auto-round` with `pip install -r requirements.txt`. + + +### Detailed Usage + +`Auto-Round` is a calibration-based quantization algorithm. The flow involves three main steps: 1) insert hooks to the modules you want to quantize, 2) Wrap the calibration data with `MultiTensor` and run the model, 3) Replace the optimized weight with `AffineQuantizedTensor` to select the appropriate low-bit kernel. + +> [!NOTE] +> To learn more about the flow and `MultiTensor`, please refer to [this example](https://github.com/pytorch/ao/blob/main/tutorials/calibration_flow/gptq_like.py). + +#### Step 1: Prepare the Model +```python +model = ... # Load your model +model_device = next(model.parameters()).device +device = "cuda" if torch.cuda.is_available() else "cpu" + +# Define a function to identify target modules for quantization. +# For example, to apply Auto-Round to all decoder layers and the `lm-head` in a Llama model: +decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer +is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn +# Prepare the model for Auto-Round +from torchao.prototype.autoround.core import prepare_model_for_applying_auto_round_ + +prepare_model_for_applying_auto_round_( + model, + is_target_module=is_target_module, + bits=4, + group_size=128, + iters=200, + device=device, +) +``` +> [!NOTE] +> To avoid OOM issues, load the model on CPU, and set `device` to `'cuda'`. + +#### Step 2: Apply Optimization +Wrap all inputs as a `MultiTensor` to track all calibration data for optimized modules: + +```python +input_ids_lst = [] +for data in dataloader: + input_ids_lst.append(data["input_ids"].to(model_device)) + +multi_t_input_ids = MultiTensor(input_ids_lst) +# The optimization is applied during the forward pass +out = model(multi_t_input_ids) +``` +#### Step 3: Finalize Quantization +After obtaining optimized `zero_point` and `scale` values, create the `AffineQuantizedTensor` +for each target weight to select the right low-bits kernel. + +```python +from torchao.prototype.autoround.core import apply_auto_round + +quantize_(model, apply_auto_round(), is_target_module) +``` + +## End-to-End Results +### [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) +| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | +| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- | +| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 | +| auto-round-4bit | 0.6988 | 0.6533 | 0.7949 | 0.7372 | 0.5837 | 0.7250 | +| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 | + +### [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | +| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- | +| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 | +| auto-round-4bit | 0.6818 | 0.6232 | 0.7862 | 0.7230 | 0.5661 | 0.7105 | +| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 | + + +### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) +| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | +| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- | +| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.577 | 0.7070 | +| auto-round-4bit | 0.6327 | 0.4534 | 0.7590 | 0.6661 | 0.5706 | 0.7143 | +| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 | + +> [!NOTE] +> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`.
+> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`. +> - If the model includes operations without a deterministic implementation (such as Flash Attention), the results may differ slightly. + + +## Credits + +- Paper: https://arxiv.org/abs/2309.05516 +- Authors: [IntelĀ® Neural Compressor Team](https://github.com/intel/neural-compressor) diff --git a/torchao/prototype/autoround/__init__.py b/torchao/prototype/autoround/__init__.py new file mode 100644 index 0000000000..4c2ff65adb --- /dev/null +++ b/torchao/prototype/autoround/__init__.py @@ -0,0 +1,5 @@ +from torchao.prototype.autoround.core import ( + apply_auto_round, + prepare_model_for_applying_auto_round_, +) +from torchao.prototype.autoround.multi_tensor import MultiTensor diff --git a/torchao/prototype/autoround/autoround_llm.py b/torchao/prototype/autoround/autoround_llm.py new file mode 100644 index 0000000000..2d464be0f4 --- /dev/null +++ b/torchao/prototype/autoround/autoround_llm.py @@ -0,0 +1,180 @@ +import argparse +import logging + +import torch + +import torchao +import torchao.prototype.autoround.utils as ar_utils + +from torchao.prototype.autoround.core import ( + apply_auto_round, + prepare_model_for_applying_auto_round_, +) +from torchao.prototype.autoround.multi_tensor import MultiTensor +from torchao.quantization import quantize_ + +ar_utils.freeze_random(42) + + +@torch.no_grad() +def quantize_model_with_autoround_( + model, + tokenizer, + is_target_module, + bits: int = 4, + group_size: int = 128, + iters: int = 200, + seqlen: int = 2048, + dataset_name: str = "NeelNanda/pile-10k", + bs: int = 8, + nsamples: int = 128, + use_optimized_layer_output: bool = False, +): + # Step 1. Prepare the model for applying auto-round + + model_device = next(model.parameters()).device + device = "cuda" if torch.cuda.is_available() else "cpu" + + prepare_model_for_applying_auto_round_( + model, + is_target_module, + bits, + group_size, + iters, + use_optimized_layer_output, + device=device, + ) + + # Step 2. Caliration and optimization + dataloader = ar_utils.import_dataloader()( + tokenizer, + seqlen=seqlen, + dataset_name=dataset_name, + bs=bs, + nsamples=nsamples, + ) + input_ids_lst = [] + for data in dataloader: + input_ids_lst.append(data["input_ids"].to(model_device)) + print( + f"Number of batches: {len(input_ids_lst)}, shape of all batches: {[inp.shape for inp in input_ids_lst]}" + ) + + multi_t_input_ids = MultiTensor(input_ids_lst) + + # The optimization is applied during the forward pass + out = model(multi_t_input_ids) + + # Step 3. Apply the quantization + quantize_(model, apply_auto_round(), is_target_module, device=device) + + num_quantized_weight = ar_utils.count_tensor_of_type( + model, torchao.dtypes.AffineQuantizedTensor + ) + print(f"Quantized {num_quantized_weight} Linear layers.") + + return model + + +def main(args): + # Get the model, tokenizer, and decoder_cls + model_name_or_path = args.model_name_or_path + model, tokenizer, decoder_cls = ar_utils.get_float_model_info( + model_name_or_path, torch_dtype=torch.bfloat16 + ) + # Disable the `use_cache` for calibration stage. + model.config.use_cache = False + ar_utils.gen_text(model, tokenizer, "Float model", max_length=50) + + model = model.to(args.model_device) + + # User need to prepare a `is_target_module` function for identifying the target modules that need to be quantized. + if args.quant_lm_head: + is_target_module = ( + lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn + ) + else: + is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) + + quantize_model_with_autoround_( + model=model, + tokenizer=tokenizer, + is_target_module=is_target_module, + bits=args.bits, + iters=args.iters, + seqlen=args.seqlen, + dataset_name=args.dataset_name, + bs=args.train_bs, + nsamples=args.nsamples, + use_optimized_layer_output=args.use_optimized_layer_output, + ) + # Revert the `use_cache` for generation stage. + model.config.use_cache = True + + # Generate text using the quantized model + ar_utils.gen_text(model, tokenizer, "Quantized model", max_length=50) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-m", + "--model_name_or_path", + type=str, + default="facebook/opt-125m", + help="Model name or path", + ) + parser.add_argument( + "--dataset_name", + type=str, + default="NeelNanda/pile-10k", + help="Dataset name for calibration", + ) + parser.add_argument( + "--iters", + default=200, + type=int, + help="Number of iterations for auto-round optimization", + ) + parser.add_argument( + "--bits", default=4, type=int, help="Number of bits for quantization" + ) + parser.add_argument( + "--train_bs", default=8, type=int, help="Batch size for auto-round optimization" + ) + parser.add_argument( + "--nsamples", + default=128, + type=int, + help="Number of samples for calibration process", + ) + parser.add_argument( + "--seqlen", + default=2048, + type=int, + help="Sequence length for calibration process", + ) + parser.add_argument( + "--quant_lm_head", + default=False, + action="store_true", + help="Quantize the `lm_head` or not", + ) + parser.add_argument( + "--use_optimized_layer_output", + default=False, + action="store_true", + help="Use the optimized layer output for next layer or not", + ) + parser.add_argument( + "-d", + "--model_device", + default="cuda", + type=str, + choices=["cpu", "cuda"], + help="Device for loading the float model", + ) + args = parser.parse_args() + main(args) diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py new file mode 100644 index 0000000000..342f14d825 --- /dev/null +++ b/torchao/prototype/autoround/core.py @@ -0,0 +1,352 @@ +import dataclasses +import logging +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +from torch.utils._pytree import tree_flatten, tree_unflatten + +import torchao.prototype.autoround.utils as ar_utils +import torchao.quantization as ao_quant +from torchao.dtypes import TensorCoreTiledLayoutType, to_affine_quantized_intx_static +from torchao.prototype.autoround.multi_tensor import _multi_tensor_config, MultiTensor +from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.utils import find_multiple + + +@ar_utils.singleton +@dataclasses.dataclass +class _AutoRoundConfig: + bits: int = 4 + group_size: int = 128 + iters: int = 200 + use_optimized_layer_output: bool = False + + +_auto_round_config = _AutoRoundConfig() + + +@ar_utils.singleton +@dataclasses.dataclass +class _OptimizationTracker: + num_layers: int = 0 + optimized_layers: int = 0 + + def reset(self): + self.num_layers = 0 + self.optimized_layers = 0 + + +_optimization_tracker = _OptimizationTracker() + + +def _replace_model_buffers_and_params(model, replacement_fn): + model = replacement_fn(model) + for name, child in model.named_children(): + new_child = _replace_model_buffers_and_params(child, replacement_fn) + if new_child is not child: + setattr(model, name, new_child) + return model + + +def _tensor_to_multi_tensor(model): + for name, buf in model.named_buffers(recurse=False): + setattr(model, name, MultiTensor([buf])) + for name, param in model.named_parameters(recurse=False): + setattr(model, name, torch.nn.Parameter(MultiTensor([param]), False)) + return model + + +def _multi_tensor_to_tensor(model): + for name, buf in model.named_buffers(recurse=False): + if isinstance(buf, MultiTensor): + assert ( + len(buf.values) == 1 + ), f"The buffer should only have one tensor, but got {buf.count}." + model.register_buffer(name, buf.values[0]) + for name, param in model.named_parameters(recurse=False): + if isinstance(param, MultiTensor): + assert ( + len(param.values) == 1 + ), f"The parameter should only have one tensor, but got {param.count}." + setattr( + model, name, torch.nn.Parameter(param.values[0], requires_grad=False) + ) + return model + + +@torch.no_grad() +def prepare_model_for_applying_auto_round_( + model: torch.nn.Module, + is_target_module: Callable[[torch.nn.Module, str], bool], + bits: int = 4, + group_size: int = 128, + iters: int = 200, + use_optimized_layer_output: bool = False, + device: Optional[torch.types.Device] = None, +): + """Prepares the model for applying auto round optimization. + + Args: + model (torch.nn.Module): The floating-point model to be quantized. + is_target_module (Callable[[torch.nn.Module, str], bool]): A function that determines + whether a module is a target module. + bits (int, optional): The number of bits for quantization. Defaults to 4, options are 1 to 8. + group_size (int, optional): The group size for quantization. Defaults to 128. + iters (int, optional): The number of iterations for optimization. Defaults to 200. + use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to False. + device (Optional[torch.types.Device], optional): The device to use for accelrating optimization and calibration. + Defaults to None. + """ + _multi_tensor_config.device = device + _multi_tensor_config.offload = next(model.parameters()).device.type != device + _optimization_tracker.reset() + + _auto_round_config.bits = bits + _auto_round_config.group_size = group_size + _auto_round_config.iters = iters + _auto_round_config.use_optimized_layer_output = use_optimized_layer_output + + logging.warning(f"config {_auto_round_config}") + + # Wrap the model buffers and parameters with `MultiTensor` + model = _replace_model_buffers_and_params(model, _tensor_to_multi_tensor) + + def _revert_buffers_and_params_fn( + module, + input: Tuple[MultiTensor], + output: Tuple[MultiTensor], + ): + module._forward_hook_handle_for_revert_buffers_and_params.remove() + _replace_model_buffers_and_params(module, _multi_tensor_to_tensor) + return output + + # Register forward hook for reverting the replacement of buffers and parameters + model._forward_hook_handle_for_revert_buffers_and_params = ( + model.register_forward_hook(_revert_buffers_and_params_fn) + ) + + # Register forward hook for applying auto-round optimization + def auto_round_optimization_hook( + module, + args: Tuple[MultiTensor], + kwargs: Dict[str, MultiTensor], + output: Tuple[MultiTensor], + ): + apply_auto_round_optimization( + module, args, kwargs, output, config=_auto_round_config + ) + return output + + def _register_forward_hook(module: torch.nn.Module): + forward_hook_handle = module.register_forward_hook( + auto_round_optimization_hook, with_kwargs=True + ) + module._forward_hook_handle_for_auto_round = forward_hook_handle + _optimization_tracker.num_layers += 1 + return module + + model.eval() + ao_quant.quant_api._replace_with_custom_fn_if_matches_filter( + model, _register_forward_hook, is_target_module + ) + + +def apply_auto_round(): + """Create the quantized model from the model optimized by auto-round. + + More details about the auto-round can be found at https://arxiv.org/abs/2309.05516. + """ + + def _apply_auto_round(optimized_model: torch.nn.Module): + """ + The `optimized_model` includes `Linear` layers optimized by auto-round, which includes `qdq_weight`, `scale`, `zp`. + """ + + @torch.no_grad() + def convert_weight_to_affine_quantized_tensor(observed_linear: torch.nn.Module): + device = observed_linear.weight.device + scale = observed_linear.scale.to(device) + zero_point = observed_linear.zp.to(device) + + def to_uintx_weight(input_float): + quant_min = 0 + quant_max = _auto_round_config.bits**2 - 1 + block_size = (1, observed_linear.group_size) + from torchao.dtypes.uintx.Uintx import ( + _BIT_WIDTH_TO_DTYPE, + UintxLayoutType, + ) + from torchao.quantization.quant_primitives import ZeroPointDomain + + assert ( + _auto_round_config.bits in _BIT_WIDTH_TO_DTYPE + ), f"Invalid bits: {_auto_round_config.bits}" + dtype = _BIT_WIDTH_TO_DTYPE[_auto_round_config.bits] + pack_dim = -1 + layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim) + return to_affine_quantized_intx_static( + input_float=input_float, + scale=scale.to(input_float.dtype), + zero_point=zero_point, + block_size=block_size, + target_dtype=torch.uint8, + quant_min=quant_min, + quant_max=quant_max, + zero_point_domain=ZeroPointDomain.INT, + layout_type=layout_type, + ) + + def to_int4_tinygemm_weight(input_float): + # TODO(Yi): check the weight shape, `group_size`, and `inner_k_tiles` to make sure the tinygemm can handle it + inner_k_tiles = 8 + quant_min = 0 + quant_max = _auto_round_config.bits**2 - 1 + # Shift the `zero_point` to align with tiny gemm. + # The dequantization process in tiny gemm: + # tiny_dequant = (tiny_quant - 8) * scale + tiny_zp + # The dequantization porcess in auto-round + # dequant = (quant - zp) * scale + # To align with tiny gemm: + # dequant = (quant - 8 + 8 - zp) * scale + # = (quant - 8) * scale + (8 - zp) * scale + # \__/ \______________/ + # tiny_quant tiny_zp + mid_point = (quant_max + quant_min + 1) / 2 + shifted_zero_point = (mid_point - zero_point) * scale + block_size = (1, observed_linear.group_size) + orig_out_features, orig_in_features = input_float.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + orig_num_groups = orig_in_features // observed_linear.group_size + new_num_groups = in_features // observed_linear.group_size + # pad scale/zero_point from [orig_out_features, orig_num_groups] to [out_features, new_num_groups] + pad_scale = torch.nn.functional.pad( + scale, + ( + 0, + new_num_groups - orig_num_groups, + 0, + out_features - orig_out_features, + ), + ) + pad_shifted_zero_point = torch.nn.functional.pad( + shifted_zero_point, + ( + 0, + new_num_groups - orig_num_groups, + 0, + out_features - orig_out_features, + ), + ) + return to_affine_quantized_intx_static( + input_float=input_float, + scale=pad_scale.to(torch.bfloat16), + zero_point=pad_shifted_zero_point.to(torch.bfloat16), + block_size=block_size, + target_dtype=torch.int32, + quant_min=quant_min, + quant_max=quant_max, + zero_point_domain=ZeroPointDomain.FLOAT, + layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), + ) + + # TODO(Yi): better way to select the weight quantization function + if ( + _auto_round_config.bits == 4 + and observed_linear.weight.device.type == "cuda" + ): + weight_func = to_int4_tinygemm_weight + else: + weight_func = to_uintx_weight + + observed_linear.weight = torch.nn.Parameter( + weight_func(observed_linear.weight), requires_grad=False + ) + del observed_linear.scale + del observed_linear.zp + return observed_linear + + def _is_observed_linear(mod: torch.nn.Module, fqn: str): + return hasattr(mod, "scale") + + qmodel = ao_quant.quant_api._replace_with_custom_fn_if_matches_filter( + optimized_model, + convert_weight_to_affine_quantized_tensor, + _is_observed_linear, + ) + return qmodel + + return _apply_auto_round + + +@torch.no_grad() +def _apply_auto_round_optimization( + block, block_inputs, block_outputs, config: _AutoRoundConfig +): + # Call the auto-round to execute the optimization process. + # https://github.com/intel/auto-round/tree/patch-for-ao-2 + # TODO(Yi), make the branch more stable + if ar_utils.is_auto_round_available(): + import auto_round + else: + raise ImportError( + ( + "This example requires the `auto-round` library." + "Please install it with `pip install git+https://github.com/intel/auto-round.git@patch-for-ao-2`" + ) + ) + orig_device = next(block.parameters()).device + block = block.to(_multi_tensor_config.device) + _optimization_tracker.optimized_layers += 1 + logging.warning( + "Apply auto-round optimization on layer %d / %d.", + _optimization_tracker.optimized_layers, + _optimization_tracker.num_layers, + ) + + # Start the training process to update the v, alpha and beta. + rounder = auto_round.AutoRound( + model=block, + tokenizer=None, + sym=False, + bits=config.bits, + iters=config.iters, + group_size=config.group_size, + amp=True, + model_dtype=next(block.parameters()).dtype, + ) + + with torch.enable_grad(): + rounder.quant_block_v2_( + block, + inputs=block_inputs, + outputs=block_outputs, + device=_multi_tensor_config.device, + ) + block.to(orig_device) + + +@ar_utils.dump_elapsed_time() +@torch.no_grad() +def apply_auto_round_optimization( + module: torch.nn.Module, + args: Tuple[MultiTensor], + kwargs: Dict[str, Any], + output: Any, + config: _AutoRoundConfig, +): + # Remove the hook to avoid recursive calls + module._forward_hook_handle_for_auto_round.remove() + # Revert the model to the original state for applying auto-round optimization + module = _replace_model_buffers_and_params(module, _multi_tensor_to_tensor) + + block_inputs = MultiTensor.revert_to_tensor_pairs(args, kwargs) + block_outputs = MultiTensor.revert_to_tensor_pairs(output) + + _apply_auto_round_optimization(module, block_inputs, block_outputs, config) + # Get the new output of the optimized model + if config.use_optimized_layer_output: + # Re-replace the model buffers and parameters with `MultiTensor` + _replace_model_buffers_and_params(module, _tensor_to_multi_tensor) + output = module(*args, **kwargs) + return output diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py new file mode 100644 index 0000000000..e72205aba9 --- /dev/null +++ b/torchao/prototype/autoround/eval_autoround.py @@ -0,0 +1,222 @@ +import argparse + +import torchao.prototype.autoround.utils as ar_utils + +ar_utils.freeze_random(42) +import torch + +torch.use_deterministic_algorithms(True, warn_only=True) +import torchao + +import torchao.quantization +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + + +@ar_utils.dump_elapsed_time() +def run_evaluation(model, tokenizer, tasks, compile=False, batch_size=4): + try: + from lm_eval.evaluator import evaluate + from lm_eval.models.huggingface import HFLM + from lm_eval.tasks import get_task_dict + except ImportError as e: + print( + """ + Error: The 'lm_eval' module was not found. + To install, follow these steps: + pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git + """ + ) + raise # Re-raise the ImportError + + with torch.no_grad(): + result = evaluate( + HFLM(pretrained=model, tokenizer=tokenizer, batch_size=batch_size), + get_task_dict(tasks), + ) + torch.cuda.empty_cache() + from lm_eval.utils import make_table + + print(make_table(result)) + + +def bench_accuracy(model, tokenizer, tasks, msg=""): + with torch.no_grad(): + print(f"==================== {msg} ====================") + print(f"tasks: {tasks}") + from torchao.prototype.autoround.hf_eval_utils import run_evaluation + + torch.cuda.empty_cache() + res = run_evaluation(model, tokenizer, tasks=tasks) + torch.cuda.empty_cache() + + +def _is_linear_but_not_lm_head(mod, fqn): + return isinstance(mod, torch.nn.Linear) and "lm_head" not in fqn + + +def main(args): + with torch.no_grad(): + model_name_or_path = args.model_name_or_path + model, tokenizer, decoder_cls = ar_utils.get_float_model_info( + model_name_or_path, torch_dtype=torch.bfloat16 + ) + model.eval() + model_device = args.model_device + ar_utils.gen_text(model, tokenizer, "Float model", max_length=50) + model = model.to(model_device) + model.config.use_cache = False + msg = "Float-model" if args.eval_float_model else "Quantized-model" + if not args.eval_float_model: + filter_fn = None if args.quant_lm_head else _is_linear_but_not_lm_head + # Evaluate the quantized model + if args.woq_int4: + msg += " (int4wo)" + from torchao.quantization import int4_weight_only, quantize_ + + quantize_( + model, + int4_weight_only(group_size=args.group_size), + filter_fn=filter_fn, + device=model_device, + ) + elif args.uintx: + msg += f" (uintx {args.bits} bits)" + from torchao.dtypes.uintx.Uintx import _BIT_WIDTH_TO_DTYPE + from torchao.quantization.quant_api import quantize_, uintx_weight_only + + bits = args.bits + assert bits in _BIT_WIDTH_TO_DTYPE, f"Invalid bits: {bits}" + dtype = _BIT_WIDTH_TO_DTYPE[bits] + quantize_( + model, + uintx_weight_only(dtype=dtype, group_size=args.group_size), + filter_fn=filter_fn, + device=model_device, + ) + + else: + msg += f" (auto-round {args.bits} bits)" + torch.cuda.empty_cache() + from torchao.prototype.autoround.autoround_llm import ( + quantize_model_with_autoround_, + ) + + # User need to prepare a `is_target_module` function for identifying the target modules that need to be quantized. + if args.quant_lm_head: + is_target_module = ( + lambda mod, fqn: isinstance(mod, decoder_cls) + or "lm_head" in fqn + ) + else: + is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) + + model = quantize_model_with_autoround_( + model=model, + tokenizer=tokenizer, + is_target_module=is_target_module, + bits=args.bits, + group_size=args.group_size, + iters=args.iters, + seqlen=args.seqlen, + bs=args.train_bs, + nsamples=args.nsamples, + use_optimized_layer_output=args.use_optimized_layer_output, + ) + quantized_layer_cnt = ar_utils.count_tensor_of_type( + model, torchao.dtypes.AffineQuantizedTensor + ) + msg += f" quantized {quantized_layer_cnt} Linear layers " + ar_utils.gen_text(model, tokenizer, msg, max_length=50) + + bench_accuracy(model, tokenizer, tasks=args.tasks, msg=msg) + + +if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-m", + "--model_name_or_path", + type=str, + default="facebook/opt-125m", + help="Model name or path", + ) + parser.add_argument( + "--iters", + default=200, + type=int, + help="Number of iterations for auto-round optimization", + ) + parser.add_argument( + "--bits", default=4, type=int, help="Number of bits for quantization" + ) + parser.add_argument( + "--train_bs", default=8, type=int, help="Batch size for auto-round optimization" + ) + parser.add_argument( + "--nsamples", + default=128, + type=int, + help="Number of samples for calibration process", + ) + parser.add_argument( + "--group_size", + default=128, + type=int, + help="Group size for quantization", + ) + parser.add_argument( + "--seqlen", + default=2048, + type=int, + help="Sequence length for calibration process", + ) + parser.add_argument( + "--quant_lm_head", + default=False, + action="store_true", + help="Quantize the `lm_head` or not", + ) + parser.add_argument( + "--use_optimized_layer_output", + default=False, + action="store_true", + help="Use the optimized layer output for next layer or not", + ) + parser.add_argument( + "-d", + "--model_device", + default="cuda", + type=str, + choices=["cpu", "cuda"], + help="Device for loading the float model", + ) + parser.add_argument( + "--eval_float_model", + default=False, + action="store_true", + help="Evaluate the float model", + ) + parser.add_argument( + "--woq_int4", + default=False, + action="store_true", + help="Quantize the model with int4 weight only", + ) + parser.add_argument( + "--uintx", + default=False, + action="store_true", + help="Quantize the model with int4 weight only", + ) + parser.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", + ) + args = parser.parse_args() + + main(args) diff --git a/torchao/prototype/autoround/multi_tensor.py b/torchao/prototype/autoround/multi_tensor.py new file mode 100644 index 0000000000..72af4d8546 --- /dev/null +++ b/torchao/prototype/autoround/multi_tensor.py @@ -0,0 +1,171 @@ +import dataclasses +from typing import List + +import torch +from torch.utils._pytree import tree_flatten, tree_unflatten + + +@dataclasses.dataclass +class _MultiTensorConfig: + device: str = "cuda" if torch.cuda.is_available() else "cpu" + ops_to_accelerate: List[str] = dataclasses.field( + default_factory=lambda: [ + torch.nn.functional.linear, + torch.matmul, + torch.bmm, + torch.nn.functional.scaled_dot_product_attention, + ] + ) + offload: bool = False + + +# Note: As the `MultiTensor` includes a list of tensors, during the calibration stage, +# placing all output tensors on the GPU would consume a significant amount of GPU memory. +# This is especially true for models with a large `lm-head`, such as Llama-3.1. +# In these cases, we load the model onto the DRAM and only transfer tensors to the GPU for compute-intensive operations. +_multi_tensor_config = _MultiTensorConfig() + + +class MultiTensor(torch.Tensor): + # Modified from https://gist.github.com/HDCharles/a1b575bbf8875f994af8a01b225e1227 + @staticmethod + def __new__(cls, input, **kwargs): + if isinstance(input, (list, tuple)): + input = input[0] + kwargs["dtype"] = kwargs.get("dtype", input.dtype) + shape = kwargs.pop("shape", input.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, input, **kwargs): + self.values = [] + self.count = 0 + self.add_tensors(input) + self.debug = True + + def __repr__(self): + return f"{self.__class__.__name__}(data={self.values})" + + def add_tensors(self, input): + if isinstance(input, (tuple, list)): + for inp in input: + self.add_tensors(inp) + else: + assert isinstance( + input, torch.Tensor + ), f"MultiTensor can only use add_tensors for Tensors or lists of tensors but got {type(input)}" + self.count += 1 + self.values.append(input) + return self + + def get_value(self, i): + # Instead of copy the last tensor to pad the MultiTensor, we use this function to do fake padding + # to avoid introducing extra memory usage. + if i + 1 <= self.count: + return self.values[i] + else: + return self.values[-1] + + @classmethod + def flat_to_grouped(cls, flat): + # size of biggest MultiTensor + multi_tensor_size = max( + [x.count if isinstance(x, MultiTensor) else 1 for x in flat] + ) + + grouped = [] + for i in range(multi_tensor_size): + sub_group = [] + for x in flat: + if isinstance(x, MultiTensor): + sub_group.append(x.get_value(i)) + else: + sub_group.append(x) + grouped.append(sub_group) + return grouped + + @classmethod + def grouped_to_flat(cls, grouped): + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] + # where A is nontensor, b's,c's are tensors + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)] + flat_tups = list(zip(*grouped)) + # convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] + flattened = [ + cls(tup) if isinstance(tup[0], torch.Tensor) else tup[0] + for tup in flat_tups + ] + # need to check that getting rid of all but one from each nonTensor tuple is OK + non_tensors_equal = min( + [True] + + [ + min( + [True] + + [ # handle situation where tuples have size 0 + tup[0] == x for x in tup # check all elements match + ] + ) + for tup in flat_tups + if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors + ] + ) + return flattened, non_tensors_equal + + @classmethod + def revert_to_tensor_pairs(cls, args, kwargs=None): + if kwargs is None: + kwargs = {} + # combine args and kwargs and remove lists and tuples + flat_args, spec = tree_flatten((args, kwargs)) + # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]] + grouped_args = cls.flat_to_grouped(flat_args) + args_kwargs_pairs = [] + for i, inp in enumerate(grouped_args): + cur_args, cur_kwargs = tree_unflatten(inp, spec) + args_kwargs_pairs.append((cur_args, cur_kwargs)) + return args_kwargs_pairs + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + args_kwargs_pairs = cls.revert_to_tensor_pairs(args, kwargs) + # run function for each of the multitensors and return a multitensor + outputs = [] + with torch._C.DisableTorchFunctionSubclass(): + for cur_args, cur_kwargs in args_kwargs_pairs: + if func in _multi_tensor_config.ops_to_accelerate: + device = _multi_tensor_config.device + cur_args = [ + (arg.to(device) if isinstance(arg, torch.Tensor) else arg) + for arg in cur_args + ] + cur_kwargs = { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in cur_kwargs.items() + } + out = func(*cur_args, **cur_kwargs) + offload = _multi_tensor_config.offload + outputs.append( + out.to("cpu") if isinstance(out, torch.Tensor) and offload else out + ) + grouped_outputs = [tree_flatten(x)[0] for x in outputs] + out_spec = tree_flatten(outputs[0])[1] + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] + flat_outputs, non_tensors_equal = cls.grouped_to_flat(grouped_outputs) + assert non_tensors_equal, ( + f"ERR: found a function in model: {func} which " + "caused an error in MultiTensor, the function dispatch only works for functions" + " with Tensor outputs or that have the same non-Tensor output value for all across all inputs" + ) + return tree_unflatten(flat_outputs, out_spec) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs={}): + pass + + def __tensor_flatten__(self): + return ["values"], None + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + return cls(tensor_data_dict["values"]) diff --git a/torchao/prototype/autoround/requirements.txt b/torchao/prototype/autoround/requirements.txt new file mode 100644 index 0000000000..e18bdd0829 --- /dev/null +++ b/torchao/prototype/autoround/requirements.txt @@ -0,0 +1,4 @@ +auto_round @ git+https://github.com/intel/auto-round.git@patch-for-ao-2 +numpy < 2.0 # dataset requires numpy < 2.0, can be removed after dataset is updated +datasets # for loading the calibration dataset +transformers # for loading the model \ No newline at end of file diff --git a/torchao/prototype/autoround/run_example.sh b/torchao/prototype/autoround/run_example.sh new file mode 100644 index 0000000000..9ae8956ebf --- /dev/null +++ b/torchao/prototype/autoround/run_example.sh @@ -0,0 +1,17 @@ +# Run examples +python autoround_llm.py -m meta-llama/Llama-2-7b-chat-hf +python autoround_llm.py -m meta-llama/Llama-2-7b-chat-hf --quant_lm_head +python autoround_llm.py -m meta-llama/Meta-Llama-3-8B-Instruct --model_device cpu +python autoround_llm.py -m meta-llama/Meta-Llama-3.1-8B-Instruct --model_device cpu + +# Evaluate with lm-eval +# Auto-round +python eval_autoround.py -m meta-llama/Llama-2-7b-chat-hf --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu +python eval_autoround.py -m meta-llama/Meta-Llama-3-8B-Instruct --model_device cpu --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu +python eval_autoround.py -m meta-llama/Meta-Llama-3.1-8B-Instruct --model_device cpu --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu +# wo_int4 +python eval_autoround.py -m meta-llama/Llama-2-7b-chat-hf --woq_int4 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu +python eval_autoround.py -m meta-llama/Meta-Llama-3-8B-Instruct --woq_int4 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu +python eval_autoround.py -m meta-llama/Meta-Llama-3.1-8B-Instruct --woq_int4 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu +# uintx +python eval_autoround.py -m /models/Meta-Llama-3.1-8B-Instruct/ --uintx --bits 2 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu \ No newline at end of file diff --git a/torchao/prototype/autoround/utils.py b/torchao/prototype/autoround/utils.py new file mode 100644 index 0000000000..ea62b2e34d --- /dev/null +++ b/torchao/prototype/autoround/utils.py @@ -0,0 +1,178 @@ +# ==------------------------------------------------------------------------------------------== +# Utils for the auto-round +# ==------------------------------------------------------------------------------------------== +import logging +import random + +import numpy as np +import torch + + +def _is_package_available(pkg_name, metadata_name=None): + # Copied from Accelerate https://github.com/huggingface/accelerate + import importlib + + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + if package_exists: + try: + # Some libraries have different names in the metadata + _ = importlib.metadata.metadata( + pkg_name if metadata_name is None else metadata_name + ) + return True + except importlib.metadata.PackageNotFoundError: + return False + + +def is_auto_round_available() -> bool: + return _is_package_available("auto_round") + + +def import_dataloader(): + if is_auto_round_available(): + import auto_round + + get_dataloader = auto_round.calib_dataset.get_dataloader + return get_dataloader + else: + raise ImportError( + ( + "This example requires the `auto-round` library." + "Please install it with `pip install git+https://github.com/intel/auto-round.git@patch-for-ao-2`" + ) + ) + + +def singleton(cls): + """Singleton decorator.""" + instances = {} + + def _singleton(*args, **kw): + """Create a singleton object.""" + if cls not in instances: + instances[cls] = cls(*args, **kw) + return instances[cls] + + return _singleton + + +def freeze_random(seed=0): + random.seed(seed) + + torch.manual_seed(seed) + + np.random.seed(seed) + + g = torch.Generator() + g.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def count_tensor_of_type(mod, cls): + res = 0 + for name, param in mod.named_parameters(): + if isinstance(param, cls): + res += 1 + return res + + +def see_memory_usage(message: str = "", force=True): + # Modified from DeepSpeed https://github.com/microsoft/DeepSpeed + import gc + import logging + + import torch.distributed as dist + + if not force: + return + if dist.is_initialized() and not dist.get_rank() == 0: + return + + gc.collect() + + # Print message except when distributed but not rank 0 + logging.warning(message) + bytes_to_gb = 1024 * 1024 * 1024 + logging.warning( + f"AllocatedMem {round(torch.cuda.memory_allocated() / (bytes_to_gb),2 )} GB \ + MaxAllocatedMem {round(torch.cuda.max_memory_allocated() / (bytes_to_gb),2)} GB \ + ReservedMem {round(torch.cuda.memory_reserved() / (bytes_to_gb),2)} GB \ + MaxReservedMem {round(torch.cuda.max_memory_reserved() / (bytes_to_gb))} GB " + ) + + # get the peak memory to report correct data, so reset the counter for the next call + torch.cuda.reset_peak_memory_stats() + + +@torch.no_grad() +def gen_text( + model, tokenizer, msg="", device="cuda", prompt="What's AI?", max_length=20 +): + inputs = tokenizer(prompt, return_tensors="pt") + model = model.to(device) + new_tokens = model.generate(**inputs.to(device), max_length=max_length) + text = tokenizer.decode(new_tokens[0], skip_special_tokens=True) + print(f"Generated text ({msg}): {text}") + + +def gen_example_inputs(tokenizer, device, max_length=20): + inputs = tokenizer( + "What's AI?", return_tensors="pt", padding="max_length", max_length=max_length + ) + input_ids = inputs["input_ids"].to(device) + return (input_ids,) + + +def _auto_detect_decoder_cls(model): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.ModuleList): + first_module = module[0] + return type(first_module) + + +def get_float_model_info(model_name_or_path, torch_dtype=torch.float32): + import transformers + + model = transformers.AutoModelForCausalLM.from_pretrained( + model_name_or_path, torch_dtype=torch_dtype + ) + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) + decoder_cls = _auto_detect_decoder_cls(model) + logging.warning(f"Detected decoder class: {decoder_cls}") + if decoder_cls is None: + raise ValueError( + f"Cannot detect the decoder class from the model, please provide it manually." + ) + return model, tokenizer, decoder_cls + + +def dump_elapsed_time(customized_msg=""): + """Get the elapsed time for decorated functions. + + Args: + customized_msg (string, optional): The parameter passed to decorator. Defaults to None. + """ + import logging + import time + + def f(func): + def fi(*args, **kwargs): + start = time.time() + res = func(*args, **kwargs) + end = time.time() + logging.warning( + "%s elapsed time: %s ms" + % ( + customized_msg if customized_msg else func.__qualname__, + round((end - start) * 1000, 2), + ) + ) + return res + + return fi + + return f