@@ -647,7 +647,9 @@ static __global__ void flash_attn_stream_k_fixup(
647647}
648648
649649template <int  D> //  D == head size
650+ #if  !defined(GGML_USE_HIP)
650651__launch_bounds__ (D, 1 )
652+ #endif  //  !(defined(GGML_USE_HIP)
651653static  __global__  void  flash_attn_combine_results (
652654        const  float   * __restrict__  VKQ_parts,
653655        const  float2  * __restrict__  VKQ_meta,
@@ -690,7 +692,10 @@ static __global__ void flash_attn_combine_results(
690692    float  VKQ_numerator   = 0 .0f ;
691693    float  VKQ_denominator = 0 .0f ;
692694    for  (int  l = 0 ; l < parallel_blocks; ++l) {
693-         const  float  KQ_max_scale = expf (meta[l].x  - kqmax);
695+         const  float  diff = meta[l].x  - kqmax;
696+         float  KQ_max_scale = expf (diff);
697+         const  uint32_t  ftz_mask = 0xFFFFFFFF  * (diff > SOFTMAX_FTZ_THRESHOLD);
698+         *((uint32_t  *) &KQ_max_scale) &= ftz_mask;
694699
695700        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
696701        VKQ_denominator += KQ_max_scale * meta[l].y ;
@@ -831,10 +836,11 @@ void launch_fattn(
831836        CUDA_CHECK (cudaGetLastError ());
832837    }
833838
839+     int  parallel_blocks = 1 ;
840+ 
834841    const  dim3  block_dim (warp_size, nwarps, 1 );
835842    int  max_blocks_per_sm = 1 ; //  Max. number of active blocks limited by occupancy.
836843    CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x  * block_dim.y  * block_dim.z , nbytes_shared));
837-     int  parallel_blocks = max_blocks_per_sm;
838844
839845    dim3  blocks_num;
840846    if  (stream_k) {
@@ -856,6 +862,9 @@ void launch_fattn(
856862        GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
857863        const  int  ntiles_KQ = K->ne [1 ] / KQ_row_granularity; //  Max. number of parallel blocks limited by tensor size.
858864
865+         //  parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
866+         parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
867+ 
859868        //  parallel_blocks must not be larger than what the tensor size allows:
860869        parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
861870
0 commit comments