@@ -8,11 +8,14 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
88 if (GGML_CUDA_CC_IS_AMD (cc)) {
99 switch (D) {
1010 case 64 :
11- return ncols <= 16 ? 32 : 64 ;
11+ return 64 ;
1212 case 128 :
13- return ncols <= 16 ? 64 : warp_size;
1413 case 256 :
15- return 64 ;
14+ if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
15+ return ncols <= 16 ? 64 : 32 ;
16+ } else {
17+ return 64 ;
18+ }
1619 default :
1720 GGML_ABORT (" fatal error" );
1821 return -1 ;
@@ -41,17 +44,26 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
4144 GGML_ABORT (" fatal error" );
4245 return -1 ;
4346 }
47+ GGML_UNUSED (warp_size);
4448}
4549
4650static constexpr __device__ int fattn_tile_get_kq_stride_device (int D, int ncols, int warp_size) {
4751#ifdef GGML_USE_HIP
4852 switch (D) {
4953 case 64 :
50- return ncols <= 16 ? 32 : 64 ;
54+ return 64 ;
5155 case 128 :
52- return ncols <= 16 ? 64 : warp_size;
56+ #if defined(GCN) || defined(CDNA)
57+ return ncols <= 16 ? 64 : 32 ;
58+ #else
59+ return 64 ;
60+ #endif // defined(GCN) || defined(CDNA)
5361 case 256 :
62+ #if defined(GCN) || defined(CDNA)
63+ return ncols <= 16 ? 64 : 32 ;
64+ #else
5465 return 64 ;
66+ #endif // defined(GCN) || defined(CDNA)
5567 default :
5668 return -1 ;
5769 }
@@ -88,9 +100,17 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
88100 case 64 :
89101 return 64 ;
90102 case 128 :
91- return ncols <= 16 ? 2 *warp_size : 128 ;
103+ #if defined(GCN) || defined(CDNA)
104+ return ncols <= 16 ? 64 : 128 ;
105+ #else
106+ return 64 ;
107+ #endif // defined(GCN) || defined(CDNA)
92108 case 256 :
93- return ncols <= 16 ? 128 : 2 *warp_size;
109+ #if defined(GCN) || defined(CDNA)
110+ return ncols <= 16 ? 64 : 128 ;
111+ #else
112+ return ncols <= 16 ? 64 : 256 ;
113+ #endif // defined(GCN) || defined(CDNA)
94114 default :
95115 return -1 ;
96116 }
@@ -196,14 +216,21 @@ static __global__ void flash_attn_tile(
196216
197217 const float slope = get_alibi_slope (max_bias, head, n_head_log2, m0, m1);
198218
219+ #if defined(GGML_USE_HIP)
220+ constexpr int cpy_nb = 16 ;
221+ #else
222+ constexpr int cpy_nb = 8 ;
223+ #endif // defined(GGML_USE_HIP) && defined(GCN)
224+ constexpr int cpy_ne = cpy_nb / 4 ;
225+
199226 __shared__ float KQ[ncols][kq_stride];
200227#ifdef FAST_FP16_AVAILABLE
201228 __shared__ half2 Q_tmp[ncols][D/2 ];
202- __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1 )]; // Padded to avoid memory bank conflicts.
229+ __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne )]; // Padded to avoid memory bank conflicts.
203230 half2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
204231#else
205232 __shared__ float Q_tmp[ncols][D];
206- __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1 )]; // Padded to avoid memory bank conflicts.
233+ __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne )]; // Padded to avoid memory bank conflicts.
207234 float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
208235 float2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
209236#endif // FAST_FP16_AVAILABLE
@@ -256,11 +283,11 @@ static __global__ void flash_attn_tile(
256283 for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += warp_size) {
257284 const half2 tmp_h2 = K_h2[int64_t (k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx .x ];
258285#ifdef FAST_FP16_AVAILABLE
259- KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
286+ KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
260287#else
261288 const float2 tmp_f2 = __half22float2 (tmp_h2);
262- KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
263- KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
289+ KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
290+ KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
264291#endif // FAST_FP16_AVAILABLE
265292 }
266293 }
@@ -269,42 +296,45 @@ static __global__ void flash_attn_tile(
269296
270297#ifdef FAST_FP16_AVAILABLE
271298#pragma unroll
272- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; ++ k_KQ_1) {
273- half2 K_k[kq_stride/warp_size];
274- half2 Q_k[ncols/nwarps];
299+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += cpy_ne ) {
300+ half2 K_k[kq_stride/warp_size][cpy_ne] ;
301+ half2 Q_k[ncols/nwarps][cpy_ne] ;
275302#else
276303#pragma unroll
277- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; ++ k_KQ_1) {
278- float K_k[kq_stride/warp_size];
279- float Q_k[ncols/nwarps];
304+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne ) {
305+ float K_k[kq_stride/warp_size][cpy_ne] ;
306+ float Q_k[ncols/nwarps][cpy_ne] ;
280307#endif // FAST_FP16_AVAILABLE
281308
282309#pragma unroll
283310 for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
284311 const int i_KQ = i_KQ_0 + threadIdx .x ;
285312
286313#ifdef FAST_FP16_AVAILABLE
287- K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1];
314+ ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1]) ;
288315#else
289- K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1 ) + k_KQ_1];
316+ ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne ) + k_KQ_1]) ;
290317#endif // FAST_FP16_AVAILABLE
291318 }
292319#pragma unroll
293320 for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
294321 const int j_KQ = j_KQ_0 + threadIdx .y ;
295322
296323#ifdef FAST_FP16_AVAILABLE
297- Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
324+ ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]) ;
298325#else
299- Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
326+ ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]) ;
300327#endif // FAST_FP16_AVAILABLE
301328 }
302329
303330#pragma unroll
304331 for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
305332#pragma unroll
306333 for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
307- ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
334+ #pragma unroll
335+ for (int k = 0 ; k < cpy_ne; ++k) {
336+ ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
337+ }
308338 }
309339 }
310340 }
@@ -345,14 +375,54 @@ static __global__ void flash_attn_tile(
345375 kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
346376
347377 float kqsum_add = 0 .0f ;
378+ if (kq_stride % (4 *warp_size) == 0 && cpy_ne % 4 == 0 ) {
348379#pragma unroll
349- for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
350- const int i = i0 + threadIdx .x ;
380+ for (int i0 = 0 ; i0 < kq_stride; i0 += 4 * warp_size) {
381+ const int i = i0 + 4 * threadIdx .x ;
351382
352- const float diff = KQ[j][i] - kqmax[j0/nwarps];
353- const float val = expf (diff);
354- kqsum_add += val;
355- KQ[j][i] = val;
383+ float4 val = *(const float4 *) &KQ[j][i];
384+ val.x = expf (val.x - kqmax[j0/nwarps]);
385+ val.y = expf (val.y - kqmax[j0/nwarps]);
386+ val.z = expf (val.z - kqmax[j0/nwarps]);
387+ val.w = expf (val.w - kqmax[j0/nwarps]);
388+ kqsum_add += val.x + val.y + val.z + val.w ;
389+
390+ #ifdef FAST_FP16_AVAILABLE
391+ const half2 tmp[2 ] = {make_half2 (val.x , val.y ), make_half2 (val.z , val.w )};
392+ ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
393+ #else
394+ ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
395+ #endif // FAST_FP16_AVAILABLE
396+ }
397+ } else if (kq_stride % (2 *warp_size) == 0 && cpy_ne % 2 == 0 ) {
398+ #pragma unroll
399+ for (int i0 = 0 ; i0 < kq_stride; i0 += 2 *warp_size) {
400+ const int i = i0 + 2 *threadIdx .x ;
401+
402+ float2 val = *(const float2 *) &KQ[j][i];
403+ val.x = expf (val.x - kqmax[j0/nwarps]);
404+ val.y = expf (val.y - kqmax[j0/nwarps]);
405+ kqsum_add += val.x + val.y ;
406+ #ifdef FAST_FP16_AVAILABLE
407+ const half2 tmp = make_half2 (val.x , val.y );
408+ ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
409+ #else
410+ ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
411+ #endif // FAST_FP16_AVAILABLE
412+ }
413+ } else {
414+ for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
415+ const int i = i0 + threadIdx .x ;
416+
417+ const float diff = KQ[j][i] - kqmax[j0/nwarps];
418+ const float val = expf (diff);
419+ kqsum_add += val;
420+ #ifdef FAST_FP16_AVAILABLE
421+ ((half *) KQ[j])[i] = val;
422+ #else
423+ KQ[j][i] = val;
424+ #endif // FAST_FP16_AVAILABLE
425+ }
356426 }
357427 kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
358428
@@ -419,8 +489,7 @@ static __global__ void flash_attn_tile(
419489 const int j = j0 + threadIdx .y ;
420490
421491#ifdef FAST_FP16_AVAILABLE
422- const float tmp = KQ[j][k0 + k1];
423- KQ_k[j0/nwarps] = make_half2 (tmp, tmp);
492+ KQ_k[j0/nwarps] = __half2half2 (((const half *)KQ[j])[k0 + k1]);
424493#else
425494 KQ_k[j0/nwarps] = KQ[j][k0 + k1];
426495#endif // FAST_FP16_AVAILABLE
0 commit comments