Skip to content

Commit 4b4dd40

Browse files
authored
Merge branch 'main' into jcaip/fix-readme-links-rd2
2 parents d0b7941 + 923ff4d commit 4b4dd40

File tree

4 files changed

+166
-83
lines changed

4 files changed

+166
-83
lines changed

test/quantization/test_quant_api.py

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,24 @@
1818
get_symmetric_quantization_config,
1919
)
2020

21+
from torchao.quantization.subclass import (
22+
to_aqt,
23+
to_laqt,
24+
AffineQuantizedTensor,
25+
LinearActQuantizedTensor,
26+
)
27+
from torchao.quantization.quant_primitives import (
28+
MappingType,
29+
ZeroPointDomain,
30+
)
31+
2132
from torchao.quantization.quant_api import (
2233
_replace_with_custom_fn_if_matches_filter,
2334
apply_dynamic_quant,
2435
apply_weight_only_int8_quant,
2536
Quantizer,
2637
TwoStepQuantizer,
38+
quantize,
2739
)
2840
from torchao.quantization.utils import (
2941
TORCH_VERSION_AFTER_2_3,
@@ -32,6 +44,7 @@
3244
from pathlib import Path
3345
from sentencepiece import SentencePieceProcessor
3446
from model import Transformer, prepare_inputs_for_model
47+
import copy
3548

3649

3750
def dynamic_quant(model, example_inputs):
@@ -92,8 +105,8 @@ def __init__(self, m=64, n=32, k=64):
92105
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
93106
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
94107

95-
def example_inputs(self):
96-
return (torch.randn(1, self.linear1.in_features).to(torch.float),)
108+
def example_inputs(self, batch_size=1):
109+
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)
97110

98111
def forward(self, x):
99112
x = self.linear1(x)
@@ -395,13 +408,6 @@ def test_eval_wrapper(self):
395408
# TODO: move to a separate test file
396409
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
397410
def test_quantized_tensor_subclass_8da4w(self):
398-
from torchao.quantization.subclass import (
399-
AffineQuantizedTensor,
400-
LinearActQuantizedTensor,
401-
)
402-
from torchao.quantization.quant_primitives import MappingType
403-
import copy
404-
405411
# weight settings
406412
groupsize = 32
407413
mapping_type = MappingType.SYMMETRIC
@@ -423,20 +429,26 @@ def get_per_token_block_size(x):
423429
# input settings
424430
input_mapping_type = MappingType.ASYMMETRIC
425431
input_target_dtype = torch.int8
426-
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
427-
428-
def dynamic_quant(linear):
429-
# note: order is important
430-
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
431-
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
432+
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
432433

433434
m = ToyLinearModel().eval()
434435
m_copy = copy.deepcopy(m)
435436
example_inputs = m.example_inputs()
436-
dynamic_quant(m.linear1)
437-
dynamic_quant(m.linear2)
437+
438+
def apply_weight_quant(weight):
439+
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
440+
441+
def apply_act_quant(weight):
442+
return to_laqt(weight, input_quant_func)
443+
444+
# note: order is important
445+
m = quantize(m, apply_weight_quant)
446+
m = quantize(m, apply_act_quant)
447+
438448
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
439449
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
450+
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
451+
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
440452

441453
# reference
442454
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
@@ -454,11 +466,6 @@ def dynamic_quant(linear):
454466
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
455467
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
456468
def test_quantized_tensor_subclass_int4(self):
457-
from torchao.quantization.subclass import AffineQuantizedTensor
458-
from torchao.quantization.quant_primitives import MappingType
459-
from torchao.quantization.quant_primitives import ZeroPointDomain
460-
import copy
461-
462469
# weight settings
463470
groupsize = 32
464471
mapping_type = MappingType.ASYMMETRIC
@@ -469,22 +476,17 @@ def test_quantized_tensor_subclass_int4(self):
469476
eps = 1e-6
470477
preserve_zero = False
471478
zero_point_dtype = torch.bfloat16
479+
zero_point_domain = ZeroPointDomain.FLOAT
472480

