@@ -285,28 +285,31 @@ def _get_decode_attn_bias(
285285) -> list [BlockDiagonalGappyKeysMask ]:
286286 """
287287 Generate attention bias masks for decode phase in context parallel.
288-
289- This function creates attention masks that allow queries to attend to KV cache
290- distributed across different CP ranks. Each sequence's KV cache is owned by a specific
291- rank, and we need to create appropriate masks for cross-rank attention.
292-
288+
289+ This function creates attention masks that allow queries to attend to
290+ KV cache distributed across different CP ranks. Each sequence's KV cache
291+ is owned by a specific rank, and we need to create appropriate masks for
292+ cross-rank attention.
293+
293294 Example:
294295 If we have 2 CP ranks and 3 sequences:
295296 - q_seqlens = [1, 1, 1] # Each decode query has length 1
296297 - kv_seqlens = [10, 15, 8] # KV cache lengths for each sequence
297- - current_kv_rank = [0, 1, 0] # Which rank owns the CURRENT token's KV pair
298-
298+ - current_kv_rank = [0, 1, 0] # Which rank owns the CURRENT token's
299+ KV pair
300+
299301 For src_rank=0: mask=[0, -1, 0] -> adjusted_kv_seqlens=[10, 14, 8]
300302 For src_rank=1: mask=[-1, 0, -1] -> adjusted_kv_seqlens=[9, 15, 7]
301-
302- This creates masks where sequences not owned by src_rank have reduced length,
303- effectively masking out the last token position.
304-
303+
304+ This creates masks where sequences not owned by src_rank have reduced
305+ length, effectively masking out the last token position.
306+
305307 Args:
306308 q_seqlens: Query sequence lengths for each sequence in the batch
307- kv_seqlens: Key-value cache lengths for each sequence
308- current_kv_rank: Tensor indicating which CP rank owns each sequence's KV cache
309-
309+ kv_seqlens: Key-value cache lengths for each sequence
310+ current_kv_rank: Tensor indicating which CP rank owns each sequence's
311+ KV cache
312+
310313 Returns:
311314 List of BlockDiagonalGappyKeysMask objects, one for each source rank
312315 """
@@ -348,30 +351,32 @@ def _get_prefill_attn_bias(
348351) -> list [BlockDiagonalGappyKeysMask ]:
349352 """
350353 Generate attention bias masks for prefill phase in context parallel.
351-
352- This function creates attention masks for distributed prefill computation where
353- queries and KV cache are sharded across multiple CP ranks. It handles
354- both causal masking (for local rank) and block diagonal masking (for remote ranks).
355-
354+
355+ This function creates attention masks for distributed prefill computation
356+ where queries and KV cache are sharded across multiple CP ranks. It handles
357+ both causal masking (for local rank) and block diagonal masking (for remote
358+ ranks).
359+
356360 Example with 2 CP ranks and 2 requests:
357- cp_sharded_q_seqlen = [[4, 6], [3, 5]] # Sharded query lengths per request
361+ cp_sharded_q_seqlen = [[4, 6], [3, 5]] # Sharded query lengths per req
358362 cp_sharded_pass_x_kvlens_per_rank = [
359363 [[8, 12], [6, 10]], # KV lengths for rank 0
360364 [[7, 11], [5, 9]] # KV lengths for rank 1
361365 ]
362-
366+
363367 For rank 0 (cp_rank=0):
364- - Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data (rank 0)
365- - Uses BlockDiagonalGappyKeysMask for remote data (rank 1)
366-
368+ - Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data
369+ - Uses BlockDiagonalGappyKeysMask for remote data
370+
367371 For rank 1 (cp_rank=1):
368- - Uses BlockDiagonalGappyKeysMask for remote data (rank 0)
369- - Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data (rank 1)
370-
372+ - Uses BlockDiagonalGappyKeysMask for remote data
373+ - Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data
374+
371375 Args:
372376 cp_sharded_q_seqlen: Query sequence lengths [request][cp_shard]
373- cp_sharded_pass_x_kvlens_per_rank: KV lengths [src_rank][request][cp_shard]
374-
377+ cp_sharded_pass_x_kvlens_per_rank: KV lengths
378+ [src_rank][request][cp_shard]
379+
375380 Returns:
376381 List of attention bias masks, one for each source rank
377382 """
@@ -386,19 +391,25 @@ def flatten(kv_seqlens: list[list[int]]) -> list[int]:
386391 cp_sharded_q_seqlen_flatten = flatten (cp_sharded_q_seqlen )
387392
388393 # Determine bias type for each source rank:
389- # - Causal mask for local rank (allows attending to past and current tokens)
390- # - Block diagonal mask for remote ranks (allows attending to all tokens in block)
391- # TODO: use PagedBlockDiagonalCausalWithOffsetGappyKeysMask for local attention
394+ # - Causal mask for local rank (allows attending to past and current
395+ # tokens)
396+ # - Block diagonal mask for remote ranks (allows attending to all tokens
397+ # in block)
398+ # TODO: use PagedBlockDiagonalCausalWithOffsetGappyKeysMask for local
399+ # attention
392400 bias_type = [(BlockDiagonalCausalWithOffsetGappyKeysMask
393401 if cp_rank == i else BlockDiagonalGappyKeysMask )
394402 for i in range (cp_size )]
395403
396404 def get_kv_seqstarts (kv_seqlen : list [int ]) -> list [int ]:
397405 """
398- Calculate starting positions for KV sequences in attention computation.
399-
400- Processes pairs of KV lengths to determine where each sequence block starts.
401- Example: kv_seqlen=[8, 12, 6, 10] -> kv_seqstarts=[0, 0, 12, 12, 22]
406+ Calculate starting positions for KV sequences in attention
407+ computation.
408+
409+ Processes pairs of KV lengths to determine where each sequence
410+ block starts.
411+ Example: kv_seqlen=[8, 12, 6, 10] ->
412+ kv_seqstarts=[0, 0, 12, 12, 22]
402413 """
403414 kv_seqstarts = [0 ]
404415 for i in range (0 , len (kv_seqlen ), 2 ):
@@ -425,7 +436,7 @@ def _cp_partial_prefill_get_kv_seqlens(
425436 num_computed_tokens : int ,
426437) -> list [list [int ]]:
427438 # For prefill by passing KV among CP group, get
428- # the KV seqlens (part of the attention bias) for computing partial attention
439+ # the KV seqlens (part of the attention bias) for computing partial attn
429440 # on KV received from each CP rank.
430441 cp_world_size = get_context_parallel_world_size ()
431442 cp_rank = get_context_parallel_rank ()
@@ -476,7 +487,7 @@ def _merge_attn_flash_partial(
476487 attn_out : list [torch .Tensor ],
477488 attn_lse : list [torch .Tensor ],
478489) -> torch .Tensor :
479- # merges the partial attention outputs from flash varseq fwd to get the final attention output
490+ # merges partial attention outputs from flash varseq fwd to final output
480491 assert len (attn_out ) == len (attn_lse )
481492 assert len (attn_out ) >= 1
482493
@@ -494,29 +505,32 @@ def _merge_attn_flash_partial(
494505
495506
496507def _prefill_pass_kv_attention (
497- cp_world_size : int ,
498- cp_rank : int ,
499- cache_k : torch .Tensor ,
500- cache_v : torch .Tensor ,
501- xq_out : torch .Tensor ,
502- slot_mapping : torch .Tensor ,
503- B_T : int ,
504- N_H_L : int ,
505- D_H : int ,
506- attn_bias : list [BlockDiagonalGappyKeysMask ],
508+ cp_world_size : int ,
509+ cp_rank : int ,
510+ cache_k : torch .Tensor ,
511+ cache_v : torch .Tensor ,
512+ xq_out : torch .Tensor ,
513+ slot_mapping : torch .Tensor ,
514+ B_T : int ,
515+ N_H_L : int ,
516+ D_H : int ,
517+ attn_bias : list [BlockDiagonalGappyKeysMask ], # type: ignore
507518) -> torch .Tensor :
508519 """
509- Computes attention for fused varseq prompt by passing KV among CP group for best
510- overlap between CP comms and attention compute. KV from different prefill batches
511- are padded to the maximum seqlen in the fused prefill.
520+ Computes attention for fused varseq prompt by passing KV among CP group for
521+ best overlap between CP comms and attention compute. KV from different
522+ prefill batches are padded to the maximum seqlen in the fused prefill.
512523
513524 Args:
514- max_global_kvlen: maximum seqlen in current batch, used for pass_kv only
525+ max_global_kvlen: maximum seqlen in current batch, used for pass_kv
526+ only
515527 prefetched_lengths: indicates the starting position of cache, used for
516528 duplicate_kv with persistent cache enabled
517529 varseq_batch_dedup: batch indices of the current batch.
518530 varseq_seqlen: padded seqlen after cp sharding
519531 """
532+
533+ assert XFORMERS_AVAILABLE
520534 # TODO: extract KV pieces after local attention
521535 cache_k_ = torch .index_select (cache_k , 1 , slot_mapping )
522536 cache_v_ = torch .index_select (cache_v , 1 , slot_mapping )
@@ -533,7 +547,7 @@ def _prefill_pass_kv_attention(
533547 next_tensors , reqs = cp_pass_around ([cache_k_ , cache_v_ , src_rank ],
534548 to_rank , from_rank )
535549 # local partial attn
536- attn_out_self , lse_out_self = xops .fmha .memory_efficient_attention_partial (
550+ attn_out_self , lse_out_self = xops .fmha .memory_efficient_attention_partial ( # type: ignore
537551 xq_out ,
538552 cache_k_self ,
539553 cache_v_self ,
@@ -557,7 +571,7 @@ def _prefill_pass_kv_attention(
557571 cache_k_i_ , cache_v_i_ = (t .view (1 , - 1 , N_H_L , D_H )
558572 for t in (cache_k_i , cache_v_i ))
559573
560- attn_out_i , lse_out_i = xops .fmha .memory_efficient_attention_partial (
574+ attn_out_i , lse_out_i = xops .fmha .memory_efficient_attention_partial ( # type: ignore
561575 xq_out ,
562576 cache_k_i_ ,
563577 cache_v_i_ ,
@@ -573,25 +587,28 @@ def _prefill_pass_kv_attention(
573587
574588
575589def _decode_allgather_attention (
576- cache_k : torch .Tensor ,
577- cache_v : torch .Tensor ,
578- xq_out : torch .Tensor ,
579- slot_mapping : torch .Tensor ,
580- B_T : int ,
581- N_H_L : int ,
582- D_H : int ,
583- attn_bias : list [BlockDiagonalGappyKeysMask ],
590+ cache_k : torch .Tensor ,
591+ cache_v : torch .Tensor ,
592+ xq_out : torch .Tensor ,
593+ slot_mapping : torch .Tensor ,
594+ B_T : int ,
595+ N_H_L : int ,
596+ D_H : int ,
597+ attn_bias : list [BlockDiagonalGappyKeysMask ], # type: ignore
584598) -> torch .Tensor :
585599 """
586600 Supports CP decode by allgather partial attention among CP ranks.
587601 This function distributes attention computation across multiple CP ranks by:
588- 1. Each CP rank computes partial attention: Attn(local_Q, local_KV)
589- 2. All ranks gather partial attention outputs and log-sum-exp values via allgather
602+ 1. Each CP rank computes partial attention: Attn(local_Q, local_KV)
603+ 2. All ranks gather partial attention outputs and log-sum-exp values via
604+ allgather
590605 3. Merges all partial attention results to produce final attention output
591606
592607 Returns:
593608 Merged attention output tensor [1, B_T, N_H_L * D_H]
594609 """
610+
611+ assert XFORMERS_AVAILABLE
595612 cp_rank = get_context_parallel_rank ()
596613
597614 cache_k_ = torch .index_select (cache_k , 1 ,
@@ -601,7 +618,7 @@ def _decode_allgather_attention(
601618
602619 xq_out = xq_out .view (1 , B_T , N_H_L , D_H )
603620
604- attn_out_ = xops .fmha .memory_efficient_attention_partial (
621+ attn_out_ = xops .fmha .memory_efficient_attention_partial ( # type: ignore
605622 xq_out ,
606623 cache_k_ ,
607624 cache_v_ ,
0 commit comments