Skip to content

Commit 29cd86f

Browse files
author
Qirui Yang
committed
lint
1 parent 9e845dc commit 29cd86f

File tree

1 file changed

+82
-65
lines changed

1 file changed

+82
-65
lines changed

vllm/v1/attention/backends/xformers.py

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -285,28 +285,31 @@ def _get_decode_attn_bias(
285285
) -> list[BlockDiagonalGappyKeysMask]:
286286
"""
287287
Generate attention bias masks for decode phase in context parallel.
288-
289-
This function creates attention masks that allow queries to attend to KV cache
290-
distributed across different CP ranks. Each sequence's KV cache is owned by a specific
291-
rank, and we need to create appropriate masks for cross-rank attention.
292-
288+
289+
This function creates attention masks that allow queries to attend to
290+
KV cache distributed across different CP ranks. Each sequence's KV cache
291+
is owned by a specific rank, and we need to create appropriate masks for
292+
cross-rank attention.
293+
293294
Example:
294295
If we have 2 CP ranks and 3 sequences:
295296
- q_seqlens = [1, 1, 1] # Each decode query has length 1
296297
- kv_seqlens = [10, 15, 8] # KV cache lengths for each sequence
297-
- current_kv_rank = [0, 1, 0] # Which rank owns the CURRENT token's KV pair
298-
298+
- current_kv_rank = [0, 1, 0] # Which rank owns the CURRENT token's
299+
KV pair
300+
299301
For src_rank=0: mask=[0, -1, 0] -> adjusted_kv_seqlens=[10, 14, 8]
300302
For src_rank=1: mask=[-1, 0, -1] -> adjusted_kv_seqlens=[9, 15, 7]
301-
302-
This creates masks where sequences not owned by src_rank have reduced length,
303-
effectively masking out the last token position.
304-
303+
304+
This creates masks where sequences not owned by src_rank have reduced
305+
length, effectively masking out the last token position.
306+
305307
Args:
306308
q_seqlens: Query sequence lengths for each sequence in the batch
307-
kv_seqlens: Key-value cache lengths for each sequence
308-
current_kv_rank: Tensor indicating which CP rank owns each sequence's KV cache
309-
309+
kv_seqlens: Key-value cache lengths for each sequence
310+
current_kv_rank: Tensor indicating which CP rank owns each sequence's
311+
KV cache
312+
310313
Returns:
311314
List of BlockDiagonalGappyKeysMask objects, one for each source rank
312315
"""
@@ -348,30 +351,32 @@ def _get_prefill_attn_bias(
348351
) -> list[BlockDiagonalGappyKeysMask]:
349352
"""
350353
Generate attention bias masks for prefill phase in context parallel.
351-
352-
This function creates attention masks for distributed prefill computation where
353-
queries and KV cache are sharded across multiple CP ranks. It handles
354-
both causal masking (for local rank) and block diagonal masking (for remote ranks).
355-
354+
355+
This function creates attention masks for distributed prefill computation
356+
where queries and KV cache are sharded across multiple CP ranks. It handles
357+
both causal masking (for local rank) and block diagonal masking (for remote
358+
ranks).
359+
356360
Example with 2 CP ranks and 2 requests:
357-
cp_sharded_q_seqlen = [[4, 6], [3, 5]] # Sharded query lengths per request
361+
cp_sharded_q_seqlen = [[4, 6], [3, 5]] # Sharded query lengths per req
358362
cp_sharded_pass_x_kvlens_per_rank = [
359363
[[8, 12], [6, 10]], # KV lengths for rank 0
360364
[[7, 11], [5, 9]] # KV lengths for rank 1
361365
]
362-
366+
363367
For rank 0 (cp_rank=0):
364-
- Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data (rank 0)
365-
- Uses BlockDiagonalGappyKeysMask for remote data (rank 1)
366-
368+
- Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data
369+
- Uses BlockDiagonalGappyKeysMask for remote data
370+
367371
For rank 1 (cp_rank=1):
368-
- Uses BlockDiagonalGappyKeysMask for remote data (rank 0)
369-
- Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data (rank 1)
370-
372+
- Uses BlockDiagonalGappyKeysMask for remote data
373+
- Uses BlockDiagonalCausalWithOffsetGappyKeysMask for local data
374+
371375
Args:
372376
cp_sharded_q_seqlen: Query sequence lengths [request][cp_shard]
373-
cp_sharded_pass_x_kvlens_per_rank: KV lengths [src_rank][request][cp_shard]
374-
377+
cp_sharded_pass_x_kvlens_per_rank: KV lengths
378+
[src_rank][request][cp_shard]
379+
375380
Returns:
376381
List of attention bias masks, one for each source rank
377382
"""
@@ -386,19 +391,25 @@ def flatten(kv_seqlens: list[list[int]]) -> list[int]:
386391
cp_sharded_q_seqlen_flatten = flatten(cp_sharded_q_seqlen)
387392

