Skip to content

Commit 0e64636

Browse files
Merge remote-tracking branch 'upstream/main' into feature/gdn-apc
2 parents 538c9a0 + ea97940 commit 0e64636

File tree

7 files changed

+209
-33
lines changed

7 files changed

+209
-33
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,21 @@ def _compare_cp_with_tp(
204204

205205

206206
CP_TEXT_GENERATION_MODELS = {
207-
# [MLA attention only]
208207
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
209208
CPTestSettings.detailed(),
210209
CPTestSettings.detailed(tp_base=2),
211210
],
211+
"bigcode/gpt_bigcode-santacoder": [
212+
CPTestSettings.detailed(),
213+
CPTestSettings.detailed(tp_base=2),
214+
],
212215
}
213216

214217
CP_TEST_MODELS = [
215218
# TODO support other models
216219
# [LANGUAGE GENERATION]
217220
"deepseek-ai/DeepSeek-V2-Lite-Chat",
221+
"bigcode/gpt_bigcode-santacoder",
218222
]
219223

220224

tests/models/registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,10 @@ def check_available_online(
262262
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
263263
"GPTBigCodeForCausalLM": _HfExamplesInfo(
264264
"bigcode/starcoder",
265-
extras={"tiny": "bigcode/tiny_starcoder_py"},
265+
extras={
266+
"tiny": "bigcode/tiny_starcoder_py",
267+
"santacoder": "bigcode/gpt_bigcode-santacoder",
268+
},
266269
min_transformers_version="4.55.1",
267270
transformers_version_reason="HF model broken in 4.55.0",
268271
),

vllm/attention/ops/common.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def cp_lse_ag_out_rs(
173173
cp_attn_lse: torch.Tensor,
174174
cp_group: GroupCoordinator,
175175
ctx: CPTritonContext = None,
176+
return_lse=False,
176177
):
177178
"""
178179
cp_attn_out: [ B, H, D ]
@@ -192,8 +193,15 @@ def cp_lse_ag_out_rs(
192193

193194
cp_attn_lse = cp_attn_lse.contiguous()
194195
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
195-
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
196+
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
197+
assert out.is_contiguous()
196198
out = cp_group.reduce_scatter(out, dim=1)
199+
200+
if return_lse:
201+
cp_num_heads = lse.shape[1] // cp_group.world_size
202+
cp_rank = cp_group.rank_in_group
203+
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
204+
return out, lse
197205
return out
198206

199207

vllm/config/model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,23 @@ def verify_with_parallel_config(
12021202
"Supported models implement the `SupportsPP` interface."
12031203
)
12041204

1205+
decode_context_parallel_size = parallel_config.decode_context_parallel_size
1206+
if decode_context_parallel_size > 1 and not self.use_mla:
1207+
total_num_kv_heads = self.get_total_num_kv_heads()
1208+
assert tensor_parallel_size > total_num_kv_heads, (
1209+
f"tensor parallel size {tensor_parallel_size} must be greater "
1210+
f"than total num kv heads {total_num_kv_heads} when enable "
1211+
f"decode context parallel for GQA/MQA"
1212+
)
1213+
1214+
max_dcp_size = tensor_parallel_size // total_num_kv_heads
1215+
assert decode_context_parallel_size <= max_dcp_size, (
1216+
f"decode context parallel size must less than or equal to "
1217+
f"(tensor parallel size {tensor_parallel_size} // total "
1218+
f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, "
1219+
f"but got {decode_context_parallel_size}"
1220+
)
1221+
12051222
def get_sliding_window(self) -> int | None:
12061223
"""Get the sliding window size from the HF text config if present."""
12071224
return getattr(self.hf_text_config, "sliding_window", None)

vllm/v1/attention/backends/flash_attn.py

Lines changed: 172 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_quantized_kv_cache,
1818
)
1919
from vllm.attention.layer import Attention
20+
from vllm.attention.ops.common import cp_lse_ag_out_rs
2021
from vllm.attention.ops.merge_attn_states import merge_attn_states
2122
from vllm.attention.utils.fa_utils import (
2223
flash_attn_supports_fp8,
@@ -32,6 +33,7 @@
3233
)
3334

3435
from vllm.config import VllmConfig, get_layers_from_vllm_config
36+
from vllm.distributed.parallel_state import get_dcp_group
3537
from vllm.logger import init_logger
3638
from vllm.utils import cdiv
3739
from vllm.v1.attention.backends.utils import (
@@ -147,6 +149,10 @@ class FlashAttentionMetadata:
147149
prefix_kv_lens: torch.Tensor | None
148150
suffix_kv_lens: torch.Tensor | None
149151

152+
# For GQA DCP
153+
max_dcp_context_kv_len: int | None = None
154+
dcp_context_kv_lens: torch.Tensor | None = None
155+
150156
# Optional aot scheduling
151157
scheduler_metadata: torch.Tensor | None = None
152158
prefix_scheduler_metadata: torch.Tensor | None = None
@@ -216,6 +222,16 @@ def __init__(
216222
self.max_num_splits = 0 # No upper bound on the number of splits.
217223
self.aot_schedule = get_flash_attn_version() == 3
218224

225+
try:
226+
from vllm.distributed.parallel_state import get_dcp_group
227+
228+
self.dcp_world_size = get_dcp_group().world_size
229+
self.dcp_rank = get_dcp_group().rank_in_group
230+
except AssertionError:
231+
# DCP might not be initialized in testing
232+
self.dcp_world_size = 1
233+
self.dcp_rank = 0
234+
219235
self.use_full_cuda_graph = (
220236
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
221237
)
@@ -306,7 +322,7 @@ def schedule(
306322
batch_size=batch_size,
307323
max_seqlen_q=max_query_len,
308324
max_seqlen_k=max_seq_len,
309-
num_heads_q=self.num_heads_q,
325+
num_heads_q=self.num_heads_q * self.dcp_world_size,
310326
num_heads_kv=self.num_heads_kv,
311327
headdim=self.headdim,
312328
cache_seqlens=seqlens,
@@ -320,8 +336,35 @@ def schedule(
320336
return None
321337

322338
use_cascade = common_prefix_len > 0
339+
max_dcp_context_kv_len = 0
340+
dcp_context_kv_lens = None
341+
342+
cu_prefix_query_lens = None
343+
prefix_kv_lens = None
344+
suffix_kv_lens = None
345+
prefix_scheduler_metadata = None
346+
347+
if self.dcp_world_size > 1:
348+
query_kv_lens_cpu = (
349+
common_attn_metadata.query_start_loc_cpu[1:]
350+
- common_attn_metadata.query_start_loc_cpu[:-1]
351+
)
352+
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
353+
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
354+
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
355+
)
356+
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
357+
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
323358

324-
if use_cascade:
359+
scheduler_metadata = schedule(
360+
batch_size=num_reqs,
361+
cu_query_lens=query_start_loc,
362+
max_query_len=max_query_len,
363+
seqlens=dcp_context_kv_lens,
364+
max_seq_len=max_dcp_context_kv_len,
365+
causal=False,
366+
)
367+
elif use_cascade:
325368
cu_prefix_query_lens = torch.tensor(
326369
[0, num_actual_tokens], dtype=torch.int32, device=self.device
327370
)
@@ -348,10 +391,6 @@ def schedule(
348391
causal=True,
349392
)
350393
else:
351-
cu_prefix_query_lens = None
352-
prefix_kv_lens = None
353-
suffix_kv_lens = None
354-
prefix_scheduler_metadata = None
355394
scheduler_metadata = schedule(
356395
batch_size=num_reqs,
357396
cu_query_lens=query_start_loc,
@@ -379,6 +418,8 @@ def schedule(
379418
seq_lens=seq_lens,
380419
block_table=block_table_tensor,
381420
slot_mapping=slot_mapping,
421+
max_dcp_context_kv_len=max_dcp_context_kv_len,
422+
dcp_context_kv_lens=dcp_context_kv_lens,
382423
use_cascade=use_cascade,
383424
common_prefix_len=common_prefix_len,
384425
scheduler_metadata=scheduler_metadata,
@@ -396,6 +437,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:
396437

397438

398439
class FlashAttentionImpl(AttentionImpl):
440+
can_return_lse_for_decode: bool = True
441+
399442
def __init__(
400443
self,
401444
num_heads: int,
@@ -562,30 +605,45 @@ def forward(
562605

563606
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
564607

565-
flash_attn_varlen_func(
566-
q=query[:num_actual_tokens],
567-
k=key_cache,
568-
v=value_cache,
569-
out=output[:num_actual_tokens],
570-
cu_seqlens_q=cu_seqlens_q,
571-
max_seqlen_q=max_seqlen_q,
572-
seqused_k=seqused_k,
573-
max_seqlen_k=max_seqlen_k,
574-
softmax_scale=self.scale,
575-
causal=attn_metadata.causal,
576-
alibi_slopes=self.alibi_slopes,
577-
window_size=self.sliding_window,
578-
block_table=block_table,
579-
softcap=self.logits_soft_cap,
580-
scheduler_metadata=scheduler_metadata,
581-
fa_version=self.vllm_flash_attn_version,
582-
q_descale=layer._q_scale.expand(descale_shape),
583-
k_descale=layer._k_scale.expand(descale_shape),
584-
v_descale=layer._v_scale.expand(descale_shape),
585-
num_splits=attn_metadata.max_num_splits,
586-
s_aux=self.sinks,
587-
)
588-
return output
608+
if self.dcp_world_size > 1:
609+
self._forward_with_dcp(
610+
query[:num_actual_tokens],
611+
key[:num_actual_tokens],
612+
value[:num_actual_tokens],
613+
key_cache,
614+
value_cache,
615+
output[:num_actual_tokens],
616+
attn_metadata,
617+
q_descale=layer._q_scale.expand(descale_shape),
618+
k_descale=layer._k_scale.expand(descale_shape),
619+
v_descale=layer._v_scale.expand(descale_shape),
620+
)
621+
return output
622+
else:
623+
flash_attn_varlen_func(
624+
q=query[:num_actual_tokens],
625+
k=key_cache,
626+
v=value_cache,
627+
out=output[:num_actual_tokens],
628+
cu_seqlens_q=cu_seqlens_q,
629+
max_seqlen_q=max_seqlen_q,
630+
seqused_k=seqused_k,
631+
max_seqlen_k=max_seqlen_k,
632+
softmax_scale=self.scale,
633+
causal=attn_metadata.causal,
634+
alibi_slopes=self.alibi_slopes,
635+
window_size=self.sliding_window,
636+
block_table=block_table,
637+
softcap=self.logits_soft_cap,
638+
scheduler_metadata=scheduler_metadata,
639+
fa_version=self.vllm_flash_attn_version,
640+
q_descale=layer._q_scale.expand(descale_shape),
641+
k_descale=layer._k_scale.expand(descale_shape),
642+
v_descale=layer._v_scale.expand(descale_shape),
643+
num_splits=attn_metadata.max_num_splits,
644+
s_aux=self.sinks,
645+
)
646+
return output
589647

590648
# Cascade attention (rare case).
591649
cascade_attention(
@@ -615,6 +673,86 @@ def forward(
615673
)
616674
return output
617675

676+
def _forward_with_dcp(
677+
self,
678+
query: torch.Tensor,
679+
key: torch.Tensor,
680+
value: torch.Tensor,
681+
key_cache: torch.Tensor,
682+
value_cache: torch.Tensor,
683+
output: torch.Tensor,
684+
attn_metadata: FlashAttentionMetadata,
685+
q_descale: torch.Tensor | None = None,
686+
k_descale: torch.Tensor | None = None,
687+
v_descale: torch.Tensor | None = None,
688+
) -> torch.Tensor:
689+
cu_seqlens_q = attn_metadata.query_start_loc
690+
max_seqlen_q = attn_metadata.max_query_len
691+
block_table = attn_metadata.block_table
692+
693+
query = query.contiguous()
694+
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
695+
context_attn_out, context_lse = flash_attn_varlen_func(
696+
q=query_across_dcp,
697+
k=key_cache,
698+
v=value_cache,
699+
out=None,
700+
cu_seqlens_q=cu_seqlens_q,
701+
max_seqlen_q=max_seqlen_q,
702+
seqused_k=attn_metadata.dcp_context_kv_lens,
703+
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
704+
softmax_scale=self.scale,
705+
causal=False,
706+
alibi_slopes=self.alibi_slopes,
707+
window_size=self.sliding_window,
708+
block_table=block_table,
709+
softcap=self.logits_soft_cap,
710+
return_softmax_lse=True,
711+
scheduler_metadata=attn_metadata.scheduler_metadata,
712+
fa_version=self.vllm_flash_attn_version,
713+
q_descale=q_descale,
714+
k_descale=k_descale,
715+
v_descale=v_descale,
716+
)
717+
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
718+
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
719+
context_attn_out,
720+
context_lse.transpose(0, 1),
721+
get_dcp_group(),
722+
return_lse=True,
723+
)
724+
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
725+
726+
query_attn_out, query_lse = flash_attn_varlen_func(
727+
q=query,
728+
k=key,
729+
v=value,
730+
out=None,
731+
cu_seqlens_q=cu_seqlens_q,
732+
max_seqlen_q=max_seqlen_q,
733+
cu_seqlens_k=cu_seqlens_q,
734+
max_seqlen_k=max_seqlen_q,
735+
softmax_scale=self.scale,
736+
causal=attn_metadata.causal,
737+
alibi_slopes=self.alibi_slopes,
738+
window_size=self.sliding_window,
739+
softcap=self.logits_soft_cap,
740+
return_softmax_lse=True,
741+
fa_version=self.vllm_flash_attn_version,
742+
q_descale=q_descale,
743+
k_descale=k_descale,
744+
v_descale=v_descale,
745+
)
746+
assert context_attn_out_cor.shape == query_attn_out.shape
747+
assert context_lse_cor.shape == query_lse.shape
748+
merge_attn_states(
749+
output,
750+
context_attn_out_cor,
751+
context_lse_cor,
752+
query_attn_out,
753+
query_lse,
754+
)
755+
618756
def _forward_encoder_attention(
619757
self,
620758
query: torch.Tensor,
@@ -684,6 +822,7 @@ def use_cascade_attention(
684822
use_sliding_window: bool,
685823
use_local_attention: bool,
686824
num_sms: int,
825+
dcp_world_size: int,
687826
) -> bool:
688827
"""Decide whether to use cascade attention.
689828
@@ -705,6 +844,9 @@ def use_cascade_attention(
705844
num_reqs = len(query_lens)
706845
if num_reqs < 8:
707846
return False
847+
# disable cascade attention for DCP
848+
if dcp_world_size > 1:
849+
return False
708850

709851
# Heuristics to decide whether using cascade attention is beneficial.
710852
# 1. When FlashDecoding is not used for normal attention, cascade attention

vllm/v1/attention/backends/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def use_cascade_attention(
345345
use_sliding_window: bool,
346346
use_local_attention: bool,
347347
num_sms: int,
348+
dcp_world_size: int,
348349
) -> bool:
349350
return False
350351

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,7 @@ def _compute_cascade_attn_prefix_len(
15231523
use_sliding_window=use_sliding_window,
15241524
use_local_attention=use_local_attention,
15251525
num_sms=self.num_sms,
1526+
dcp_world_size=self.dcp_world_size,
15261527
)
15271528
return common_prefix_len if use_cascade else 0
15281529

0 commit comments

Comments
 (0)