Skip to content

Commit 73b7ce4

Browse files
committed
Revert "HIP: use v_dot2_f32_f16 instruction for FA (ggml-org#15884)"
This reverts commit 17bc5a8.
1 parent 815c8ef commit 73b7ce4

File tree

2 files changed

+6
-26
lines changed

2 files changed

+6
-26
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -550,31 +550,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
550550
#endif // defined(GGML_USE_HIP)
551551
}
552552

553-
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
554-
acc += v*u;
555-
}
556-
557-
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
558-
acc += v.x*u.x;
559-
acc += v.y*u.y;
560-
}
561-
562-
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
563-
#if defined(GGML_USE_HIP) && defined(GCN)
564-
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
565-
#else
566-
#ifdef FAST_FP16_AVAILABLE
567-
const float2 tmp = __half22float2(v*u);
568-
acc += tmp.x + tmp.y;
569-
#else
570-
const float2 tmpv = __half22float2(v);
571-
const float2 tmpu = __half22float2(u);
572-
acc += tmpv.x * tmpu.x;
573-
acc += tmpv.y * tmpu.y;
574-
#endif // FAST_FP16_AVAILABLE
575-
#endif // defined(GGML_USE_HIP) && defined(GCN)
576-
}
577-
578553
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
579554
#if CUDART_VERSION >= 12080
580555
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,12 @@ static __global__ void flash_attn_tile(
304304
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
305305
#pragma unroll
306306
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
307-
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
307+
#ifdef FAST_FP16_AVAILABLE
308+
const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]);
309+
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y;
310+
#else
311+
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps];
312+
#endif // FAST_FP16_AVAILABLE
308313
}
309314
}
310315
}

0 commit comments

Comments
 (0)