Skip to content

Commit f2b9890

Browse files
committed
Deprecate top level quantization APIs
Summary: This PR deprecates a few quantization APIs and here are the bc-breaking notes: 1. int8 weight only quantization int8 weight only quant module swap API ``` apply_weight_only_int8_quant(model) ``` and int8 weight only tensor subclass API ``` change_linear_weights_to_int8_woqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8wo_quant())) ``` 2. int8 dynamic quantization ``` apply_dynamic_quant(model) ``` or ``` change_linear_weights_to_int8_dqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8dyn_quant())) ``` 3. int4 weight only quantization ``` change_linear_weights_to_int4_wotensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int4wo_quant())) ``` Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 950a893 commit f2b9890

File tree

10 files changed

+370
-313
lines changed

10 files changed

+370
-313
lines changed

test/dtypes/test_aq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
TestCase,
33
run_tests,
44
)
5-
from torchao.quantization.quant_api import get_apply_int4wo_quant
5+
from torchao.quantization.quant_api import int4wo
66
import torch
77
import unittest
88

@@ -12,7 +12,7 @@ class TestAQ(TestCase):
1212
def test_tensor_core_layout_transpose(self):
1313
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
1414
shape = t.shape
15-
apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32)
15+
apply_int4wo_quant = int4wo(groupsize=32)
1616
aqt = apply_int4wo_quant(t)
1717
aqt_shape = aqt.shape
1818
self.assertEqual(aqt_shape, shape)

