20
20
DynamicallyPerAxisQuantizedLinear ,
21
21
)
22
22
from torchao .quantization .quant_api import (
23
- apply_dynamic_quant ,
24
- apply_weight_only_int8_quant ,
23
+ int4wo ,
24
+ int8wo ,
25
+ int8da_int8w ,
26
+ quantize ,
27
+ _replace_with_custom_fn_if_matches_filter ,
28
+ )
29
+ # APIs to be deprecated (used for torch 2.2.2 and 2.3)
30
+ from torchao .quantization .quant_api import (
25
31
change_linear_weights_to_int8_dqtensors ,
26
32
change_linear_weights_to_int8_woqtensors ,
27
33
change_linear_weights_to_int4_woqtensors ,
28
- _replace_with_custom_fn_if_matches_filter ,
29
34
)
30
35
from torchao .quantization .quant_primitives import (
31
36
safe_int_mm ,
73
78
from parameterized import parameterized
74
79
import itertools
75
80
import logging
76
- from torchao .utils import TORCH_VERSION_AFTER_2_3 , TORCH_VERSION_AFTER_2_4
81
+ from torchao .utils import (
82
+ TORCH_VERSION_AFTER_2_3 ,
83
+ TORCH_VERSION_AFTER_2_4 ,
84
+ unwrap_tensor_subclass ,
85
+ )
77
86
78
87
logger = logging .getLogger ("INFO" )
79
88
80
89
torch .manual_seed (0 )
81
90
config .cache_size_limit = 100
82
91
83
- # TODO: use this to reduce the number of tests
84
- TENSOR_SUBCLASS_APIS = [
85
- change_linear_weights_to_int8_dqtensors ,
86
- change_linear_weights_to_int8_woqtensors ,
87
- change_linear_weights_to_int4_woqtensors ,
88
- ]
89
-
90
92
COMMON_DEVICES = ["cpu" , "cuda" ]
91
93
92
94
COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
93
95
94
96
COMMON_DEVICE_DTYPE = list (itertools .product (COMMON_DEVICES , COMMON_DTYPES )).copy ()
95
97
98
+ def _int8wo_api (mod ):
99
+ if TORCH_VERSION_AFTER_2_4 :
100
+ quantize (mod , int8wo ())
101
+ unwrap_tensor_subclass (mod )
102
+ else :
103
+ change_linear_weights_to_int8_woqtensors (mod )
104
+
105
+ def _int8da_int8w_api (mod ):
106
+ if TORCH_VERSION_AFTER_2_4 :
107
+ quantize (mod , int8da_int8w ())
108
+ unwrap_tensor_subclass (mod )
109
+ else :
110
+ change_linear_weights_to_int8_dqtensors (mod )
111
+
112
+ def _int4wo_api (mod ):
113
+ if TORCH_VERSION_AFTER_2_4 :
114
+ quantize (mod , int4wo ())
115
+ unwrap_tensor_subclass (mod )
116
+ else :
117
+ change_linear_weights_to_int4_woqtensors (mod )
118
+
119
+ # TODO: use this to reduce the number of tests
120
+ TENSOR_SUBCLASS_APIS = [
121
+ _int8wo_api ,
122
+ _int8da_int8w_api ,
123
+ _int4wo_api ,
124
+ ]
125
+
126
+
96
127
def combine_parameters (a , b ):
97
128
new_tuples = []
98
129
for (tuple1 , tuple2 ) in itertools .product (a , b ):
@@ -756,13 +787,13 @@ def _test_lin_weight_subclass_api_impl(
756
787
@unittest .skipIf (TORCH_VERSION_AFTER_2_4 , "skip because there is some bug in inductor codegen" )
757
788
def test_int8_dynamic_quant_subclass_api (self , device , dtype ):
758
789
self ._test_lin_weight_subclass_api_impl (
759
- change_linear_weights_to_int8_dqtensors , device , 35 , test_dtype = dtype
790
+ _int8da_int8w_api , device , 35 , test_dtype = dtype
760
791
)
761
792
762
793
@parameterized .expand (COMMON_DEVICE_DTYPE )
763
794
def test_int8_weight_only_quant_subclass_api (self , device , dtype ):
764
795
self ._test_lin_weight_subclass_api_impl (
765
- change_linear_weights_to_int8_woqtensors , device , 40 , test_dtype = dtype
796
+ _int8wo_api , device , 40 , test_dtype = dtype
766
797
)
767
798
768
799
@parameterized .expand (COMMON_DEVICE_DTYPE )
@@ -772,7 +803,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
772
803
self .skipTest (f"Fails for { dtype } " )
773
804
for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 256 )] if device == 'cuda' else [])):
774
805
self ._test_lin_weight_subclass_api_impl (
775
- change_linear_weights_to_int4_woqtensors ,
806
+ _int4wo_api ,
776
807
device ,
777
808
15 ,
778
809
test_shape = test_shape ,
@@ -788,8 +819,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
788
819
for groupsize in [64 , 32 ]:
789
820
for inner_k_tiles in [4 , 2 ]:
790
821
kwargs = {"groupsize" : groupsize , "inner_k_tiles" : inner_k_tiles }
822
+
823
+ def api (mod ):
824
+ if TORCH_VERSION_AFTER_2_4 :
825
+ quantize (mod , int4wo (** kwargs ))
826
+ unwrap_tensor_subclass (mod )
827
+ else :
828
+ change_linear_weights_to_int4_woqtensors (mod , ** kwargs )
829
+
791
830
self ._test_lin_weight_subclass_api_impl (
792
- lambda mod : change_linear_weights_to_int4_woqtensors ( mod , ** kwargs ) ,
831
+ api ,
793
832
device ,
794
833
15 ,
795
834
test_shape = test_shape ,
@@ -804,7 +843,7 @@ def test_dynamic_quant(self):
804
843
m = nn .Sequential (nn .Linear (K , N ))
805
844
806
845
y_ref = m (x )
807
- apply_dynamic_quant ( m )
846
+ quantize ( m , int8da_int8w () )
808
847
y_test = m (x )
809
848
810
849
sqnr = compute_error (y_ref , y_test )
@@ -818,7 +857,7 @@ def test_weight_only_quant(self):
818
857
x = torch .randn (* x_shape )
819
858
m = nn .Sequential (nn .Linear (4 , 5 ))
820
859
y_ref = m (x )
821
- apply_weight_only_int8_quant (m )
860
+ _int8wo_api (m )
822
861
y_wo = m (x )
823
862
sqnr = compute_error (y_ref , y_wo )
824
863
self .assertGreater (sqnr , 44.0 )
@@ -841,7 +880,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
841
880
x = torch .randn (* x_shape ).to (device ).to (dtype )
842
881
m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
843
882
y_ref = m (x )
844
- apply_weight_only_int8_quant (m )
883
+ _int8wo_api (m )
845
884
m (x )
846
885
m_c = torch .compile (m , mode = "max-autotune" )
847
886
y_wo , (code ,) = run_and_get_code (m_c , x )
@@ -868,7 +907,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
868
907
x = torch .randn (* x_shape ).to (device ).to (dtype )
869
908
m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
870
909
y_ref = m (x )
871
- apply_weight_only_int8_quant (m )
910
+ _int8wo_api (m )
872
911
m_c = torch .compile (m , mode = "max-autotune" )
873
912
y_wo , (code ,) = run_and_get_code (m_c , x )
874
913
sqnr = compute_error (y_ref , y_wo )
@@ -909,6 +948,7 @@ def forward(self, x):
909
948
910
949
# save quantized state_dict
911
950
api (model )
951
+
912
952
torch .save (model .state_dict (), "test.pth" )
913
953
# get quantized reference
914
954
model_qc = torch .compile (model , mode = "max-autotune" )
@@ -924,6 +964,7 @@ def forward(self, x):
924
964
# load quantized state_dict
925
965
state_dict = torch .load ("test.pth" , mmap = True )
926
966
os .remove ("test.pth" )
967
+
927
968
model .load_state_dict (state_dict , assign = True )
928
969
model = model .to (device = test_device , dtype = test_dtype ).eval ()
929
970
@@ -939,20 +980,20 @@ def forward(self, x):
939
980
def test_save_load_dqtensors (self , device , dtype ):
940
981
if device == "cpu" :
941
982
self .skipTest (f"indcutor failed for cpu right now" )
942
- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_dqtensors , device , test_dtype = dtype )
983
+ self ._test_handle_save_load_meta_impl (_int8da_int8w_api , device , test_dtype = dtype )
943
984
944
985
@parameterized .expand (COMMON_DEVICE_DTYPE )
945
986
@torch .no_grad ()
946
987
def test_save_load_int8woqtensors (self , device , dtype ):
947
- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_woqtensors , device , test_dtype = dtype )
988
+ self ._test_handle_save_load_meta_impl (_int8wo_api , device , test_dtype = dtype )
948
989
949
990
@parameterized .expand (COMMON_DEVICE_DTYPE )
950
991
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "int4 requires torch nightly." )
951
992
@torch .no_grad ()
952
993
def test_save_load_int4woqtensors (self , device , dtype ):
953
994
if dtype != torch .bfloat16 :
954
995
self .skipTest (f"Fails for { dtype } " )
955
- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int4_woqtensors , device , 20 , test_dtype = dtype )
996
+ self ._test_handle_save_load_meta_impl (_int4wo_api , device , 20 , test_dtype = dtype )
956
997
957
998
958
999
class TorchCompileUnitTest (unittest .TestCase ):
@@ -1271,8 +1312,7 @@ def forward(self, x):
1271
1312
model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
1272
1313
ref_f = model (x )
1273
1314
1274
- kwargs = {"dtype" : test_dtype }
1275
- api (model , ** kwargs )
1315
+ api (model )
1276
1316
1277
1317
# running model
1278
1318
model (x )
@@ -1317,8 +1357,7 @@ def forward(self, x):
1317
1357
model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
1318
1358
ref_f = model (x )
1319
1359
1320
- kwargs = {"dtype" : test_dtype }
1321
- api (model , ** kwargs )
1360
+ api (model )
1322
1361
1323
1362
# running model
1324
1363
ref = model (x )
0 commit comments