388393
# Determine bias type for each source rank:
389-
# - Causal mask for local rank (allows attending to past and current tokens)
390-
# - Block diagonal mask for remote ranks (allows attending to all tokens in block)
391-
# TODO: use PagedBlockDiagonalCausalWithOffsetGappyKeysMask for local attention
394+
# - Causal mask for local rank (allows attending to past and current
395+
# tokens)
396+
# - Block diagonal mask for remote ranks (allows attending to all tokens
397+
# in block)
398+
# TODO: use PagedBlockDiagonalCausalWithOffsetGappyKeysMask for local
399+
# attention
392400
bias_type = [(BlockDiagonalCausalWithOffsetGappyKeysMask
393401
if cp_rank == i else BlockDiagonalGappyKeysMask)
394402
for i in range(cp_size)]
395403

396404
def get_kv_seqstarts(kv_seqlen: list[int]) -> list[int]:
397405
"""
398-
Calculate starting positions for KV sequences in attention computation.
399-
400-
Processes pairs of KV lengths to determine where each sequence block starts.
401-
Example: kv_seqlen=[8, 12, 6, 10] -> kv_seqstarts=[0, 0, 12, 12, 22]
406+
Calculate starting positions for KV sequences in attention
407+
computation.
408+
409+
Processes pairs of KV lengths to determine where each sequence
410+
block starts.
411+
Example: kv_seqlen=[8, 12, 6, 10] ->
412+
kv_seqstarts=[0, 0, 12, 12, 22]
402413
"""
403414
kv_seqstarts = [0]
404415
for i in range(0, len(kv_seqlen), 2):
@@ -425,7 +436,7 @@ def _cp_partial_prefill_get_kv_seqlens(
425436
num_computed_tokens: int,
426437
) -> list[list[int]]:
427438
# For prefill by passing KV among CP group, get
428-
# the KV seqlens (part of the attention bias) for computing partial attention
439+
# the KV seqlens (part of the attention bias) for computing partial attn
429440
# on KV received from each CP rank.
430441
cp_world_size = get_context_parallel_world_size()
431442
cp_rank = get_context_parallel_rank()
@@ -476,7 +487,7 @@ def _merge_attn_flash_partial(
476487
attn_out: list[torch.Tensor],
477488
attn_lse: list[torch.Tensor],
478489
) -> torch.Tensor:
479-
# merges the partial attention outputs from flash varseq fwd to get the final attention output
490+
# merges partial attention outputs from flash varseq fwd to final output
480491
assert len(attn_out) == len(attn_lse)
481492
assert len(attn_out) >= 1
482493

