Skip to content

Commit ecff830

Browse files
authored
[ROCm] Env variable to trigger custom PA (#15557)
Signed-off-by: Gregory Shtrasberg <[email protected]>
1 parent dcf2a59 commit ecff830

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,4 +908,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
908908
and (qtype == torch.half or qtype == torch.bfloat16)
909909
and (head_size == 64 or head_size == 128)
910910
and (block_size == 16 or block_size == 32)
911-
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
911+
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
912+
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
7979
VLLM_ROCM_FP8_PADDING: bool = True
8080
VLLM_ROCM_MOE_PADDING: bool = True
81+
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
8182
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
8283
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
8384
VLLM_DISABLE_COMPILE_CACHE: bool = False
@@ -541,6 +542,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
541542
"VLLM_ROCM_MOE_PADDING":
542543
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
543544

545+
# custom paged attention kernel for MI3* cards
546+
"VLLM_ROCM_CUSTOM_PAGED_ATTN":
547+
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
548+
("true", "1")),
549+
544550
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
545551
"Q_SCALE_CONSTANT":
546552
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),

0 commit comments

Comments
 (0)