File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed
csrc/quantization/fp8/amd Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -19,12 +19,24 @@ __device__ __forceinline__ fp8_type cvt_c10(float const r) {
1919 return {};
2020}
2121
22+ // __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
23+ // HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
24+ // its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
25+ // on ROCm instantiates both OCP and FNUZ kernels, we need to replace
26+ // the new HW cvt with something reasonable that doesn't rely on the
27+ // ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
2228template <>
2329__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10 (float const r) {
30+ #if HIP_FP8_TYPE_OCP
2431 return c10::Float8_e4m3fn (
2532 __hip_cvt_float_to_fp8 (r, __hip_fp8_e4m3::__default_saturation,
2633 __hip_fp8_e4m3::__default_interpret),
2734 c10::Float8_e4m3fn::from_bits ());
35+ #else
36+ // Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
37+ // HW cvt above is faster when it is available (ROCm 6.3 or newer).
38+ return static_cast <c10::Float8_e4m3fn>(r);
39+ #endif
2840}
2941
3042template <>
You can’t perform that action at this time.
0 commit comments