Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c9ca102
Move apply_w8a8_block_fp8_linear to an op class
ElizaWszola Sep 11, 2025
eef4349
Remove TODO, bring back old one
ElizaWszola Sep 11, 2025
dd53183
CUDA graphs fix
ElizaWszola Sep 11, 2025
bb24881
Clean up
ElizaWszola Sep 11, 2025
1ba47cd
Create linear op objects conditionally, move some arch checks to bloc…
ElizaWszola Sep 11, 2025
02793b9
format
ElizaWszola Sep 11, 2025
b72c9f2
clean up repetitive code
ElizaWszola Sep 12, 2025
d51f35c
More aggressive dispatch of blockscale ops
ElizaWszola Sep 12, 2025
a6ae689
fix
ElizaWszola Sep 12, 2025
3238ff6
Deep_gemm fix
ElizaWszola Sep 12, 2025
f9c79aa
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 12, 2025
23341c2
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 12, 2025
9b09b60
Post-merge fixes, better dispatch
ElizaWszola Sep 12, 2025
e6b0028
small fixes
ElizaWszola Sep 12, 2025
9b5c552
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 15, 2025
ef6f1e2
Fix cutlass compilation issue on Hopper
ElizaWszola Sep 17, 2025
77335de
Cleanup bad transpose
ElizaWszola Sep 17, 2025
5eaf155
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 17, 2025
e036dac
Wrap w8a8_block_fp8_matmul
ElizaWszola Sep 17, 2025
233e874
Rename padded_cutlass to padded_cutlass_scaled_mm, add todo
ElizaWszola Sep 17, 2025
1edfedc
Cleanup dispatch_w8a8_blockscale_func
ElizaWszola Sep 17, 2025
35a0236
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 18, 2025
0ac3a1e
Deep gemm warmup fix
ElizaWszola Sep 18, 2025
9a48100
Fix deep gemm support function
ElizaWszola Sep 18, 2025
b6a8fb8
Feedback
ElizaWszola Sep 19, 2025
e89ecd8
Pre-commit fixes
ElizaWszola Sep 19, 2025
00cb05c
Pre-commit fixes 2
ElizaWszola Sep 19, 2025
66c89e6
Feedback
ElizaWszola Sep 19, 2025
d9b4121
fix type issue
ElizaWszola Sep 19, 2025
1bc81a1
Add use_ue8m0 support to _quantize_group_native
ElizaWszola Sep 19, 2025
ec73268
Fix padding compilation issue
ElizaWszola Sep 22, 2025
d19bf4b
Feedback
ElizaWszola Sep 22, 2025
1f895e9
Update vllm/model_executor/layers/quantization/utils/fp8_utils.py
ElizaWszola Sep 22, 2025
be3ac58
Link bad group shape issue
ElizaWszola Sep 22, 2025
3772f2f
format
ElizaWszola Sep 22, 2025
8b6cbe4
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 22, 2025
2a87a3b
fix quant config condition
ElizaWszola Sep 22, 2025
012eaff
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
mgoin Sep 22, 2025
e7f6ec9
fix quant issue (TODO test)
ProExpertProg Sep 22, 2025
10829d3
fix custom op test
ProExpertProg Sep 22, 2025
15cf30e
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 23, 2025
ebdcb10
CUDA condition for compressed tensors and H100
ElizaWszola Sep 23, 2025
2e3d206
Fix quantfp8 test
ElizaWszola Sep 23, 2025
bd32cb9
Test scales_col vs. scales_native
ElizaWszola Sep 23, 2025
efa4446
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 23, 2025
1f00804
Add compressed tensors model test
ElizaWszola Sep 23, 2025
e895df6
Extra asserts, don't use enabled()
ElizaWszola Sep 23, 2025
9806cf8
CUDA path for quant
ProExpertProg Sep 23, 2025
2ae1ef9
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ProExpertProg Sep 23, 2025
00bd638
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ProExpertProg Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear)
from vllm.utils.flashinfer import has_flashinfer_moe

if TYPE_CHECKING:
Expand Down Expand Up @@ -251,8 +253,10 @@ def __init__(self, quant_config: Fp8Config):
act_quant_group_shape=self.act_q_group_shape)

self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
self.cutlass_block_fp8_supported,
self.use_aiter_and_is_supported,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
ue8m0_deepgemm_supported=is_deep_gemm_e8m0_used(),
is_blackwell=current_platform.has_device_capability(100),
)

def create_weights(
Expand Down Expand Up @@ -365,6 +369,9 @@ def create_weights(
else:
layer.register_parameter("input_scale", None)

self.w8a8_block_fp8_linear.set_should_use_deepgemm(
should_use_deepgemm_for_fp8_linear(self.out_dtype, weight))

def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
Expand Down
23 changes: 18 additions & 5 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp:
"""

Check failure on line 115 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/utils/fp8_utils.py:115:81: E501 Line too long (81 > 80)
This class executes a Blocked FP8 linear layer using cutlass if supported and
torch.scaled_mm otherwise.
"""
Expand All @@ -121,9 +121,20 @@
self,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
ue8m0_deepgemm_supported: bool = False,
is_blackwell: bool = False,
):
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported
self.use_aiter_and_is_supported = use_aiter_and_is_supported
self.ue8m0_deepgemm_supported = ue8m0_deepgemm_supported
self.is_blackwell = is_blackwell
self.should_use_deepgemm = False

def set_should_use_deepgemm(
self,
should_use_deepgemm: bool,
):
self.should_use_deepgemm = should_use_deepgemm

def apply(
self,
Expand All @@ -140,7 +151,7 @@
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype

if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
if self.should_use_deepgemm:

input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
Expand All @@ -149,6 +160,7 @@
input_2d,
block_size[1],
column_major_scales=True,
use_ue8m0=self.ue8m0_deepgemm_supported,
)

# ensure DeepGEMM-backed custom op is registered before use
Expand All @@ -166,12 +178,11 @@
return output.to(dtype=output_dtype).view(*output_shape)

if current_platform.is_cuda():
if current_platform.has_device_capability(100):

if self.is_blackwell:
use_cutlass = self.cutlass_block_fp8_supported and (
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
else:

Check failure on line 185 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/utils/fp8_utils.py:185:81: E501 Line too long (83 > 80)
# TODO: update this after switching to public sm90 block scale gemm
# as it also supports weight.shape % 128 != 0
use_cutlass = self.cutlass_block_fp8_supported and (
Expand All @@ -183,7 +194,8 @@
use_cutlass, self.use_aiter_and_is_supported)
if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
input_2d, block_size[1], column_major_scales=use_cutlass,
use_ue8m0=self.ue8m0_deepgemm_supported)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)

Expand All @@ -193,7 +205,8 @@
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
input_2d, block_size[1], column_major_scales=use_cutlass,
use_ue8m0=self.ue8m0_deepgemm_supported)

output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
Expand Down
Loading