Skip to content

Commit 2a602b0

Browse files
authored
forward fix PR 14245, restore build on ROCm 6.2 (#14709)
Signed-off-by: Jeff Daily <[email protected]>
1 parent 7888e1d commit 2a602b0

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

csrc/quantization/fp8/amd/quant_utils.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,24 @@ __device__ __forceinline__ fp8_type cvt_c10(float const r) {
1919
return {};
2020
}
2121

22+
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
23+
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
24+
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
25+
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
26+
// the new HW cvt with something reasonable that doesn't rely on the
27+
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
2228
template <>
2329
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
30+
#if HIP_FP8_TYPE_OCP
2431
return c10::Float8_e4m3fn(
2532
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
2633
__hip_fp8_e4m3::__default_interpret),
2734
c10::Float8_e4m3fn::from_bits());
35+
#else
36+
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
37+
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
38+
return static_cast<c10::Float8_e4m3fn>(r);
39+
#endif
2840
}
2941

3042
template <>

0 commit comments

Comments
 (0)