Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,7 @@ def forward(
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"
9 changes: 8 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import torch

from vllm import _custom_ops as ops
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_flash_attn_version,
Expand Down Expand Up @@ -626,6 +630,9 @@ def __init__(
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
Expand Down Expand Up @@ -207,6 +208,10 @@ def __init__(
"are not implemented for "
"FlashMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand All @@ -215,8 +220,6 @@ def _forward_decode(
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")

decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
Expand Down
7 changes: 6 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
Expand Down Expand Up @@ -158,6 +159,10 @@ def __init__(
"are not implemented for "
"HPUAttentionImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"HPUAttention with FP8 KV cache not yet supported")

def forward(
self,
layer: AttentionLayer,
Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState


Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
Expand Down
8 changes: 6 additions & 2 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
import torch
from torch.nn.functional import scaled_dot_product_attention

# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
Expand Down Expand Up @@ -427,7 +431,7 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(
"are not implemented for "
"TritonMLAImpl")

if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"TritonMLA with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand All @@ -66,8 +70,6 @@ def _forward_decode(
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")

decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.logger import init_logger
Expand Down Expand Up @@ -180,6 +181,9 @@ def __init__(
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
Expand Down
9 changes: 6 additions & 3 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
Expand Down Expand Up @@ -107,6 +108,10 @@ def __init__(
"are not implemented for "
"FlashMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand All @@ -115,8 +120,6 @@ def _forward_decode(
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")

q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
Expand Down Expand Up @@ -61,6 +62,10 @@ def __init__(
"are not implemented for "
"TritonMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand Down