Skip to content

Commit 1d40cfd

Browse files
LyrisZhongdjmmoss
authored andcommitted
[Kernel] SM90 CUTLASS FP8 GEMM: add support for swap AB + kernel tuning (vllm-project#20396)
Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Duncan Moss <[email protected]>
1 parent 47a6c89 commit 1d40cfd

File tree

3 files changed

+277
-52
lines changed

3 files changed

+277
-52
lines changed

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "scaled_mm_kernels.hpp"
22
#include "scaled_mm_sm90_fp8_dispatch.cuh"
3-
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
43

54
namespace vllm {
65

@@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
1312
if (bias) {
1413
TORCH_CHECK(bias->dtype() == out.dtype(),
1514
"currently bias dtype must match output dtype ", out.dtype());
16-
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
17-
out, a, b, a_scales, b_scales, *bias);
15+
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
16+
b_scales, *bias);
1817
} else {
19-
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
20-
out, a, b, a_scales, b_scales);
18+
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
19+
b_scales);
2120
}
2221
}
2322

0 commit comments

Comments
 (0)