@@ -494,29 +505,32 @@ def _merge_attn_flash_partial(
494505

495506

496507
def _prefill_pass_kv_attention(
497-
cp_world_size: int,
498-
cp_rank: int,
499-
cache_k: torch.Tensor,
500-
cache_v: torch.Tensor,
501-
xq_out: torch.Tensor,
502-
slot_mapping: torch.Tensor,
503-
B_T: int,
504-
N_H_L: int,
505-
D_H: int,
506-
attn_bias: list[BlockDiagonalGappyKeysMask],
508+
cp_world_size: int,
509+
cp_rank: int,
510+
cache_k: torch.Tensor,
511+
cache_v: torch.Tensor,
512+
xq_out: torch.Tensor,
513+
slot_mapping: torch.Tensor,
514+
B_T: int,
515+
N_H_L: int,
516+
D_H: int,
517+
attn_bias: list[BlockDiagonalGappyKeysMask], # type: ignore
507518
) -> torch.Tensor:
508519
"""
509-
Computes attention for fused varseq prompt by passing KV among CP group for best
510-
overlap between CP comms and attention compute. KV from different prefill batches
511-
are padded to the maximum seqlen in the fused prefill.
520+
Computes attention for fused varseq prompt by passing KV among CP group for
521+
best overlap between CP comms and attention compute. KV from different
522+
prefill batches are padded to the maximum seqlen in the fused prefill.
512523
513524
Args:
514-
max_global_kvlen: maximum seqlen in current batch, used for pass_kv only
525+
max_global_kvlen: maximum seqlen in current batch, used for pass_kv
526+
only
515527
prefetched_lengths: indicates the starting position of cache, used for
516528
duplicate_kv with persistent cache enabled
517529
varseq_batch_dedup: batch indices of the current batch.
518530
varseq_seqlen: padded seqlen after cp sharding
519531
"""
532+
533+
assert XFORMERS_AVAILABLE
520534
# TODO: extract KV pieces after local attention
521535
cache_k_ = torch.index_select(cache_k, 1, slot_mapping)
522536
cache_v_ = torch.index_select(cache_v, 1, slot_mapping)
@@ -533,7 +547,7 @@ def _prefill_pass_kv_attention(
533547
next_tensors, reqs = cp_pass_around([cache_k_, cache_v_, src_rank],
534548
to_rank, from_rank)
535549
# local partial attn
536-
attn_out_self, lse_out_self = xops.fmha.memory_efficient_attention_partial(
550+
attn_out_self, lse_out_self = xops.fmha.memory_efficient_attention_partial( # type: ignore
537551
xq_out,
538552
cache_k_self,
539553
cache_v_self,
@@ -557,7 +571,7 @@ def _prefill_pass_kv_attention(
557571
cache_k_i_, cache_v_i_ = (t.view(1, -1, N_H_L, D_H)
558572
for t in (cache_k_i, cache_v_i))
559573

560-
attn_out_i, lse_out_i = xops.fmha.memory_efficient_attention_partial(
574+
attn_out_i, lse_out_i = xops.fmha.memory_efficient_attention_partial( # type: ignore
561575
xq_out,
562576
cache_k_i_,
563577
cache_v_i_,
@@ -573,25 +587,28 @@ def _prefill_pass_kv_attention(
573587

574588

575589
def _decode_allgather_attention(
576-
cache_k: torch.Tensor,
577-
cache_v: torch.Tensor,
578-
xq_out: torch.Tensor,
579-
slot_mapping: torch.Tensor,
580-
B_T: int,
581-
N_H_L: int,
582-
D_H: int,
583-
attn_bias: list[BlockDiagonalGappyKeysMask],
590+
cache_k: torch.Tensor,
591+
cache_v: torch.Tensor,
592+
xq_out: torch.Tensor,
593+
slot_mapping: torch.Tensor,
594+
B_T: int,
595+
N_H_L: int,
596+
D_H: int,
597+
attn_bias: list[BlockDiagonalGappyKeysMask], # type: ignore
584598
) -> torch.Tensor:
585599
"""
586600
Supports CP decode by allgather partial attention among CP ranks.
587601
This function distributes attention computation across multiple CP ranks by:
588-
1. Each CP rank computes partial attention: Attn(local_Q, local_KV)
589-
2. All ranks gather partial attention outputs and log-sum-exp values via allgather
602+
1. Each CP rank computes partial attention: Attn(local_Q, local_KV)
603+
2. All ranks gather partial attention outputs and log-sum-exp values via
604+
allgather
590605
3. Merges all partial attention results to produce final attention output
591606
592607
Returns:
593608
Merged attention output tensor [1, B_T, N_H_L * D_H]
594609
"""
610+
611+
assert XFORMERS_AVAILABLE
595612
cp_rank = get_context_parallel_rank()
596613

597614
cache_k_ = torch.index_select(cache_k, 1,
@@ -601,7 +618,7 @@ def _decode_allgather_attention(
601618

602619
xq_out = xq_out.view(1, B_T, N_H_L, D_H)
603620

604-
attn_out_ = xops.fmha.memory_efficient_attention_partial(
621+
attn_out_ = xops.fmha.memory_efficient_attention_partial( # type: ignore
605622
xq_out,
606623
cache_k_,
607624
cache_v_,

0 commit comments

Comments
 (0)