473481
# use 1024 so that we don't need padding
474482
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
475483
m_copy = copy.deepcopy(m)
476484
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
477485

478-
def to_quantized(weight):
479-
return AffineQuantizedTensor.from_float(
480-
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
481-
zero_point_dtype=zero_point_dtype,
482-
preserve_zero=preserve_zero,
483-
zero_point_domain=ZeroPointDomain.FLOAT,
484-
)
486+
def apply_weight_quant(weight):
487+
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
485488

486-
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
487-
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
489+
m = quantize(m, apply_weight_quant)
488490
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
489491
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
490492

@@ -501,10 +503,6 @@ def to_quantized(weight):
501503
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
502504
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
503505
def test_quantized_tensor_subclass_int8(self):
504-
from torchao.quantization.subclass import AffineQuantizedTensor
505-
from torchao.quantization.quant_primitives import MappingType
506-
import copy
507-
508506
# weight settings
509507
mapping_type = MappingType.SYMMETRIC
510508
target_dtype = torch.int8
@@ -515,12 +513,12 @@ def test_quantized_tensor_subclass_int8(self):
515513
m_copy = copy.deepcopy(m)
516514
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
517515

518-
def to_quantized(weight):
516+
def apply_weight_quant(weight):
519517
block_size = (1, weight.shape[1])
520-
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
518+
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
519+
520+
m = quantize(m, apply_weight_quant)
521521

522-
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
523-
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
524522
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
525523
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
526524

@@ -537,12 +535,6 @@ def to_quantized(weight):
537535
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
538536
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
539537
def test_quantized_tensor_subclass_int8_dyn_quant(self):
540-
from torchao.quantization.subclass import AffineQuantizedTensor
541-
from torchao.quantization.subclass import LinearActQuantizedTensor
542-
from torchao.quantization.quant_primitives import MappingType
543-
from torchao.quantization.quant_primitives import ZeroPointDomain
544-
import copy
545-
546538
# weight settings
547539
mapping_type = MappingType.SYMMETRIC
548540
def get_weight_block_size(x):
@@ -563,20 +555,24 @@ def get_per_token_block_size(x):
563555
input_eps = 1e-5
564556
input_quant_min = -127
565557
input_quant_max = 127
566-
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
558+
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
567559

568560
# use 1024 so that we don't need padding
569561
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
570562
m_copy = copy.deepcopy(m)
571-
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
563+
# setting batch_size to 20 to be compatible with the kernel
564+
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))
565+
566+
def apply_weight_quant(weight):
567+
block_size = get_weight_block_size(weight)
568+
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
572569

573-
def dynamic_quant(linear):
574-
# note: order is important
575-
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
576-
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
570+
def apply_act_quant(weight):
571+
return to_laqt(weight, input_quant_func)
572+
573+
m = quantize(m, apply_weight_quant)
574+
m = quantize(m, apply_act_quant)
577575

578-
dynamic_quant(m.linear1)
579-
dynamic_quant(m.linear2)
580576
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
581577
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
582578
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
@@ -591,6 +587,19 @@ def dynamic_quant(linear):
591587

592588
self.assertTrue(torch.equal(res, ref))
593589

590+
# workaround for export path
591+
from torchao.quantization.utils import unwrap_tensor_subclass
592+
m_unwrapped = unwrap_tensor_subclass(m)
593+
594+
m = torch.export.export(m_unwrapped, example_inputs).module()
595+
exported_model_res = m(*example_inputs)
596+
597+
self.assertTrue(torch.equal(exported_model_res, ref))
598+
599+
# make sure it compiles
600+
torch._export.aot_compile(m_unwrapped, example_inputs)
601+
602+
594603

595604
if __name__ == "__main__":
596605
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21+
from typing import Any, Callable
2122

2223
from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
2324
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
@@ -48,7 +49,8 @@
4849
"TwoStepQuantizer",
4950
"Int4WeightOnlyGPTQQuantizer",
5051
"Int4WeightOnlyQuantizer",
51-
"autoquant"
52+
"quantize",
53+
"autoquant",
5254
]
5355

