diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu index cf2cccc913f6..62aeb927ccdc 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu @@ -1,6 +1,5 @@ #include "scaled_mm_kernels.hpp" #include "scaled_mm_sm100_fp8_dispatch.cuh" -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" namespace vllm { @@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm100_fp8_epilogue( - out, a, b, a_scales, b_scales, *bias); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales, *bias); } else { - return cutlass_scaled_mm_sm100_fp8_epilogue( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales); } } diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh index f876b7d9acd8..c950008b4139 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -2,6 +2,7 @@ #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" /** * This file defines Gemm kernel configurations for SM100 (fp8) based on the @@ -12,8 +13,88 @@ namespace vllm { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule, bool swap_ab_ = false> +struct cutlass_3x_gemm_sm100_fp8 { + using ElementAB = ElementAB_; + using ElementC = ElementD_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Epilogue = Epilogue_; + + using EVTCompute = typename Epilogue::EVTCompute; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; + + // Compile-time swap_ab flag + static constexpr bool swap_ab = swap_ab_; + + // ----------------------------------------------------------- + // Layout definitions + // ----------------------------------------------------------- + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + + // ----------------------------------------------------------- + // Collective epilogue (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, + conditional_t, AlignmentCD, + ElementD, conditional_t, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // ----------------------------------------------------------- + // Collective mainloop (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutB_T, AlignmentAB, // Swapped B (as A) + ElementAB, LayoutA_T, AlignmentAB, // Swapped A (as B) + ElementAcc, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentAB, ElementAB, LayoutB, AlignmentAB, ElementAcc, + TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>; + + // ----------------------------------------------------------- + // Kernel definition + // ----------------------------------------------------------- + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + +template struct sm100_fp8_config_default { // M in (256, inf) static_assert(std::is_same()); @@ -22,12 +103,16 @@ struct sm100_fp8_config_default { using TileShape = Shape<_256, _128, _128>; using ClusterShape = Shape<_2, _2, _1>; using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> +template struct sm100_fp8_config_M256 { // M in (64, 256] static_assert(std::is_same()); @@ -36,44 +121,127 @@ struct sm100_fp8_config_M256 { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> +template +struct sm100_fp8_config_M64_swap_ab { + // This config is for M in (16, 64] and K >= 4096 + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _64, _256>; + using ClusterShape = Shape<_4, _1, _1>; + + // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap + // AB + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm100_fp8, + cutlass_3x_gemm_sm100_fp8>; +}; + +template struct sm100_fp8_config_M64 { - // M in (16, 64] + // This config is for M = 64 and K < 4096 (do not enable swap AB in such case) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + conditional_t, + cutlass_3x_gemm_sm100_fp8< + InType, OutType, c3x::ScaledEpilogue, TileShape, + ClusterShape, KernelSchedule, EpilogueSchedule>>; }; -template typename Epilogue> -struct sm100_fp8_config_M16 { +template +struct sm100_fp8_config_M16_swap_ab { // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm_sm100; + using TileShape = Shape<_128, _32, _128>; + using ClusterShape = Shape<_4, _1, _1>; + + // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap + // AB + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm100_fp8, + cutlass_3x_gemm_sm100_fp8>; }; -template typename Epilogue, +template +void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + static constexpr bool swap_ab = Gemm::swap_ab; + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = + swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC c_stride = cutlass::make_cute_packed_stride( + StrideC{}, + swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args = + swap_ab ? typename GemmKernel::MainloopArguments{b_ptr, b_stride, a_ptr, + a_stride} + : typename GemmKernel::MainloopArguments{a_ptr, a_stride, b_ptr, + b_stride}; + + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); @@ -81,55 +249,69 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM16 = - typename sm100_fp8_config_M16::Cutlass3xGemm; + EnableBias>::Cutlass3xGemm; + using Cutlass3xGemmM16SwapAB = + typename sm100_fp8_config_M16_swap_ab::Cutlass3xGemm; + using Cutlass3xGemmM64SwapAB = + typename sm100_fp8_config_M64_swap_ab::Cutlass3xGemm; using Cutlass3xGemmM64 = - typename sm100_fp8_config_M64::Cutlass3xGemm; + typename sm100_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM256 = - typename sm100_fp8_config_M256::Cutlass3xGemm; + typename sm100_fp8_config_M256::Cutlass3xGemm; uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(16), next_pow_2(m)); // next power of 2 + uint32_t const k = a.size(1); - if (mp2 <= 16) { + if (m <= 16) { // m in [1, 16] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 64) { + return cutlass_gemm_caller_sm100_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + } else if (m <= 64) { // m in (16, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 256) { + if (m == 64 && k < 4096) { + // do not enable swap AB + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); + } + return cutlass_gemm_caller_sm100_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + + } else if (m <= 256) { // m in (64, 256] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } else { // m in (256, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm100_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } } -template