1717 compute_slot_mapping_start_idx , get_num_prefill_decode_query_kv_tokens ,
1818 get_seq_len_block_table_args , is_all_cross_attn_metadata_set ,
1919 is_all_encoder_attn_metadata_set , is_block_tables_empty )
20+ from vllm .envs import VLLM_FLASH_ATTN_VERSION
2021from vllm .multimodal import MultiModalPlaceholderMap
22+ from vllm .platforms import current_platform
2123from vllm .utils import async_tensor_h2d , make_tensor_with_pad
2224
2325if TYPE_CHECKING :
2426 from vllm .worker .model_runner import (ModelInputForGPUBuilder ,
2527 ModelInputForGPUWithSamplingMetadata )
2628
2729from vllm .vllm_flash_attn import (flash_attn_varlen_func ,
28- flash_attn_with_kvcache )
30+ flash_attn_with_kvcache ,
31+ is_fa_version_supported )
2932
3033
3134class FlashAttentionBackend (AttentionBackend ):
@@ -634,6 +637,20 @@ def __init__(
634637 f"Supported head sizes are: { support_head_sizes } ." )
635638 self .attn_type = attn_type
636639
640+ # if hopper default to FA3, otherwise stick to FA2 for now
641+ # TODO(lucas): profile FA3 on ampere to see if it makes sense to
642+ # use FA3 as default for both
643+ if current_platform .get_device_capability ()[0 ] >= 9 :
644+ self .fa_version = 3 if is_fa_version_supported (3 ) else 2
645+ else :
646+ self .fa_version = 2
647+
648+ if VLLM_FLASH_ATTN_VERSION is not None :
649+ assert VLLM_FLASH_ATTN_VERSION in [2 , 3 ]
650+ self .fa_version = VLLM_FLASH_ATTN_VERSION
651+
652+ assert is_fa_version_supported (self .fa_version )
653+
637654 def forward (
638655 self ,
639656 layer : AttentionLayer ,
@@ -752,6 +769,7 @@ def forward(
752769 alibi_slopes = alibi_slopes ,
753770 softcap = logits_soft_cap ,
754771 out = prefill_output ,
772+ fa_version = self .fa_version ,
755773 )
756774 else :
757775 # prefix-enabled attention
@@ -765,7 +783,7 @@ def forward(
765783 v = value_cache ,
766784 cu_seqlens_q = prefill_meta .query_start_loc ,
767785 max_seqlen_q = prefill_meta .max_query_len ,
768- cu_seqlens_k = prefill_meta .seq_start_loc ,
786+ seqused_k = prefill_meta .seq_lens_tensor ,
769787 max_seqlen_k = max_seq_len ,
770788 softmax_scale = softmax_scale ,
771789 causal = True ,
@@ -774,6 +792,7 @@ def forward(
774792 block_table = prefill_meta .block_tables ,
775793 softcap = logits_soft_cap ,
776794 out = prefill_output ,
795+ fa_version = self .fa_version ,
777796 )
778797
779798 if decode_meta := attn_metadata .decode_metadata :
@@ -793,7 +812,7 @@ def forward(
793812 v = value_cache ,
794813 cu_seqlens_q = decode_meta .query_start_loc ,
795814 max_seqlen_q = decode_meta .max_decode_query_len ,
796- cu_seqlens_k = decode_meta .seq_start_loc ,
815+ seqused_k = decode_meta .seq_lens_tensor ,
797816 max_seqlen_k = decode_meta .max_decode_seq_len ,
798817 softmax_scale = softmax_scale ,
799818 causal = True ,
@@ -802,6 +821,7 @@ def forward(
802821 softcap = logits_soft_cap ,
803822 block_table = decode_meta .block_tables ,
804823 out = decode_output ,
824+ fa_version = self .fa_version ,
805825 )
806826 else :
807827 # Use flash_attn_with_kvcache for normal decoding.
@@ -822,6 +842,7 @@ def forward(
822842 alibi_slopes = alibi_slopes ,
823843 softcap = logits_soft_cap ,
824844 out = decode_output .unsqueeze (1 ),
845+ fa_version = self .fa_version ,
825846 )
826847 return output
827848
0 commit comments