18
18
get_symmetric_quantization_config ,
19
19
)
20
20
21
+ from torchao .quantization .subclass import (
22
+ to_aqt ,
23
+ to_laqt ,
24
+ AffineQuantizedTensor ,
25
+ LinearActQuantizedTensor ,
26
+ )
27
+ from torchao .quantization .quant_primitives import (
28
+ MappingType ,
29
+ ZeroPointDomain ,
30
+ )
31
+
21
32
from torchao .quantization .quant_api import (
22
33
_replace_with_custom_fn_if_matches_filter ,
23
34
apply_dynamic_quant ,
24
35
apply_weight_only_int8_quant ,
25
36
Quantizer ,
26
37
TwoStepQuantizer ,
38
+ quantize ,
27
39
)
28
40
from torchao .quantization .utils import (
29
41
TORCH_VERSION_AFTER_2_3 ,
32
44
from pathlib import Path
33
45
from sentencepiece import SentencePieceProcessor
34
46
from model import Transformer , prepare_inputs_for_model
47
+ import copy
35
48
36
49
37
50
def dynamic_quant (model , example_inputs ):
@@ -92,8 +105,8 @@ def __init__(self, m=64, n=32, k=64):
92
105
self .linear1 = torch .nn .Linear (m , n , bias = False ).to (torch .float )
93
106
self .linear2 = torch .nn .Linear (n , k , bias = False ).to (torch .float )
94
107
95
- def example_inputs (self ):
96
- return (torch .randn (1 , self .linear1 .in_features ).to (torch .float ),)
108
+ def example_inputs (self , batch_size = 1 ):
109
+ return (torch .randn (batch_size , self .linear1 .in_features ).to (torch .float ),)
97
110
98
111
def forward (self , x ):
99
112
x = self .linear1 (x )
@@ -395,13 +408,6 @@ def test_eval_wrapper(self):
395
408
# TODO: move to a separate test file
396
409
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
397
410
def test_quantized_tensor_subclass_8da4w (self ):
398
- from torchao .quantization .subclass import (
399
- AffineQuantizedTensor ,
400
- LinearActQuantizedTensor ,
401
- )
402
- from torchao .quantization .quant_primitives import MappingType
403
- import copy
404
-
405
411
# weight settings
406
412
groupsize = 32
407
413
mapping_type = MappingType .SYMMETRIC
@@ -423,20 +429,26 @@ def get_per_token_block_size(x):
423
429
# input settings
424
430
input_mapping_type = MappingType .ASYMMETRIC
425
431
input_target_dtype = torch .int8
426
- input_quant_func = lambda x : AffineQuantizedTensor .from_float (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype )
427
-
428
- def dynamic_quant (linear ):
429
- # note: order is important
430
- linear .weight = torch .nn .Parameter (AffineQuantizedTensor .from_float (linear .weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps ), requires_grad = False )
431
- linear .weight = torch .nn .Parameter (LinearActQuantizedTensor .from_float (linear .weight , input_quant_func ), requires_grad = False )
432
+ input_quant_func = lambda x : to_aqt (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype )
432
433
433
434
m = ToyLinearModel ().eval ()
434
435
m_copy = copy .deepcopy (m )
435
436
example_inputs = m .example_inputs ()
436
- dynamic_quant (m .linear1 )
437
- dynamic_quant (m .linear2 )
437
+
438
+ def apply_weight_quant (weight ):
439
+ return to_aqt (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps )
440
+
441
+ def apply_act_quant (weight ):
442
+ return to_laqt (weight , input_quant_func )
443
+
444
+ # note: order is important
445
+ m = quantize (m , apply_weight_quant )
446
+ m = quantize (m , apply_act_quant )
447
+
438
448
assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
439
449
assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
450
+ assert isinstance (m .linear1 .weight .original_weight_tensor , AffineQuantizedTensor )
451
+ assert isinstance (m .linear2 .weight .original_weight_tensor , AffineQuantizedTensor )
440
452
441
453
# reference
442
454
from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
@@ -454,11 +466,6 @@ def dynamic_quant(linear):
454
466
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
455
467
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
456
468
def test_quantized_tensor_subclass_int4 (self ):
457
- from torchao .quantization .subclass import AffineQuantizedTensor
458
- from torchao .quantization .quant_primitives import MappingType
459
- from torchao .quantization .quant_primitives import ZeroPointDomain
460
- import copy
461
-
462
469
# weight settings
463
470
groupsize = 32
464
471
mapping_type = MappingType .ASYMMETRIC
@@ -469,22 +476,17 @@ def test_quantized_tensor_subclass_int4(self):
469
476
eps = 1e-6
470
477
preserve_zero = False
471
478
zero_point_dtype = torch .bfloat16
479
+ zero_point_domain = ZeroPointDomain .FLOAT
472
480
473
481
# use 1024 so that we don't need padding
474
482
m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
475
483
m_copy = copy .deepcopy (m )
476
484
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs ()))
477
485
478
- def to_quantized (weight ):
479
- return AffineQuantizedTensor .from_float (
480
- weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps ,
481
- zero_point_dtype = zero_point_dtype ,
482
- preserve_zero = preserve_zero ,
483
- zero_point_domain = ZeroPointDomain .FLOAT ,
484
- )
486
+ def apply_weight_quant (weight ):
487
+ return to_aqt (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain )
485
488
486
- m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
487
- m .linear2 .weight = torch .nn .Parameter (to_quantized (m .linear2 .weight ), requires_grad = False )
489
+ m = quantize (m , apply_weight_quant )
488
490
assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
489
491
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
490
492
@@ -501,10 +503,6 @@ def to_quantized(weight):
501
503
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
502
504
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
503
505
def test_quantized_tensor_subclass_int8 (self ):
504
- from torchao .quantization .subclass import AffineQuantizedTensor
505
- from torchao .quantization .quant_primitives import MappingType
506
- import copy
507
-
508
506
# weight settings
509
507
mapping_type = MappingType .SYMMETRIC
510
508
target_dtype = torch .int8
@@ -515,12 +513,12 @@ def test_quantized_tensor_subclass_int8(self):
515
513
m_copy = copy .deepcopy (m )
516
514
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
517
515
518
- def to_quantized (weight ):
516
+ def apply_weight_quant (weight ):
519
517
block_size = (1 , weight .shape [1 ])
520
- return AffineQuantizedTensor .from_float (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
518
+ return to_aqt (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
519
+
520
+ m = quantize (m , apply_weight_quant )
521
521
522
- m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
523
- m .linear2 .weight = torch .nn .Parameter (to_quantized (m .linear2 .weight ), requires_grad = False )
524
522
assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
525
523
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
526
524
@@ -537,12 +535,6 @@ def to_quantized(weight):
537
535
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
538
536
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
539
537
def test_quantized_tensor_subclass_int8_dyn_quant (self ):
540
- from torchao .quantization .subclass import AffineQuantizedTensor
541
- from torchao .quantization .subclass import LinearActQuantizedTensor
542
- from torchao .quantization .quant_primitives import MappingType
543
- from torchao .quantization .quant_primitives import ZeroPointDomain
544
- import copy
545
-
546
538
# weight settings
547
539
mapping_type = MappingType .SYMMETRIC
548
540
def get_weight_block_size (x ):
@@ -563,20 +555,24 @@ def get_per_token_block_size(x):
563
555
input_eps = 1e-5
564
556
input_quant_min = - 127
565
557
input_quant_max = 127
566
- input_quant_func = lambda x : AffineQuantizedTensor . from_float (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
558
+ input_quant_func = lambda x : to_aqt (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
567
559
568
560
# use 1024 so that we don't need padding
569
561
m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
570
562
m_copy = copy .deepcopy (m )
571
- example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs ()))
563
+ # setting batch_size to 20 to be compatible with the kernel
564
+ example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs (batch_size = 20 )))
565
+
566
+ def apply_weight_quant (weight ):
567
+ block_size = get_weight_block_size (weight )
568
+ return to_aqt (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
572
569
573
- def dynamic_quant (linear ):
574
- # note: order is important
575
- linear .weight = torch .nn .Parameter (AffineQuantizedTensor .from_float (linear .weight , mapping_type , get_weight_block_size (linear .weight ), target_dtype , eps = eps , zero_point_dtype = zero_point_dtype ), requires_grad = False )
576
- linear .weight = torch .nn .Parameter (LinearActQuantizedTensor .from_float (linear .weight , input_quant_func ), requires_grad = False )
570
+ def apply_act_quant (weight ):
571
+ return to_laqt (weight , input_quant_func )
572
+
573
+ m = quantize (m , apply_weight_quant )
574
+ m = quantize (m , apply_act_quant )
577
575
578
- dynamic_quant (m .linear1 )
579
- dynamic_quant (m .linear2 )
580
576
assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
581
577
assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
582
578
assert isinstance (m .linear1 .weight .original_weight_tensor , AffineQuantizedTensor )
@@ -591,6 +587,19 @@ def dynamic_quant(linear):
591
587
592
588
self .assertTrue (torch .equal (res , ref ))
593
589
590
+ # workaround for export path
591
+ from torchao .quantization .utils import unwrap_tensor_subclass
592
+ m_unwrapped = unwrap_tensor_subclass (m )
593
+
594
+ m = torch .export .export (m_unwrapped , example_inputs ).module ()
595
+ exported_model_res = m (* example_inputs )
596
+
597
+ self .assertTrue (torch .equal (exported_model_res , ref ))
598
+
599
+ # make sure it compiles
600
+ torch ._export .aot_compile (m_unwrapped , example_inputs )
601
+
602
+
594
603
595
604
if __name__ == "__main__" :
596
605
unittest .main ()
0 commit comments