33from compressed_tensors .quantization import FP8_DTYPE
44
55import vllm .envs as envs
6+ import vllm .plugins
67from vllm .compilation .fusion import (FusionPass , find_auto_fn ,
78 find_auto_fn_maybe )
89from vllm .compilation .reshapes import RedundantReshapesPass
9- from vllm .config import CompilationConfig
10+ from vllm .config import CompilationConfig , VllmConfig , CompilationLevel
1011from vllm .model_executor .layers .layernorm import RMSNorm
1112from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
1213 apply_fp8_linear )
1617
1718class TestModel (torch .nn .Module ):
1819
19- def __init__ (self , hidden_size : int , eps : float , * args , ** kwargs ):
20+ def __init__ (self , hidden_size : int , eps : float , cutlass_fp8 : bool , * args ,
21+ ** kwargs ):
2022 super ().__init__ (* args , ** kwargs )
23+ self .cutlass_fp8 = cutlass_fp8
2124 self .norm = [RMSNorm (hidden_size , eps ) for _ in range (3 )]
2225 self .scale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (4 )]
2326 self .w = [
@@ -29,11 +32,19 @@ def forward(self, x):
2932 resid = torch .relu (x )
3033 y = self .norm [0 ](x )
3134
32- x2 = apply_fp8_linear (y , self .w [0 ], self .scale [0 ], self .scale [1 ])
35+ x2 = apply_fp8_linear (y ,
36+ self .w [0 ],
37+ self .scale [0 ],
38+ self .scale [1 ],
39+ cutlass_fp8_supported = self .cutlass_fp8 )
3340 # make sure resid is used for replacement to work
3441 y2 , resid = self .norm [1 ](x2 , resid )
3542
36- x3 = apply_fp8_linear (y2 , self .w [1 ], self .scale [2 ], self .scale [3 ])
43+ x3 = apply_fp8_linear (y2 ,
44+ self .w [1 ],
45+ self .scale [2 ],
46+ self .scale [3 ],
47+ cutlass_fp8_supported = self .cutlass_fp8 )
3748 y3 , resid = self .norm [2 ](x3 , resid ) # use resid here
3849 return y3
3950
@@ -42,50 +53,58 @@ def forward(self, x):
4253@pytest .mark .parametrize ("hidden_size" , [64 , 3392 , 4096 ])
4354@pytest .mark .parametrize ("num_tokens" , [7 , 256 , 533 , 2048 , 2049 ])
4455@pytest .mark .parametrize ("eps" , [1e-5 , 1e-6 ])
56+ @pytest .mark .parametrize (
57+ "cutlass_fp8" ,
58+ [True , False ] if envs .VLLM_TARGET_DEVICE == "cuda" else [False ])
4559@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" ,
4660 reason = "Only test on CUDA" )
47- def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps ):
61+ def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps ,
62+ cutlass_fp8 ):
4863 torch .set_default_device ("cuda" )
49- torch .set_default_dtype (torch .float16 )
64+ torch .set_default_dtype (dtype )
65+ torch .manual_seed (1 )
5066
5167 if eps != 1e-5 :
5268 pytest .skip ("Only test eps=1e-5 for now" )
5369
54- # Reshape pass is needed for the fusion pass to work
55- config = CompilationConfig .PassConfig (enable_fusion = True ,
56- enable_reshape = True )
57- reshape_pass = RedundantReshapesPass (config )
58- fusion_pass = FusionPass .instance (config )
70+ vllm_config = VllmConfig (compilation_config = CompilationConfig (
71+ level = CompilationLevel .PIECEWISE , custom_ops = ["+rms_norm" ]))
72+ with vllm .plugins .set_current_vllm_config (vllm_config ):
73+ # Reshape pass is needed for the fusion pass to work
74+ config = CompilationConfig .PassConfig (enable_fusion = True ,
75+ enable_reshape = True )
76+ reshape_pass = RedundantReshapesPass (config )
77+ fusion_pass = FusionPass .instance (config )
5978
60- backend = TestBackend (reshape_pass , fusion_pass )
61- model = TestModel (hidden_size , eps )
79+ backend = TestBackend (reshape_pass , fusion_pass )
80+ model = TestModel (hidden_size , eps , cutlass_fp8 )
6281
63- # First dimension dynamic
64- x = torch .rand (num_tokens , hidden_size )
65- torch ._dynamo .mark_dynamic (x , 0 )
82+ # First dimension dynamic
83+ x = torch .rand (num_tokens , hidden_size )
84+ torch ._dynamo .mark_dynamic (x , 0 )
6685
67- result = model (x )
86+ result = model (x )
6887
69- model2 = torch .compile (model , backend = backend )
70- result2 = model2 (x )
88+ model2 = torch .compile (model , backend = backend )
89+ result2 = model2 (x )
7190
72- # Check that it gives the same answer
73- torch .testing .assert_close (result , result2 , atol = 1e-3 , rtol = 1e-3 )
91+ # Check that it gives the same answer
92+ torch .testing .assert_close (result , result2 , atol = 1e-3 , rtol = 1e-3 )
7493
75- # Check substitution worked
76- pre_nodes = backend .graph_pre_pass .nodes
77- post_nodes = backend .graph_post_pass .nodes
94+ # Check substitution worked
95+ pre_nodes = backend .graph_pre_pass .nodes
96+ post_nodes = backend .graph_post_pass .nodes
7897
79- rms_quant = torch .ops ._C .rms_norm_static_fp8_quant .default
80- add_rms_quant = torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default
81- fp8_quant = torch .ops ._C .static_scaled_fp8_quant .default
98+ rms_quant = torch .ops ._C .rms_norm_static_fp8_quant .default
99+ add_rms_quant = torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default
100+ fp8_quant = torch .ops ._C .static_scaled_fp8_quant .default
82101
83- # In pre-nodes, fp8 quant should be present and fused kernels should not
84- assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
85- assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
86- find_auto_fn (pre_nodes , fp8_quant )
102+ # In pre-nodes, fp8 quant should be there and fused kernels should not
103+ assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
104+ assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
105+ find_auto_fn (pre_nodes , fp8_quant )
87106
88- # In post-nodes, fused kernels should be present and fp8 quant should not
89- find_auto_fn (post_nodes , rms_quant )
90- find_auto_fn (post_nodes , add_rms_quant )
91- assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
107+ # In post-nodes, fused kernels should be there and fp8 quant should not
108+ find_auto_fn (post_nodes , rms_quant )
109+ find_auto_fn (post_nodes , add_rms_quant )
110+ assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
0 commit comments