-
Notifications
You must be signed in to change notification settings - Fork 317
SmoothQuant using tensor subclassing #1030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d34859b
df5b49f
847f1f2
f03cfb3
a2518f1
28fb8ce
bada2b0
921efc0
ad5b97e
f1be01d
7ee1f13
c773386
427ff73
9916113
52260b6
ca50fee
3e90789
d47fcc1
a195e73
6627be1
fb981e7
17c374e
bb76de6
98b2de1
316f5ea
b4d8383
aca06d2
dde7545
00cfadd
466d2f1
e970a4a
5e2abbe
90d1b7d
03d490b
f595ed4
2202f69
6ea8aa8
fa1144c
0f79cc3
36b6315
9c2ce60
ff4812f
fb3e6ee
fc85eb5
0fd04e3
deb29a5
5c7dd5d
490dbc5
cb9167a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from copy import deepcopy | ||
import pytest | ||
import torch | ||
import tempfile | ||
from torchao.quantization import quantize_ | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_2, | ||
TORCH_VERSION_AT_LEAST_2_4, | ||
TORCH_VERSION_AT_LEAST_2_5, | ||
) | ||
from torchao.quantization.utils import ( | ||
dynamically_quantize_per_channel, | ||
dequantize_per_channel, | ||
) | ||
from torchao.prototype.smoothquant import ( | ||
insert_smooth_quant_observer_, | ||
smooth_quant, | ||
SmoothQuantObservedLinear, | ||
save_smooth_quant_recipe, | ||
load_smooth_quant_recipe | ||
) | ||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=512, n=256, k=128): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(m, n, bias=False) | ||
self.linear2 = torch.nn.Linear(n, k, bias=False) | ||
self.linear3 = torch.nn.Linear(k, 1, bias=False) | ||
|
||
def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): | ||
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
x = self.linear3(x) | ||
return x | ||
|
||
|
||
bias_list = [True, False] | ||
alpha_list = [None, 0.5, 0.75] | ||
quant_mode_list = ["static", "dynamic"] | ||
devices = ["cpu"] | ||
if torch.cuda.is_available(): | ||
devices.append("cuda") | ||
idtypes = (torch.float, torch.bfloat16, torch.half) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# This test case will trigger recompilation many times, so set a large cache_size_limit here | ||
torch._dynamo.config.cache_size_limit = 128 | ||
|
||
@pytest.mark.parametrize("bias", bias_list) | ||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
def test_compute(bias, alpha, quant_mode, device, idtype): | ||
class Linear(torch.nn.Module): | ||
def __init__(self, bias: bool): | ||
super().__init__() | ||
self.fc = torch.nn.Linear(32, 32, bias) | ||
self.fc.weight.data = torch.randn_like(self.fc.weight.data) | ||
|
||
def forward(self, x): | ||
return self.fc(x) | ||
|
||
m = Linear(bias).eval().to(idtype).to(device) | ||
m_ref = deepcopy(m) | ||
data = torch.randn(2, 32, dtype=idtype, device=device) | ||
|
||
# calibrate | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
m(data) | ||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, smooth_quant(), is_observed_linear) | ||
with torch.inference_mode(): | ||
if TORCH_VERSION_AT_LEAST_2_5: | ||
m = torch.compile(m, fullgraph=True) | ||
out = m(data) | ||
|
||
# reference | ||
weight = m_ref.fc.weight.data.float() | ||
b = m_ref.fc.bias if bias else None | ||
x_abs_max_per_ic = torch.abs(data).max(dim=0).values | ||
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values | ||
smoothing_factor = 1 if alpha is None else ( | ||
torch.pow(x_abs_max_per_ic, alpha) / torch.pow( | ||
w_abs_max_per_ic, 1 - alpha) | ||
) | ||
act = data / smoothing_factor | ||
wei = weight * smoothing_factor | ||
qw, w_scales, w_zps = dynamically_quantize_per_channel( | ||
wei, -127, 127, torch.int8 | ||
) | ||
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) | ||
if quant_mode == "static": | ||
# activation is quantized per-tensor | ||
act_min, act_max = torch.aminmax(act.float()) | ||
max_val_pos = torch.max(-act_min, act_max) | ||
act_scale = max_val_pos / 127.0 | ||
fq_act = torch.quantize_per_tensor( | ||
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 | ||
).dequantize().to(idtype) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
else: | ||
# activation is quantized per-row (batch * sequence_length) | ||
qx, x_scales, x_zps = dynamically_quantize_per_channel( | ||
act.float(), -127, 127, torch.int8 | ||
) | ||
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
|
||
# BFloat16 and Float16 have larger errors | ||
atol = 0.1 if idtype == torch.float else ( | ||
0.2 if idtype == torch.half else 0.3 | ||
) | ||
assert torch.allclose(out, out_ref.to(idtype), atol=atol) | ||
|
||
|
||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
def test_save_load_recipe(alpha, quant_mode, device, idtype): | ||
dataset_size = 20 | ||
l1, l2, l3 = 512, 256, 128 | ||
original_dtype = idtype | ||
n_calib_examples = 10 | ||
sequence_length = 5 | ||
|
||
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) | ||
m_save_load = deepcopy(m) | ||
|
||
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) | ||
calibration_data = dataset[:n_calib_examples] | ||
|
||
# calibrate | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) | ||
|
||
for example in calibration_data: | ||
m(example.to(device)) | ||
m_save_load(example.to(device)) | ||
|
||
with tempfile.NamedTemporaryFile() as fp: | ||
save_path = fp.name | ||
save_smooth_quant_recipe(m_save_load, save_path) | ||
load_smooth_quant_recipe(m_save_load, save_path) | ||
|
||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, smooth_quant(), is_observed_linear) | ||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# earlier versions are not compatible | ||
m = torch.compile(m, fullgraph=True) | ||
m_save_load = torch.compile(m_save_load, fullgraph=True) | ||
out_list = [m(data.squeeze(0)) for data in dataset] | ||
out = torch.cat(out_list) | ||
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] | ||
save_load_out = torch.cat(save_load_out_list) | ||
|
||
assert out is not None | ||
assert save_load_out is not None | ||
assert torch.allclose(out, save_load_out) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# SmothQuant quantization | ||
This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438). | ||
|
||
In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized. | ||
|
||
## Quick start | ||
Run the example code with | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> | ||
# An example | ||
python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic | ||
``` | ||
To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance. | ||
```bash | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --compile | ||
``` | ||
To save a quantized model for reuse, specify `--model-save-path` | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-save-path ./quantized_model.pt | ||
``` | ||
And load it by `--model-load-path` | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-load-path ./quantized_model.pt | ||
``` | ||
|
||
|
||
## Usage of API | ||
The following APIs are provided: | ||
- insert_smooth_quant_observer_ | ||
- smooth_quant | ||
- save_smooth_quant_recipe (advanced) | ||
- load_smooth_quant_recipe (advanced) | ||
|
||
`insert_smooth_quant_observer_` inserts observers into the model to be quantized. For example: | ||
```python | ||
insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") | ||
``` | ||
After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe. | ||
|
||
`smooth_quant` applies SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example: | ||
```python | ||
from torchao.prototype.smoothquant import SmoothQuantObservedLinear | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear) | ||
``` | ||
`is_observed_linear` is a filter so that we only quantize observed linear layers. | ||
|
||
(Advanced) `save_smooth_quant_recipe` and `load_smooth_quant_recipe` saves or loads a recipe for a model. | ||
|
||
A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layers. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed. | ||
|
||
To save a recipe, users should insert observers and run calibration first. For example, | ||
```python | ||
insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") | ||
for data in dataset_for_calibration: | ||
model(data) | ||
save_smooth_quant_recipe(model, "./smooth_quant_recipe.json") | ||
``` | ||
To load a recipe, users should insert observers first. For example, | ||
```python | ||
insert_smooth_quant_observer_(model) | ||
load_smooth_quant_recipe(model, "./smooth_quant_recipe.json") | ||
``` | ||
|
||
## Benchmark | ||
Running the example with `torch.compile` on a NVIDIA A10G GPU. | ||
### meta-llama/Llama-2-7b-hf | ||
Perplexity | ||
| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | | ||
|-|-|-|-|-| | ||
| Dynamic | 8.1872 | 7.4257 | 7.2518 | 7.5509 | | ||
| Static | 43.8051 | 11.2984 | 7.5791 | 19.5050 | | ||
|
||
Note*: Conventional quantization without SmoothQuant | ||
|
||
### meta-llama/Meta-Llama-3-8B | ||
Perplexity | ||
| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | | ||
|-|-|-|-|-| | ||
| Dynamic | 21.2475 | 8.8288 | 9.6514 | 8.3574 | | ||
| Static | 301.7118 | 18.0617 | 10.8343 | 278.9819 | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so looks like it's more effective on static quant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you also do a sanity check for perf to make sure this doesn't regress performance? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For performance, it's found from high to low
It's expected that SmoothQuant is slower because it inserts There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that's fine as long as it's reasonable, it's just a sanity check |
||
|
||
Note*: Conventional quantization without SmoothQuant | ||
|
||
### Test method | ||
**Commands** | ||
```bash | ||
# dynamic quant | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m <model_id> --device=cuda --quant-mode=dynamic --compile | ||
# static quant | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m <model_id> --device=cuda --quant-mode=static --compile | ||
``` | ||
Use `--alpha` to specify the alpha parameter. Add `--disable-smooth-quant` to run quantization without SmoothQuant. | ||
|
||
**Environment** | ||
- AWS g5.12xlarge instance | ||
- torch==2.6.0.dev20241017+cu124 | ||
- python==3.12.6 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .api import ( | ||
insert_smooth_quant_observer_, | ||
smooth_quant, | ||
save_smooth_quant_recipe, | ||
load_smooth_quant_recipe, | ||
) | ||
from .core import SmoothQuantObservedLinear |
Uh oh!
There was an error while loading. Please reload this page.