test/integration/test_integration.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
2222
from torchao.quantization.quant_api import (
23-
apply_dynamic_quant,
24-
apply_weight_only_int8_quant,
23+
int4wo,
24+
int8wo,
25+
int8da_int8w,
26+
quantize,
27+
_replace_with_custom_fn_if_matches_filter,
28+
)
29+
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
30+
from torchao.quantization.quant_api import (
2531
change_linear_weights_to_int8_dqtensors,
2632
change_linear_weights_to_int8_woqtensors,
2733
change_linear_weights_to_int4_woqtensors,
28-
_replace_with_custom_fn_if_matches_filter,
2934
)
3035
from torchao.quantization.quant_primitives import (
3136
safe_int_mm,
@@ -73,26 +78,52 @@
7378
from parameterized import parameterized
7479
import itertools
7580
import logging
76-
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
81+
from torchao.utils import (
82+
TORCH_VERSION_AFTER_2_3,
83+
TORCH_VERSION_AFTER_2_4,
84+
unwrap_tensor_subclass,
85+
)
7786

7887
logger = logging.getLogger("INFO")
7988

8089
torch.manual_seed(0)
8190
config.cache_size_limit = 100
8291

83-
# TODO: use this to reduce the number of tests
84-
TENSOR_SUBCLASS_APIS = [
85-
change_linear_weights_to_int8_dqtensors,
86-
change_linear_weights_to_int8_woqtensors,
87-
change_linear_weights_to_int4_woqtensors,
88-
]
89-
9092
COMMON_DEVICES = ["cpu", "cuda"]
9193

9294
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
9395

9496
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
9597

98+
def _int8wo_api(mod):
99+
if TORCH_VERSION_AFTER_2_4:
100+
quantize(mod, int8wo())
101+
unwrap_tensor_subclass(mod)
102+
else:
103+
change_linear_weights_to_int8_woqtensors(mod)
104+
105+
def _int8da_int8w_api(mod):
106+
if TORCH_VERSION_AFTER_2_4:
107+
quantize(mod, int8da_int8w())
108+
unwrap_tensor_subclass(mod)
109+
else:
110+
change_linear_weights_to_int8_dqtensors(mod)
111+
112+
def _int4wo_api(mod):
113+
if TORCH_VERSION_AFTER_2_4:
114+
quantize(mod, int4wo())
115+
unwrap_tensor_subclass(mod)
116+
else:
117+
change_linear_weights_to_int4_woqtensors(mod)
118+
119+
# TODO: use this to reduce the number of tests
120+
TENSOR_SUBCLASS_APIS = [
121+
_int8wo_api,
122+
_int8da_int8w_api,
123+
_int4wo_api,
124+
]
125+
126+
96127
def combine_parameters(a, b):
97128
new_tuples = []
98129
for (tuple1, tuple2) in itertools.product(a, b):
@@ -756,13 +787,13 @@ def _test_lin_weight_subclass_api_impl(
756787
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
757788
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
758789
self._test_lin_weight_subclass_api_impl(
759-
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
790+
_int8da_int8w_api, device, 35, test_dtype=dtype
760791
)
761792

762793
@parameterized.expand(COMMON_DEVICE_DTYPE)
763794
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
764795
self._test_lin_weight_subclass_api_impl(
765-
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
796+
_int8wo_api, device, 40, test_dtype=dtype
766797
)
767798

768799
@parameterized.expand(COMMON_DEVICE_DTYPE)
@@ -772,7 +803,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
772803
self.skipTest(f"Fails for {dtype}")
773804
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
774805
self._test_lin_weight_subclass_api_impl(
775-
change_linear_weights_to_int4_woqtensors,
806+
_int4wo_api,
776807
device,
777808
15,
778809
test_shape=test_shape,
@@ -788,8 +819,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
788819
for groupsize in [64, 32]:
789820
for inner_k_tiles in [4, 2]:
790821
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
822+
823+
def api(mod):
824+
if TORCH_VERSION_AFTER_2_4:
825+
quantize(mod, int4wo(**kwargs))
826+
unwrap_tensor_subclass(mod)
827+
else:
828+
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
829+
791830
self._test_lin_weight_subclass_api_impl(
792-
lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs),
831+
api,
793832
device,
794833
15,
795834
test_shape=test_shape,
@@ -804,7 +843,7 @@ def test_dynamic_quant(self):
804843
m = nn.Sequential(nn.Linear(K, N))
805844

806845
y_ref = m(x)
807-
apply_dynamic_quant(m)
846+
quantize(m, int8da_int8w())
808847
y_test = m(x)
809848

810849
sqnr = compute_error(y_ref, y_test)
@@ -818,7 +857,7 @@ def test_weight_only_quant(self):
818857
x = torch.randn(*x_shape)
819858
m = nn.Sequential(nn.Linear(4, 5))
820859
y_ref = m(x)
821-
apply_weight_only_int8_quant(m)
860+
_int8wo_api(m)
822861
y_wo = m(x)
823862
sqnr = compute_error(y_ref, y_wo)
824863
self.assertGreater(sqnr, 44.0)
@@ -841,7 +880,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
841880
x = torch.randn(*x_shape).to(device).to(dtype)
842881
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
843882
y_ref = m(x)
844-
apply_weight_only_int8_quant(m)
883+
_int8wo_api(m)
845884
m(x)
846885
m_c = torch.compile(m, mode="max-autotune")
847886
y_wo, (code,) = run_and_get_code(m_c, x)
@@ -868,7 +907,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
868907
x = torch.randn(*x_shape).to(device).to(dtype)
869908
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
870909
y_ref = m(x)
871-
apply_weight_only_int8_quant(m)
910+
_int8wo_api(m)
872911
m_c = torch.compile(m, mode="max-autotune")
873912
y_wo, (code,) = run_and_get_code(m_c, x)
874913
sqnr = compute_error(y_ref, y_wo)
@@ -909,6 +948,7 @@ def forward(self, x):
909948

910949
# save quantized state_dict
911950
api(model)
951+
912952
torch.save(model.state_dict(), "test.pth")
913953
# get quantized reference
914954
model_qc = torch.compile(model, mode="max-autotune")
@@ -924,6 +964,7 @@ def forward(self, x):
924964
# load quantized state_dict
925965
state_dict = torch.load("test.pth", mmap=True)
926966
os.remove("test.pth")
967+
927968
model.load_state_dict(state_dict, assign=True)
928969
model = model.to(device=test_device, dtype=test_dtype).eval()
929970

@@ -939,20 +980,20 @@ def forward(self, x):
939980
def test_save_load_dqtensors(self, device, dtype):
940981
if device == "cpu":
941982
self.skipTest(f"indcutor failed for cpu right now")
942-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
983+
self._test_handle_save_load_meta_impl(_int8da_int8w_api, device, test_dtype=dtype)
943984

944985
@parameterized.expand(COMMON_DEVICE_DTYPE)
945986
@torch.no_grad()
946987
def test_save_load_int8woqtensors(self, device, dtype):
947-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)
988+
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)
948989

949990
@parameterized.expand(COMMON_DEVICE_DTYPE)
950991
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
951992
@torch.no_grad()
952993
def test_save_load_int4woqtensors(self, device, dtype):
953994
if dtype != torch.bfloat16:
954995
self.skipTest(f"Fails for {dtype}")
955-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype)
996+
self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype)
956997

957998

958999
class TorchCompileUnitTest(unittest.TestCase):
@@ -1271,8 +1312,7 @@ def forward(self, x):
12711312
model = test_model().to(dtype=test_dtype, device=test_device).eval()
12721313
ref_f = model(x)
12731314

1274-
kwargs = {"dtype": test_dtype}
1275-
api(model, **kwargs)
1315+
api(model)
12761316

12771317
# running model
12781318
model(x)
@@ -1317,8 +1357,7 @@ def forward(self, x):
13171357
model = test_model().to(dtype=test_dtype, device=test_device).eval()
13181358
ref_f = model(x)
13191359

1320-
kwargs = {"dtype": test_dtype}
1321-
api(model, **kwargs)
1360+
api(model)
13221361

13231362
# running model
13241363
ref = model(x)

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_inference_compile_simple(elem_dtype):
189189
if elem_dtype is torch.float8_e4m3fn:
190190
assert sqnr >= 20.0
191191
else:
192-
assert sqnr >= 14.0
192+
assert sqnr >= 13.5
193193

194194

195195
def test_filter_fn():

