@@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
116116 }
117117}
118118
119- #define FATTN_VEC_CASE (D, type_K, type_V ) \
120- if (Q->ne[0 ] == (D) && K->type == (type_K) && V->type == (type_V)) { \
121- ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
122- return ; \
123- } \
119+ #define FATTN_VEC_CASE (D, type_K, type_V ) \
120+ { \
121+ const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
122+ const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
123+ if (Q->ne [0 ] == (D) && type_K_okay && type_V_okay) { \
124+ ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
125+ return ; \
126+ } \
127+ } \
124128
125129#define FATTN_VEC_CASES_ALL_D (type_K, type_V ) \
126130 FATTN_VEC_CASE ( 64 , type_K, type_V) \
@@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
247251#endif // GGML_CUDA_FA_ALL_QUANTS
248252
249253 switch (K->type ) {
254+ case GGML_TYPE_F32:
250255 case GGML_TYPE_F16:
251256 break ;
252257 case GGML_TYPE_Q4_1:
@@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
272277 // If Turing tensor cores available, use them:
273278 if (turing_mma_available (cc) && K->ne [1 ] % FATTN_KQ_STRIDE == 0 && Q->ne [0 ] != 40 ) {
274279 if (can_use_vector_kernel) {
275- if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 ) {
280+ if (! ggml_is_quantized ( K->type ) && ! ggml_is_quantized ( V->type ) ) {
276281 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1 && Q->ne [3 ] == 1 && !(gqa_ratio > 4 && K->ne [1 ] >= 8192 )) {
277282 return BEST_FATTN_KERNEL_VEC;
278283 }
@@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
305310
306311 // If there are no tensor cores available, use the generic tile kernel:
307312 if (can_use_vector_kernel) {
308- if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 ) {
313+ if (! ggml_is_quantized ( K->type ) && ! ggml_is_quantized ( V->type ) ) {
309314 if (Q->ne [1 ] == 1 ) {
310315 if (!gqa_opt_applies) {
311316 return BEST_FATTN_KERNEL_VEC;
0 commit comments