5456
if TORCH_VERSION_AFTER_2_3:
@@ -215,3 +217,49 @@ def replace_conv2d_1x1(conv):
215217
_replace_with_custom_fn_if_matches_filter(
216218
model, replace_conv2d_1x1, filter_fn=filter_fn
217219
)
220+
221+
222+
def _get_linear_subclass_inserter(constructor):
223+
def insert_subclass(lin):
224+
lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
225+
return lin
226+
227+
return insert_subclass
228+
229+
def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module:
230+
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
231+
232+
Args:
233+
model: input model
234+
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance
235+
filter_fn: used to filter out the modules that we don't want to apply tenosr subclass
236+
237+
Example::
238+
239+
# weight settings
240+
groupsize = 32
241+
mapping_type = MappingType.ASYMMETRIC
242+
block_size = (1, groupsize)
243+
target_dtype = torch.int32
244+
quant_min = 0
245+
quant_max = 15
246+
eps = 1e-6
247+
preserve_zero = False
248+
zero_point_dtype = torch.bfloat16
249+
zero_point_domain = ZeroPointDomain.FLOAT
250+
251+
apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
252+
253+
# apply to modules under block0 submodule
254+
def filter_fn(module, fqn):
255+
return fqn == "block0"
256+
257+
m = MyModel(...)
258+
m = quantize(m, apply_weight_quant, filter_fn)
259+
"""
260+
_replace_with_custom_fn_if_matches_filter(
261+
model,
262+
_get_linear_subclass_inserter(apply_tensor_subclass),
263+
_is_linear if filter_fn is None else filter_fn,
264+
)
265+
return model

torchao/quantization/subclass.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"Int8WeightOnlyQuantizedLinearWeight",
3636
"Int4WeightOnlyQuantizedLinearWeight",
3737
"AffineQuantizedTensor",
38+
"LinearActQuantizedTensor",
3839
]
3940

4041

@@ -266,7 +267,6 @@ def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
266267
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]
267268

268269
def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
269-
270270
self.q_scales = q_scales
271271
super().__init__(int_data, transposed)
272272

@@ -629,32 +629,6 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
629629
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
630630
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles
631631

632-
def to_aqt(
633-
input_float,
634-
mapping_type,
635-
block_size,
636-
target_dtype,
637-
quant_min = None,
638-
quant_max = None,
639-
eps = None,
640-
scale_dtype = None,
641-
zero_point_dtype = None,
642-
preserve_zero = True,
643-
zero_point_domain = ZeroPointDomain.INT,
644-
):
645-
return AffineQuantizedTensor.from_float(
646-
input_float,
647-
mapping_type,
648-
block_size,
649-
target_dtype,
650-
quant_min=quant_min,
651-
quant_max=quant_max,
652-
eps=eps,
653-
scale_dtype=scale_dtype,
654-
zero_point_dtype=zero_point_dtype,
655-
preserve_zero=preserve_zero,
656-
zero_point_domain=zero_point_domain
657-
)
658632

659633
# TODO: merge with nf4 implements decorator
660634
# aten op to their __torch_dispatch__ implemnetations for the tensor subclass
@@ -777,7 +751,7 @@ def dequantize(self, output_dtype=None):
777751
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)
778752

779753
def __tensor_flatten__(self):
780-
return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
754+
return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
781755

782756
@classmethod
783757
def __tensor_unflatten__(
@@ -1091,7 +1065,7 @@ def __tensor_unflatten__(
10911065
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
10921066
):
10931067
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
1094-
input_quant_func = tensor_attributes
1068+
input_quant_func, = tensor_attributes
10951069
return cls(
10961070
original_weight_tensor,
10971071
input_quant_func,
@@ -1176,3 +1150,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
11761150
raise NotImplementedError(
11771151
f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
11781152
)
1153+
1154+
to_aqt = AffineQuantizedTensor.from_float
1155+
to_laqt = LinearActQuantizedTensor.from_float

0 commit comments

Comments
 (0)