55from compressed_tensors .quantization import FP8_DTYPE
66
77import vllm .envs as envs
8+ import vllm .plugins
89from vllm .compilation .fusion import (FUSED_OPS , QUANT_OPS , FusedRMSQuantKey ,
910 FusionPass , QuantKey )
1011from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe
11- from vllm .compilation .reshapes import RedundantReshapesPass
12- from vllm .config import CompilationConfig
12+ from vllm .compilation .noop_elimination import NoOpEliminationPass
13+ from vllm .config import CompilationConfig , CompilationLevel , VllmConfig
1314from vllm .model_executor .layers .layernorm import RMSNorm
1415from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
15- apply_fp8_linear )
16+ CUTLASS_FP8_SUPPORTED , apply_fp8_linear , maybe_create_device_identity )
1617
1718from .backend import TestBackend
1819
1920
2021class TestModel (torch .nn .Module ):
2122
22- def __init__ (self , hidden_size : int , eps : float , static : bool , * args ,
23- ** kwargs ):
23+ def __init__ (self , hidden_size : int , eps : float , static : bool ,
24+ cutlass_fp8_enabled : bool , * args , ** kwargs ):
2425 super ().__init__ (* args , ** kwargs )
26+ self .cutlass_fp8_enabled = cutlass_fp8_enabled
2527 self .norm = [RMSNorm (hidden_size , eps ) for _ in range (3 )]
2628 self .wscale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (2 )]
2729 if static :
@@ -41,15 +43,17 @@ def forward(self, x):
4143 self .w [0 ],
4244 self .wscale [0 ],
4345 self .scale [0 ],
44- use_per_token_if_dynamic = True )
46+ use_per_token_if_dynamic = True ,
47+ cutlass_fp8_supported = self .cutlass_fp8_enabled )
4548 # make sure resid is used for replacement to work
4649 y2 , resid = self .norm [1 ](x2 , resid )
4750
4851 x3 = apply_fp8_linear (y2 ,
4952 self .w [1 ],
5053 self .wscale [1 ],
5154 self .scale [1 ],
52- use_per_token_if_dynamic = True )
55+ use_per_token_if_dynamic = True ,
56+ cutlass_fp8_supported = self .cutlass_fp8_enabled )
5357 y3 , resid = self .norm [2 ](x3 , resid ) # use resid here
5458 return y3
5559
@@ -59,60 +63,67 @@ def forward(self, x):
5963@pytest .mark .parametrize ("num_tokens" , [7 , 256 , 533 , 2048 , 2049 ])
6064@pytest .mark .parametrize ("eps" , [1e-5 , 1e-6 ])
6165@pytest .mark .parametrize ("static" , [True , False ])
66+ @pytest .mark .parametrize ("cutlass_fp8_enabled" ,
67+ [True , False ] if CUTLASS_FP8_SUPPORTED else [False ])
6268@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" ,
6369 reason = "Only test on CUDA" )
64- def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps , static ):
70+ def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps , static ,
71+ cutlass_fp8_enabled ):
6572 torch .set_default_device ("cuda" )
6673 torch .set_default_dtype (dtype )
6774 torch .manual_seed (1 )
75+ maybe_create_device_identity () # needed for certain non-cutlass fp8 paths
6876
69- # Reshape pass is needed for the fusion pass to work
70- config = CompilationConfig .PassConfig (enable_fusion = True ,
71- enable_reshape = True )
72- reshape_pass = RedundantReshapesPass (config )
73- fusion_pass = FusionPass .instance (config )
74-
75- backend = TestBackend (reshape_pass , fusion_pass )
76- model = TestModel (hidden_size , eps , static )
77-
78- # First dimension dynamic
79- x = torch .rand (num_tokens , hidden_size )
80- torch ._dynamo .mark_dynamic (x , 0 )
81-
82- result = model (x )
83-
84- model2 = torch .compile (model , backend = backend )
85- result2 = model2 (x )
86-
87- # Higher tol for dynamic, even higher for bfloat16
88- if static :
89- ATOL , RTOL = (1e-3 , 1e-3 )
90- elif dtype == torch .float16 :
91- ATOL , RTOL = (2e-3 , 2e-3 )
92- else :
93- ATOL , RTOL = (1e-2 , 1e-2 )
94-
95- torch .testing .assert_close (result , result2 , atol = ATOL , rtol = RTOL )
96-
97- # Check substitution worked
98- pre_nodes = backend .graph_pre_pass .nodes
99- post_nodes = backend .graph_post_pass .nodes
100-
101- # static is per-tensor, dynamic is per-token
102- key = QuantKey (dtype = FP8_DTYPE ,
103- static = static ,
104- per_tensor = static ,
105- symmetric = True )
106- rms_quant = FUSED_OPS [FusedRMSQuantKey (key , False )]
107- add_rms_quant = FUSED_OPS [FusedRMSQuantKey (key , True )]
108- fp8_quant = QUANT_OPS [key ]
109-
110- # In pre-nodes, fp8 quant should be present and fused kernels should not
111- assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
112- assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
113- find_auto_fn (pre_nodes , fp8_quant )
114-
115- # In post-nodes, fused kernels should be present and fp8 quant should not
116- find_auto_fn (post_nodes , rms_quant )
117- find_auto_fn (post_nodes , add_rms_quant )
118- assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
77+ vllm_config = VllmConfig (compilation_config = CompilationConfig (
78+ level = CompilationLevel .PIECEWISE , custom_ops = ["+rms_norm" ]))
79+ with vllm .config .set_current_vllm_config (vllm_config ):
80+ # Reshape pass is needed for the fusion pass to work
81+ config = CompilationConfig .PassConfig (enable_fusion = True ,
82+ enable_noop = True )
83+ noop_pass = NoOpEliminationPass (config )
84+ fusion_pass = FusionPass .instance (config )
85+
86+ backend = TestBackend (noop_pass , fusion_pass )
87+ model = TestModel (hidden_size , eps , static , cutlass_fp8_enabled )
88+
89+ # First dimension dynamic
90+ x = torch .rand (num_tokens , hidden_size )
91+ torch ._dynamo .mark_dynamic (x , 0 )
92+
93+ result = model (x )
94+
95+ model2 = torch .compile (model , backend = backend )
96+ result2 = model2 (x )
97+
98+ # Higher tol for dynamic, even higher for bfloat16
99+ if static :
100+ ATOL , RTOL = (1e-3 , 1e-3 )
101+ elif dtype == torch .float16 :
102+ ATOL , RTOL = (2e-3 , 2e-3 )
103+ else :
104+ ATOL , RTOL = (1e-2 , 1e-2 )
105+
106+ torch .testing .assert_close (result , result2 , atol = ATOL , rtol = RTOL )
107+
108+ # Check substitution worked
109+ pre_nodes = backend .graph_pre_pass .nodes
110+ post_nodes = backend .graph_post_pass .nodes
111+
112+ # static is per-tensor, dynamic is per-token
113+ key = QuantKey (dtype = FP8_DTYPE ,
114+ static = static ,
115+ per_tensor = static ,
116+ symmetric = True )
117+ rms_quant = FUSED_OPS [FusedRMSQuantKey (key , False )]
118+ add_rms_quant = FUSED_OPS [FusedRMSQuantKey (key , True )]
119+ fp8_quant = QUANT_OPS [key ]
120+
121+ # In pre-nodes, fp8 quant should be there and fused kernels should not
122+ assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
123+ assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
124+ find_auto_fn (pre_nodes , fp8_quant )
125+
126+ # In post-nodes, fused kernels should be there and fp8 quant should not
127+ find_auto_fn (post_nodes , rms_quant )
128+ find_auto_fn (post_nodes , add_rms_quant )
129+ assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
0 commit comments