@@ -647,7 +647,9 @@ static __global__ void flash_attn_stream_k_fixup(
647647}
648648
649649template <int D> // D == head size
650+ #if !defined(GGML_USE_HIP)
650651__launch_bounds__ (D, 1 )
652+ #endif // !(defined(GGML_USE_HIP)
651653static __global__ void flash_attn_combine_results (
652654 const float * __restrict__ VKQ_parts,
653655 const float2 * __restrict__ VKQ_meta,
@@ -690,7 +692,10 @@ static __global__ void flash_attn_combine_results(
690692 float VKQ_numerator = 0 .0f ;
691693 float VKQ_denominator = 0 .0f ;
692694 for (int l = 0 ; l < parallel_blocks; ++l) {
693- const float KQ_max_scale = expf (meta[l].x - kqmax);
695+ const float diff = meta[l].x - kqmax;
696+ float KQ_max_scale = expf (diff);
697+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
698+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
694699
695700 VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
696701 VKQ_denominator += KQ_max_scale * meta[l].y ;
@@ -831,10 +836,11 @@ void launch_fattn(
831836 CUDA_CHECK (cudaGetLastError ());
832837 }
833838
839+ int parallel_blocks = 1 ;
840+
834841 const dim3 block_dim (warp_size, nwarps, 1 );
835842 int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
836843 CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
837- int parallel_blocks = max_blocks_per_sm;
838844
839845 dim3 blocks_num;
840846 if (stream_k) {
@@ -856,6 +862,9 @@ void launch_fattn(
856862 GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
857863 const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
858864
865+ // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
866+ parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
867+
859868 // parallel_blocks must not be larger than what the tensor size allows:
860869 parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
861870
0 commit comments