Skip to content

Commit e5ded5c

Browse files
committed
Fix for the padding in the non-cutlass-fp8 case
Signed-off-by: luka <[email protected]>
1 parent c92acb9 commit e5ded5c

File tree

4 files changed

+86
-42
lines changed

4 files changed

+86
-42
lines changed

tests/compile/backend.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,26 @@ class TestBackend:
1111
This class provides a simple Inductor backend that can be used for testing.
1212
It takes a list of custom passes and runs them after Inductor's passes.
1313
It also saves the graph before and after the custom passes for inspection.
14+
15+
Inductor config can be modified directly by editing the inductor_config
16+
property. This can be helpful for adding passes like the
17+
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
1418
"""
1519

1620
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
1721
None]]):
1822
self.custom_passes = list(passes)
1923
from torch._inductor import config
20-
self.current_config = config.shallow_copy_dict()
21-
self.current_config['force_disable_caches'] = True
22-
self.current_config['post_grad_custom_post_pass'] = self.post_pass
24+
self.inductor_config = config.shallow_copy_dict()
25+
self.inductor_config['force_disable_caches'] = True
26+
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
2327

2428
def __call__(self, graph: fx.GraphModule, example_inputs):
29+
self.graph_pre_compile = deepcopy(graph)
2530
from torch._inductor.compile_fx import compile_fx
2631
return compile_fx(graph,
2732
example_inputs,
28-
config_patches=self.current_config)
33+
config_patches=self.inductor_config)
2934

3035
def post_pass(self, graph: fx.Graph):
3136
self.graph_pre_pass = deepcopy(graph)

tests/compile/test_fusion.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from compressed_tensors.quantization import FP8_DTYPE
44

55
import vllm.envs as envs
6+
import vllm.plugins
67
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
78
find_auto_fn_maybe)
89
from vllm.compilation.reshapes import RedundantReshapesPass
9-
from vllm.config import CompilationConfig
10+
from vllm.config import CompilationConfig, VllmConfig, CompilationLevel
1011
from vllm.model_executor.layers.layernorm import RMSNorm
1112
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1213
apply_fp8_linear)
@@ -16,8 +17,10 @@
1617

1718
class TestModel(torch.nn.Module):
1819

19-
def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
20+
def __init__(self, hidden_size: int, eps: float, cutlass_fp8: bool, *args,
21+
**kwargs):
2022
super().__init__(*args, **kwargs)
23+
self.cutlass_fp8 = cutlass_fp8
2124
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
2225
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
2326
self.w = [
@@ -29,11 +32,19 @@ def forward(self, x):
2932
resid = torch.relu(x)
3033
y = self.norm[0](x)
3134

32-
x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
35+
x2 = apply_fp8_linear(y,
36+
self.w[0],
37+
self.scale[0],
38+
self.scale[1],
39+
cutlass_fp8_supported=self.cutlass_fp8)
3340
# make sure resid is used for replacement to work
3441
y2, resid = self.norm[1](x2, resid)
3542

36-
x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
43+
x3 = apply_fp8_linear(y2,
44+
self.w[1],
45+
self.scale[2],
46+
self.scale[3],
47+
cutlass_fp8_supported=self.cutlass_fp8)
3748
y3, resid = self.norm[2](x3, resid) # use resid here
3849
return y3
3950

@@ -42,50 +53,58 @@ def forward(self, x):
4253
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
4354
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
4455
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
56+
@pytest.mark.parametrize(
57+
"cutlass_fp8",
58+
[True, False] if envs.VLLM_TARGET_DEVICE == "cuda" else [False])
4559
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
4660
reason="Only test on CUDA")
47-
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
61+
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps,
62+
cutlass_fp8):
4863
torch.set_default_device("cuda")
49-
torch.set_default_dtype(torch.float16)
64+
torch.set_default_dtype(dtype)
65+
torch.manual_seed(1)
5066

5167
if eps != 1e-5:
5268
pytest.skip("Only test eps=1e-5 for now")
5369

54-
# Reshape pass is needed for the fusion pass to work
55-
config = CompilationConfig.PassConfig(enable_fusion=True,
56-
enable_reshape=True)
57-
reshape_pass = RedundantReshapesPass(config)
58-
fusion_pass = FusionPass.instance(config)
70+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
71+
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
72+
with vllm.plugins.set_current_vllm_config(vllm_config):
73+
# Reshape pass is needed for the fusion pass to work
74+
config = CompilationConfig.PassConfig(enable_fusion=True,
75+
enable_reshape=True)
76+
reshape_pass = RedundantReshapesPass(config)
77+
fusion_pass = FusionPass.instance(config)
5978

60-
backend = TestBackend(reshape_pass, fusion_pass)
61-
model = TestModel(hidden_size, eps)
79+
backend = TestBackend(reshape_pass, fusion_pass)
80+
model = TestModel(hidden_size, eps, cutlass_fp8)
6281

63-
# First dimension dynamic
64-
x = torch.rand(num_tokens, hidden_size)
65-
torch._dynamo.mark_dynamic(x, 0)
82+
# First dimension dynamic
83+
x = torch.rand(num_tokens, hidden_size)
84+
torch._dynamo.mark_dynamic(x, 0)
6685

67-
result = model(x)
86+
result = model(x)
6887

69-
model2 = torch.compile(model, backend=backend)
70-
result2 = model2(x)
88+
model2 = torch.compile(model, backend=backend)
89+
result2 = model2(x)
7190

72-
# Check that it gives the same answer
73-
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
91+
# Check that it gives the same answer
92+
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
7493

75-
# Check substitution worked
76-
pre_nodes = backend.graph_pre_pass.nodes
77-
post_nodes = backend.graph_post_pass.nodes
94+
# Check substitution worked
95+
pre_nodes = backend.graph_pre_pass.nodes
96+
post_nodes = backend.graph_post_pass.nodes
7897

79-
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
80-
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
81-
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
98+
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
99+
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
100+
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
82101

83-
# In pre-nodes, fp8 quant should be present and fused kernels should not
84-
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
85-
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
86-
find_auto_fn(pre_nodes, fp8_quant)
102+
# In pre-nodes, fp8 quant should be there and fused kernels should not
103+
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
104+
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
105+
find_auto_fn(pre_nodes, fp8_quant)
87106

88-
# In post-nodes, fused kernels should be present and fp8 quant should not
89-
find_auto_fn(post_nodes, rms_quant)
90-
find_auto_fn(post_nodes, add_rms_quant)
91-
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
107+
# In post-nodes, fused kernels should be there and fp8 quant should not
108+
find_auto_fn(post_nodes, rms_quant)
109+
find_auto_fn(post_nodes, add_rms_quant)
110+
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None

vllm/compilation/vllm_inductor_pass.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def __init__(self, config: CompilationConfig.PassConfig):
3030
self.config = config
3131
self.pass_name = self.__class__.__name__
3232

33-
def dump_graph(self, graph: torch.fx.Graph, stage: str):
34-
if stage in self.config.dump_graph_stages:
33+
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
34+
if stage in self.config.dump_graph_stages or always:
3535
# Make sure filename includes rank in the distributed setting
3636
parallel = p_is_init() and get_tp_world_size() > 1
3737
rank = f"-{get_tp_rank()}" if parallel else ""
@@ -51,3 +51,17 @@ def end_and_log(self):
5151
self._end_time = time.perf_counter_ns()
5252
duration_ms = float(self._end_time - self._start_time) / 1.0e6
5353
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
54+
55+
56+
class PrinterInductorPass(VllmInductorPass):
57+
58+
def __init__(self,
59+
name: str,
60+
config: CompilationConfig.PassConfig,
61+
always=False):
62+
super().__init__(config)
63+
self.name = name
64+
self.always = always
65+
66+
def __call__(self, graph: torch.fx.Graph):
67+
self.dump_graph(graph, self.name, always=self.always)

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import torch
44

55
from vllm import _custom_ops as ops
6+
from vllm.config import CompilationLevel
67
from vllm.platforms import current_platform
8+
from vllm.plugins import get_current_vllm_config
79

810
# Input scaling factors are no longer optional in _scaled_mm starting
911
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@@ -122,10 +124,14 @@ def apply_fp8_linear(
122124
# Note: we pad the input because torch._scaled_mm is more performant
123125
# for matrices with batch dimension > 16.
124126
# This could change in the future.
127+
# We also don't pad when using torch.compile,
128+
# as it breaks with dynamic shapes.
129+
config = get_current_vllm_config().compilation_config
130+
do_pad = config.level < CompilationLevel.PIECEWISE
125131
qinput, x_scale = ops.scaled_fp8_quant(
126132
input_2d,
127133
input_scale,
128-
num_token_padding=17,
134+
num_token_padding=17 if do_pad else None,
129135
use_per_token_if_dynamic=use_per_token_if_dynamic)
130136

131137
per_tensor_weights = (weight_scale.numel() == 1)

0 commit comments

Comments
 (0)