test/quantization/test_quant_api.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,19 @@
2828
ZeroPointDomain,
2929
)
3030
from torchao.quantization.subclass import (
31-
to_laq,
3231
LinearActQuantizedTensor,
3332
Int8WeightOnlyQuantizedLinearWeight,
3433
Int4WeightOnlyQuantizedLinearWeight,
3534
)
3635
from torchao.quantization.quant_api import (
3736
_replace_with_custom_fn_if_matches_filter,
38-
apply_dynamic_quant,
39-
apply_weight_only_int8_quant,
4037
Quantizer,
4138
TwoStepQuantizer,
4239
quantize,
43-
get_apply_8da4w_quant,
44-
get_apply_int4wo_quant,
45-
get_apply_int8wo_quant,
46-
get_apply_int8dyn_quant,
40+
int8da_int4w,
41+
int4wo,
42+
int8wo,
43+
int8da_int8w,
4744
)
4845
from torchao.utils import (
4946
TORCH_VERSION_AFTER_2_3,
@@ -52,7 +49,9 @@
5249
from pathlib import Path
5350
from torchao._models.llama.tokenizer import get_tokenizer
5451
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
52+
from torchao.utils import unwrap_tensor_subclass
5553
import copy
54+
import tempfile
5655

5756

5857
def dynamic_quant(model, example_inputs):
@@ -62,20 +61,6 @@ def dynamic_quant(model, example_inputs):
6261
m = convert_pt2e(m)
6362
return m
6463

65-
def _apply_dynamic_quant(model):
66-
"""
67-
Applies dynamic symmetric per-token activation and per-channel weight
68-
quantization to all linear layers in the given model using
69-
module swaps.
70-
"""
71-
_replace_with_custom_fn_if_matches_filter(
72-
model,
73-
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
74-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
75-
)
76-
return model
77-
78-
7964
def capture_and_prepare(model, example_inputs):
8065
m = torch.export.export(model, example_inputs)
8166
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
@@ -104,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
10489

10590
class TorchCompileDynamicQuantizer(Quantizer):
10691
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
107-
apply_dynamic_quant(model)
92+
quantize(model, int8da_int8w())
10893
return model
10994

11095
class ToyLinearModel(torch.nn.Module):
@@ -167,7 +152,7 @@ class TestQuantFlow(unittest.TestCase):
167152
def test_dynamic_quant_gpu_singleline(self):
168153
m = ToyLinearModel().eval()
169154
example_inputs = m.example_inputs()
170-
m = _apply_dynamic_quant(m)
155+
m = quantize(m, int8da_int8w())
171156
quantized = m(*example_inputs)
172157
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
173158
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -203,18 +188,28 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
203188
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
204189

205190
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
191+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "only works for torch 2.4+")
206192
def test_int8_wo_quant_save_load(self):
193+
from torchao.quantization.quant_api import (
194+
change_linear_weights_to_int8_woqtensors,
195+
)
207196
m = ToyLinearModel().eval().cpu()
208-
apply_weight_only_int8_quant(m)
197+
def api(model):
198+
model = quantize(model, int8wo())
199+
unwrap_tensor_subclass(model)
200+
201+
api(m)
202+
209203
example_inputs = m.example_inputs()
210204
ref = m(*example_inputs)
211-
_TMP_FN = "_test.pt"
212-
torch.save(m.state_dict(), _TMP_FN)
205+
with tempfile.NamedTemporaryFile() as f:
206+
torch.save(m.state_dict(), f)
207+
f.seek(0)
208+
state_dict = torch.load(f)
209+
210+
m2 = ToyLinearModel().eval().cpu()
211+
api(m2)
213212

214-
state_dict = torch.load(_TMP_FN)
215-
os.remove(_TMP_FN)
216-
m2 = ToyLinearModel().eval()
217-
apply_weight_only_int8_quant(m2)
218213
m2.load_state_dict(state_dict)
219214
m2 = m2.to(device="cuda")
220215
example_inputs = map(lambda x: x.cuda(), example_inputs)
@@ -508,7 +503,7 @@ def test_quantized_tensor_subclass_8da4w(self):
508503
m = ToyLinearModel().eval()
509504
m_copy = copy.deepcopy(m)
510505
example_inputs = m.example_inputs()
511-
m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize))
506+
m = quantize(m, int8da_int4w(groupsize=groupsize))
512507

513508
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
514509
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -537,7 +532,7 @@ def test_quantized_tensor_subclass_int4(self):
537532
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
538533

539534
groupsize = 32
540-
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
535+
m = quantize(m, int4wo(groupsize=groupsize))
541536
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
542537
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
543538

@@ -557,7 +552,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
557552
m_copy = copy.deepcopy(m)
558553
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
559554

560-
m = quantize(m, get_apply_int8wo_quant())
555+
m = quantize(m, int8wo())
561556

562557
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
563558
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
@@ -580,7 +575,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
580575
m_copy = copy.deepcopy(m)
581576
# setting batch_size to 20 to be compatible with the kernel
582577
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
583-
m = quantize(m, get_apply_int8dyn_quant())
578+
m = quantize(m, int8da_int8w())
584579

585580
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
586581
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)

0 commit comments

Comments
 (0)