@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
647647}
648648
649649template <int D> // D == head size
650- #if !defined(GGML_USE_HIP)
651650__launch_bounds__ (D, 1 )
652- #endif // !(defined(GGML_USE_HIP)
653651static __global__ void flash_attn_combine_results(
654652 const float * __restrict__ VKQ_parts,
655653 const float2 * __restrict__ VKQ_meta,
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
692690 float VKQ_numerator = 0 .0f ;
693691 float VKQ_denominator = 0 .0f ;
694692 for (int l = 0 ; l < parallel_blocks; ++l) {
695- const float diff = meta[l].x - kqmax;
696- float KQ_max_scale = expf (diff);
697- const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
698- *((uint32_t *) &KQ_max_scale) &= ftz_mask;
693+ const float KQ_max_scale = expf (meta[l].x - kqmax);
699694
700695 VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
701696 VKQ_denominator += KQ_max_scale * meta[l].y ;
@@ -836,11 +831,10 @@ void launch_fattn(
836831 CUDA_CHECK (cudaGetLastError ());
837832 }
838833
839- int parallel_blocks = 1 ;
840-
841834 const dim3 block_dim (warp_size, nwarps, 1 );
842835 int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
843836 CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
837+ int parallel_blocks = max_blocks_per_sm;
844838
845839 dim3 blocks_num;
846840 if (stream_k) {
@@ -862,9 +856,6 @@ void launch_fattn(
862856 GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
863857 const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
864858
865- // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
866- parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
867-
868859 // parallel_blocks must not be larger than what the tensor size allows:
869860 parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
870861
0 commit comments