1717    is_quantized_kv_cache ,
1818)
1919from  vllm .attention .layer  import  Attention 
20+ from  vllm .attention .ops .common  import  cp_lse_ag_out_rs 
2021from  vllm .attention .ops .merge_attn_states  import  merge_attn_states 
2122from  vllm .attention .utils .fa_utils  import  (
2223    flash_attn_supports_fp8 ,
3233    )
3334
3435from  vllm .config  import  VllmConfig , get_layers_from_vllm_config 
36+ from  vllm .distributed .parallel_state  import  get_dcp_group 
3537from  vllm .logger  import  init_logger 
3638from  vllm .utils  import  cdiv 
3739from  vllm .v1 .attention .backends .utils  import  (
@@ -147,6 +149,10 @@ class FlashAttentionMetadata:
147149    prefix_kv_lens : torch .Tensor  |  None 
148150    suffix_kv_lens : torch .Tensor  |  None 
149151
152+     # For GQA DCP 
153+     max_dcp_context_kv_len : int  |  None  =  None 
154+     dcp_context_kv_lens : torch .Tensor  |  None  =  None 
155+ 
150156    # Optional aot scheduling 
151157    scheduler_metadata : torch .Tensor  |  None  =  None 
152158    prefix_scheduler_metadata : torch .Tensor  |  None  =  None 
@@ -216,6 +222,16 @@ def __init__(
216222        self .max_num_splits  =  0   # No upper bound on the number of splits. 
217223        self .aot_schedule  =  get_flash_attn_version () ==  3 
218224
225+         try :
226+             from  vllm .distributed .parallel_state  import  get_dcp_group 
227+ 
228+             self .dcp_world_size  =  get_dcp_group ().world_size 
229+             self .dcp_rank  =  get_dcp_group ().rank_in_group 
230+         except  AssertionError :
231+             # DCP might not be initialized in testing 
232+             self .dcp_world_size  =  1 
233+             self .dcp_rank  =  0 
234+ 
219235        self .use_full_cuda_graph  =  (
220236            self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
221237        )
@@ -306,7 +322,7 @@ def schedule(
306322                    batch_size = batch_size ,
307323                    max_seqlen_q = max_query_len ,
308324                    max_seqlen_k = max_seq_len ,
309-                     num_heads_q = self .num_heads_q ,
325+                     num_heads_q = self .num_heads_q   *   self . dcp_world_size ,
310326                    num_heads_kv = self .num_heads_kv ,
311327                    headdim = self .headdim ,
312328                    cache_seqlens = seqlens ,
@@ -320,8 +336,35 @@ def schedule(
320336            return  None 
321337
322338        use_cascade  =  common_prefix_len  >  0 
339+         max_dcp_context_kv_len  =  0 
340+         dcp_context_kv_lens  =  None 
341+ 
342+         cu_prefix_query_lens  =  None 
343+         prefix_kv_lens  =  None 
344+         suffix_kv_lens  =  None 
345+         prefix_scheduler_metadata  =  None 
346+ 
347+         if  self .dcp_world_size  >  1 :
348+             query_kv_lens_cpu  =  (
349+                 common_attn_metadata .query_start_loc_cpu [1 :]
350+                 -  common_attn_metadata .query_start_loc_cpu [:- 1 ]
351+             )
352+             dcp_context_kv_lens_cpu  =  seq_lens_cpu  -  query_kv_lens_cpu 
353+             dcp_context_kv_lens_cpu  =  dcp_context_kv_lens_cpu  //  self .dcp_world_size  +  (
354+                 self .dcp_rank  <=  (dcp_context_kv_lens_cpu  -  1 ) %  self .dcp_world_size 
355+             )
356+             dcp_context_kv_lens  =  dcp_context_kv_lens_cpu .to (self .device )
357+             max_dcp_context_kv_len  =  dcp_context_kv_lens .max ().item ()
323358
324-         if  use_cascade :
359+             scheduler_metadata  =  schedule (
360+                 batch_size = num_reqs ,
361+                 cu_query_lens = query_start_loc ,
362+                 max_query_len = max_query_len ,
363+                 seqlens = dcp_context_kv_lens ,
364+                 max_seq_len = max_dcp_context_kv_len ,
365+                 causal = False ,
366+             )
367+         elif  use_cascade :
325368            cu_prefix_query_lens  =  torch .tensor (
326369                [0 , num_actual_tokens ], dtype = torch .int32 , device = self .device 
327370            )
@@ -348,10 +391,6 @@ def schedule(
348391                causal = True ,
349392            )
350393        else :
351-             cu_prefix_query_lens  =  None 
352-             prefix_kv_lens  =  None 
353-             suffix_kv_lens  =  None 
354-             prefix_scheduler_metadata  =  None 
355394            scheduler_metadata  =  schedule (
356395                batch_size = num_reqs ,
357396                cu_query_lens = query_start_loc ,
@@ -379,6 +418,8 @@ def schedule(
379418            seq_lens = seq_lens ,
380419            block_table = block_table_tensor ,
381420            slot_mapping = slot_mapping ,
421+             max_dcp_context_kv_len = max_dcp_context_kv_len ,
422+             dcp_context_kv_lens = dcp_context_kv_lens ,
382423            use_cascade = use_cascade ,
383424            common_prefix_len = common_prefix_len ,
384425            scheduler_metadata = scheduler_metadata ,
@@ -396,6 +437,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:
396437
397438
398439class  FlashAttentionImpl (AttentionImpl ):
440+     can_return_lse_for_decode : bool  =  True 
441+ 
399442    def  __init__ (
400443        self ,
401444        num_heads : int ,
@@ -562,30 +605,45 @@ def forward(
562605
563606            descale_shape  =  (cu_seqlens_q .shape [0 ] -  1 , self .num_kv_heads )
564607
565-             flash_attn_varlen_func (
566-                 q = query [:num_actual_tokens ],
567-                 k = key_cache ,
568-                 v = value_cache ,
569-                 out = output [:num_actual_tokens ],
570-                 cu_seqlens_q = cu_seqlens_q ,
571-                 max_seqlen_q = max_seqlen_q ,
572-                 seqused_k = seqused_k ,
573-                 max_seqlen_k = max_seqlen_k ,
574-                 softmax_scale = self .scale ,
575-                 causal = attn_metadata .causal ,
576-                 alibi_slopes = self .alibi_slopes ,
577-                 window_size = self .sliding_window ,
578-                 block_table = block_table ,
579-                 softcap = self .logits_soft_cap ,
580-                 scheduler_metadata = scheduler_metadata ,
581-                 fa_version = self .vllm_flash_attn_version ,
582-                 q_descale = layer ._q_scale .expand (descale_shape ),
583-                 k_descale = layer ._k_scale .expand (descale_shape ),
584-                 v_descale = layer ._v_scale .expand (descale_shape ),
585-                 num_splits = attn_metadata .max_num_splits ,
586-                 s_aux = self .sinks ,
587-             )
588-             return  output 
608+             if  self .dcp_world_size  >  1 :
609+                 self ._forward_with_dcp (
610+                     query [:num_actual_tokens ],
611+                     key [:num_actual_tokens ],
612+                     value [:num_actual_tokens ],
613+                     key_cache ,
614+                     value_cache ,
615+                     output [:num_actual_tokens ],
616+                     attn_metadata ,
617+                     q_descale = layer ._q_scale .expand (descale_shape ),
618+                     k_descale = layer ._k_scale .expand (descale_shape ),
619+                     v_descale = layer ._v_scale .expand (descale_shape ),
620+                 )
621+                 return  output 
622+             else :
623+                 flash_attn_varlen_func (
624+                     q = query [:num_actual_tokens ],
625+                     k = key_cache ,
626+                     v = value_cache ,
627+                     out = output [:num_actual_tokens ],
628+                     cu_seqlens_q = cu_seqlens_q ,
629+                     max_seqlen_q = max_seqlen_q ,
630+                     seqused_k = seqused_k ,
631+                     max_seqlen_k = max_seqlen_k ,
632+                     softmax_scale = self .scale ,
633+                     causal = attn_metadata .causal ,
634+                     alibi_slopes = self .alibi_slopes ,
635+                     window_size = self .sliding_window ,
636+                     block_table = block_table ,
637+                     softcap = self .logits_soft_cap ,
638+                     scheduler_metadata = scheduler_metadata ,
639+                     fa_version = self .vllm_flash_attn_version ,
640+                     q_descale = layer ._q_scale .expand (descale_shape ),
641+                     k_descale = layer ._k_scale .expand (descale_shape ),
642+                     v_descale = layer ._v_scale .expand (descale_shape ),
643+                     num_splits = attn_metadata .max_num_splits ,
644+                     s_aux = self .sinks ,
645+                 )
646+                 return  output 
589647
590648        # Cascade attention (rare case). 
591649        cascade_attention (
@@ -615,6 +673,86 @@ def forward(
615673        )
616674        return  output 
617675
676+     def  _forward_with_dcp (
677+         self ,
678+         query : torch .Tensor ,
679+         key : torch .Tensor ,
680+         value : torch .Tensor ,
681+         key_cache : torch .Tensor ,
682+         value_cache : torch .Tensor ,
683+         output : torch .Tensor ,
684+         attn_metadata : FlashAttentionMetadata ,
685+         q_descale : torch .Tensor  |  None  =  None ,
686+         k_descale : torch .Tensor  |  None  =  None ,
687+         v_descale : torch .Tensor  |  None  =  None ,
688+     ) ->  torch .Tensor :
689+         cu_seqlens_q  =  attn_metadata .query_start_loc 
690+         max_seqlen_q  =  attn_metadata .max_query_len 
691+         block_table  =  attn_metadata .block_table 
692+ 
693+         query  =  query .contiguous ()
694+         query_across_dcp  =  get_dcp_group ().all_gather (query , dim = 1 )
695+         context_attn_out , context_lse  =  flash_attn_varlen_func (
696+             q = query_across_dcp ,
697+             k = key_cache ,
698+             v = value_cache ,
699+             out = None ,
700+             cu_seqlens_q = cu_seqlens_q ,
701+             max_seqlen_q = max_seqlen_q ,
702+             seqused_k = attn_metadata .dcp_context_kv_lens ,
703+             max_seqlen_k = attn_metadata .max_dcp_context_kv_len ,
704+             softmax_scale = self .scale ,
705+             causal = False ,
706+             alibi_slopes = self .alibi_slopes ,
707+             window_size = self .sliding_window ,
708+             block_table = block_table ,
709+             softcap = self .logits_soft_cap ,
710+             return_softmax_lse = True ,
711+             scheduler_metadata = attn_metadata .scheduler_metadata ,
712+             fa_version = self .vllm_flash_attn_version ,
713+             q_descale = q_descale ,
714+             k_descale = k_descale ,
715+             v_descale = v_descale ,
716+         )
717+         # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] 
718+         context_attn_out_cor , context_lse_cor  =  cp_lse_ag_out_rs (
719+             context_attn_out ,
720+             context_lse .transpose (0 , 1 ),
721+             get_dcp_group (),
722+             return_lse = True ,
723+         )
724+         context_lse_cor  =  context_lse_cor .transpose (0 , 1 ).contiguous ()
725+ 
726+         query_attn_out , query_lse  =  flash_attn_varlen_func (
727+             q = query ,
728+             k = key ,
729+             v = value ,
730+             out = None ,
731+             cu_seqlens_q = cu_seqlens_q ,
732+             max_seqlen_q = max_seqlen_q ,
733+             cu_seqlens_k = cu_seqlens_q ,
734+             max_seqlen_k = max_seqlen_q ,
735+             softmax_scale = self .scale ,
736+             causal = attn_metadata .causal ,
737+             alibi_slopes = self .alibi_slopes ,
738+             window_size = self .sliding_window ,
739+             softcap = self .logits_soft_cap ,
740+             return_softmax_lse = True ,
741+             fa_version = self .vllm_flash_attn_version ,
742+             q_descale = q_descale ,
743+             k_descale = k_descale ,
744+             v_descale = v_descale ,
745+         )
746+         assert  context_attn_out_cor .shape  ==  query_attn_out .shape 
747+         assert  context_lse_cor .shape  ==  query_lse .shape 
748+         merge_attn_states (
749+             output ,
750+             context_attn_out_cor ,
751+             context_lse_cor ,
752+             query_attn_out ,
753+             query_lse ,
754+         )
755+ 
618756    def  _forward_encoder_attention (
619757        self ,
620758        query : torch .Tensor ,
@@ -684,6 +822,7 @@ def use_cascade_attention(
684822    use_sliding_window : bool ,
685823    use_local_attention : bool ,
686824    num_sms : int ,
825+     dcp_world_size : int ,
687826) ->  bool :
688827    """Decide whether to use cascade attention. 
689828
@@ -705,6 +844,9 @@ def use_cascade_attention(
705844    num_reqs  =  len (query_lens )
706845    if  num_reqs  <  8 :
707846        return  False 
847+     # disable cascade attention for DCP 
848+     if  dcp_world_size  >  1 :
849+         return  False 
708850
709851    # Heuristics to decide whether using cascade attention is beneficial. 
710852    # 1. When FlashDecoding is not used for normal attention, cascade attention 
0 commit comments