Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/compile/test_fusions_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ModelBackendTestCase(NamedTuple):
ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=97,
),
Expand Down
17 changes: 15 additions & 2 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple):
class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str | None = None


@dataclass
Expand All @@ -58,6 +59,7 @@ def detailed(
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str | None = None,
):
parallel_setups = []
for eager_mode_val in [False]:
Expand All @@ -79,7 +81,9 @@ def detailed(
distributed_backends=["mp"],
runner=runner,
test_options=CPTestOptions(
multi_node_only=multi_node_only, load_format=load_format
multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend,
),
)

Expand Down Expand Up @@ -117,7 +121,7 @@ def _compare_cp_with_tp(
chunked_prefill,
) = parallel_setup

multi_node_only, load_format = test_options
multi_node_only, load_format, attn_backend = test_options

model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
Expand Down Expand Up @@ -177,6 +181,13 @@ def _compare_cp_with_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

if not attn_backend:
cp_env = tp_env = {}
else:
cp_env = tp_env = {
"VLLM_ATTENTION_BACKEND": attn_backend,
}

cp_args = [
*common_args,
"--tensor-parallel-size",
Expand Down Expand Up @@ -205,6 +216,8 @@ def _compare_cp_with_tp(
model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720,
)
Expand Down
8 changes: 8 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,14 @@ def verify_with_parallel_config(
f"but got {decode_context_parallel_size}"
)

num_q_per_kv = total_num_attention_heads // total_num_kv_heads
assert num_q_per_kv % decode_context_parallel_size == 0, (
f"Total number of q per kv attn heads ({num_q_per_kv})"
" must be divisible by dcp world size when enable "
"decode context parallel for GQA "
f"({parallel_config.decode_context_parallel_size})."
)

def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)
Expand Down
9 changes: 9 additions & 0 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def use_trtllm_attention(
num_kv_heads: int,
num_tokens: int,
max_seq_len: int,
dcp_world_size: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
Expand All @@ -261,6 +262,14 @@ def use_trtllm_attention(
if force_use_trtllm is not None and not force_use_trtllm:
return False

# Decode context parallel is not supported
if dcp_world_size > 1:
logger.warning_once(
"Trtllm does not support returning LSE and as a result"
"does not support DCP, reverting to FlashInfer"
)
return False

# The platform is not supported
if not supports_trtllm_attention():
if force_use_trtllm:
Expand Down
Loading