4
4
5
5
import torch
6
6
7
- from torchao .prototype .uintx import uintx_affine_weight_only , to_uintx
8
- from torchao .quantization .quant_api import quantize_
7
+ from torchao .dtypes .uintx . Uintx import to_uintx
8
+ from torchao .quantization .quant_api import quantize_ , uintx_weight_only
9
9
from torchao .utils import TORCH_VERSION_AFTER_2_5
10
10
11
11
from torchao .quantization .quant_primitives import (
12
- MappingType ,
13
- ZeroPointDomain ,
14
- choose_qparams_affine ,
15
- quantize_affine ,
16
- dequantize_affine ,
17
- )
12
+ MappingType ,
13
+ ZeroPointDomain ,
14
+ choose_qparams_affine ,
15
+ quantize_affine ,
16
+ dequantize_affine ,
17
+ )
18
18
19
- bit_sizes = (1 ,2 , 3 , 4 , 5 , 6 , 7 )
20
- group_sizes = [32 ,64 ,128 ]
19
+ bit_widths = (1 , 2 , 3 , 4 , 5 , 6 , 7 )
20
+ group_sizes = [32 , 64 , 128 ]
21
21
devices = ["cpu" , "cuda" ]
22
22
@pytest .fixture (autouse = True )
23
23
def run_before_and_after_tests ():
24
24
yield
25
25
torch ._dynamo .reset () # reset cache between tests
26
26
27
-
28
-
29
27
class Linear16 (torch .nn .Module ):
30
28
def __init__ (self , scale , device ):
31
29
super ().__init__ ()
@@ -37,52 +35,52 @@ def __init__(self, scale, device):
37
35
38
36
def forward (self , x ):
39
37
return self .net (x )
40
-
41
- @pytest .mark .parametrize ("bit_size " , bit_sizes )
38
+
39
+ @pytest .mark .parametrize ("bit_width " , bit_widths )
42
40
@pytest .mark .parametrize ("group_size" , group_sizes )
43
41
@pytest .mark .parametrize ("device" , devices )
44
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
42
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
45
43
@pytest .mark .skipif (not TORCH_VERSION_AFTER_2_5 , reason = "only works with fix in the nightly build" )
46
- def test_uintx_affine_weight_only_model_quant ( bit_size , group_size , device ):
44
+ def test_uintx_weight_only_model_quant ( bit_width , group_size , device ):
47
45
scale = 512
48
46
fp16 = Linear16 (scale , device )
49
- quantize_ (fp16 , uintx_affine_weight_only ( bit_size , group_size = group_size ))
47
+ quantize_ (fp16 , uintx_weight_only ( bit_width , group_size = group_size ))
50
48
uintx = torch .compile (fp16 , fullgraph = True )
51
49
test_input = torch .randn (scale * 2 , dtype = torch .float16 , device = device )
52
50
output = uintx .forward (test_input )
53
51
assert output != None , "model quantization failed"
54
-
55
- @pytest .mark .parametrize ("bit_size " , bit_sizes )
52
+
53
+ @pytest .mark .parametrize ("bit_width " , bit_widths )
56
54
@pytest .mark .parametrize ("group_size" , group_sizes )
57
55
@pytest .mark .parametrize ("device" , devices )
58
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
56
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
59
57
@pytest .mark .skipif (not TORCH_VERSION_AFTER_2_5 , reason = "only works with fix in the nightly build" )
60
- def test_uintx_affine_weight_only_quant ( bit_size , group_size , device ):
61
- input_float = torch .randn ((1 ,256 ), dtype = torch .float16 , device = device )
58
+ def test_uintx_weight_only_quant ( bit_width , group_size , device ):
59
+ input_float = torch .randn ((1 , 256 ), dtype = torch .float16 , device = device )
62
60
mapping_type = MappingType .SYMMETRIC
63
61
quant_min = 0
64
- quant_max = 2 ** bit_size - 1
62
+ quant_max = 2 ** bit_width - 1
65
63
eps = torch .finfo (torch .float32 ).eps
66
64
zero_point_dtype = torch .int32
67
65
zero_point_domain = ZeroPointDomain .INT
68
66
target_dtype = torch .uint8
69
67
block_size = (1 , group_size )
70
-
68
+
71
69
scale , zero_point = choose_qparams_affine (
72
- input_float , mapping_type , block_size ,
73
- target_dtype , quant_min , quant_max , eps , torch .float32 ,
74
- zero_point_dtype , True , zero_point_domain
70
+ input_float , mapping_type , block_size ,
71
+ target_dtype , quant_min , quant_max , eps , torch .float32 ,
72
+ zero_point_dtype , True , zero_point_domain
75
73
)
76
-
74
+
77
75
aqt = quantize_affine (
78
76
input_float , block_size , scale ,
79
77
zero_point , target_dtype ,
80
78
quant_min = quant_min ,
81
79
quant_max = quant_max ,
82
80
zero_point_domain = zero_point_domain
83
- )
84
-
85
- q = to_uintx (aqt , bit_size , - 1 )
81
+ )
82
+
83
+ q = to_uintx (aqt , bit_width , - 1 )
86
84
assert q != None , "quantization failed"
87
85
deqaunt = dequantize_affine (
88
86
q , block_size , scale ,
0 commit comments