Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/design/cuda_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
| AITER FlashAttention | `UNIFORM_BATCH`| |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
| FlashMLA | `UNIFORM_BATCH` | |
| FlashInferMLA | `UNIFORM_BATCH` | |
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
Expand Down
8 changes: 6 additions & 2 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""Attention layer with FlashInfer."""

from dataclasses import dataclass
from typing import ClassVar

import numpy as np
import torch
Expand Down Expand Up @@ -272,7 +271,9 @@


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = (
# When using TRTLLM attention with cudagraphs, we can use UNIFORM_BATCH
# mode. This will be overridden in the initializer if supported.
cudagraph_support: AttentionCGSupport = (

Check failure on line 276 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot override class variable (previously declared on base class "AttentionMetadataBuilder") with instance variable [misc]

Check failure on line 276 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot override class variable (previously declared on base class "AttentionMetadataBuilder") with instance variable [misc]

Check failure on line 276 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot override class variable (previously declared on base class "AttentionMetadataBuilder") with instance variable [misc]
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

Expand Down Expand Up @@ -354,7 +355,10 @@
else:
self.q_data_type = self.model_config.dtype

# If using trtllm attention, we can support uniform_batch speculative decoding
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
if can_use_trtllm:
self.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

self._cascade_wrapper = None # Wrapper for cascade attention

Expand Down
Loading