Skip to content

Commit 0001743

Browse files
add registry
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent bb6d430 commit 0001743

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Attention backend registry"""
4+
5+
import enum
6+
from typing import Optional, Type
7+
8+
from vllm.utils import resolve_obj_by_qualname
9+
10+
11+
class _Backend(enum.Enum):
12+
FLASH_ATTN = enum.auto()
13+
TRITON_ATTN = enum.auto()
14+
XFORMERS = enum.auto()
15+
ROCM_FLASH = enum.auto()
16+
ROCM_AITER_MLA = enum.auto() # Supported by V1
17+
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
18+
TORCH_SDPA = enum.auto()
19+
FLASHINFER = enum.auto()
20+
FLASHINFER_MLA = enum.auto()
21+
TRITON_MLA = enum.auto() # Supported by V1
22+
CUTLASS_MLA = enum.auto()
23+
FLASHMLA = enum.auto() # Supported by V1
24+
FLASH_ATTN_MLA = enum.auto() # Supported by V1
25+
PALLAS = enum.auto()
26+
IPEX = enum.auto()
27+
DUAL_CHUNK_FLASH_ATTN = enum.auto()
28+
DIFFERENTIAL_FLASH_ATTN = enum.auto()
29+
NO_ATTENTION = enum.auto()
30+
FLEX_ATTENTION = enum.auto()
31+
TREE_ATTN = enum.auto()
32+
ROCM_ATTN = enum.auto()
33+
34+
35+
BACKEND_MAPPING = {}
36+
37+
38+
def register_attn_backend(backend: _Backend, class_path: str):
39+
"""
40+
Decorator: register a custom attention backend into BACKEND_MAPPING.
41+
Validation: only checks if 'backend' is a valid _Backend enum member.
42+
Overwriting existing mappings is allowed.
43+
"""
44+
if not isinstance(backend, _Backend):
45+
raise ValueError(f"{backend} is not a valid _Backend enum value.")
46+
47+
def decorator(cls):
48+
BACKEND_MAPPING[backend] = class_path
49+
return cls
50+
51+
return decorator
52+
53+
54+
def backend_to_class_str(backend: _Backend) -> str:
55+
"""Get the backend class string
56+
57+
Args:
58+
backend: The backend enum value
59+
60+
Returns:
61+
The backend class string
62+
"""
63+
return BACKEND_MAPPING[backend]
64+
65+
66+
def backend_to_class(backend: _Backend) -> Type:
67+
"""Get the backend class.
68+
69+
Args:
70+
backend: The backend enum value
71+
72+
Returns:
73+
The backend class
74+
"""
75+
backend_class_name = backend_to_class_str(backend)
76+
return resolve_obj_by_qualname(backend_class_name)
77+
78+
79+
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
80+
"""
81+
Convert a string backend name to a _Backend enum value.
82+
83+
Returns:
84+
_Backend: enum value if backend_name is a valid in-tree type
85+
None: otherwise it's an invalid in-tree type or an out-of-tree platform
86+
is loaded.
87+
"""
88+
assert backend_name is not None
89+
backend_name = backend_name.removesuffix("_VLLM_V1")
90+
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
91+
None

0 commit comments

Comments
 (0)