|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +"""Attention backend registry""" |
| 4 | + |
| 5 | +import enum |
| 6 | +from typing import Optional, Type |
| 7 | + |
| 8 | +from vllm.utils import resolve_obj_by_qualname |
| 9 | + |
| 10 | + |
| 11 | +class _Backend(enum.Enum): |
| 12 | + FLASH_ATTN = enum.auto() |
| 13 | + TRITON_ATTN = enum.auto() |
| 14 | + XFORMERS = enum.auto() |
| 15 | + ROCM_FLASH = enum.auto() |
| 16 | + ROCM_AITER_MLA = enum.auto() # Supported by V1 |
| 17 | + ROCM_AITER_FA = enum.auto() # used for ViT attn backend |
| 18 | + TORCH_SDPA = enum.auto() |
| 19 | + FLASHINFER = enum.auto() |
| 20 | + FLASHINFER_MLA = enum.auto() |
| 21 | + TRITON_MLA = enum.auto() # Supported by V1 |
| 22 | + CUTLASS_MLA = enum.auto() |
| 23 | + FLASHMLA = enum.auto() # Supported by V1 |
| 24 | + FLASH_ATTN_MLA = enum.auto() # Supported by V1 |
| 25 | + PALLAS = enum.auto() |
| 26 | + IPEX = enum.auto() |
| 27 | + DUAL_CHUNK_FLASH_ATTN = enum.auto() |
| 28 | + DIFFERENTIAL_FLASH_ATTN = enum.auto() |
| 29 | + NO_ATTENTION = enum.auto() |
| 30 | + FLEX_ATTENTION = enum.auto() |
| 31 | + TREE_ATTN = enum.auto() |
| 32 | + ROCM_ATTN = enum.auto() |
| 33 | + |
| 34 | + |
| 35 | +BACKEND_MAPPING = {} |
| 36 | + |
| 37 | + |
| 38 | +def register_attn_backend(backend: _Backend, class_path: str): |
| 39 | + """ |
| 40 | + Decorator: register a custom attention backend into BACKEND_MAPPING. |
| 41 | + Validation: only checks if 'backend' is a valid _Backend enum member. |
| 42 | + Overwriting existing mappings is allowed. |
| 43 | + """ |
| 44 | + if not isinstance(backend, _Backend): |
| 45 | + raise ValueError(f"{backend} is not a valid _Backend enum value.") |
| 46 | + |
| 47 | + def decorator(cls): |
| 48 | + BACKEND_MAPPING[backend] = class_path |
| 49 | + return cls |
| 50 | + |
| 51 | + return decorator |
| 52 | + |
| 53 | + |
| 54 | +def backend_to_class_str(backend: _Backend) -> str: |
| 55 | + """Get the backend class string |
| 56 | + |
| 57 | + Args: |
| 58 | + backend: The backend enum value |
| 59 | + |
| 60 | + Returns: |
| 61 | + The backend class string |
| 62 | + """ |
| 63 | + return BACKEND_MAPPING[backend] |
| 64 | + |
| 65 | + |
| 66 | +def backend_to_class(backend: _Backend) -> Type: |
| 67 | + """Get the backend class. |
| 68 | +
|
| 69 | + Args: |
| 70 | + backend: The backend enum value |
| 71 | +
|
| 72 | + Returns: |
| 73 | + The backend class |
| 74 | + """ |
| 75 | + backend_class_name = backend_to_class_str(backend) |
| 76 | + return resolve_obj_by_qualname(backend_class_name) |
| 77 | + |
| 78 | + |
| 79 | +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: |
| 80 | + """ |
| 81 | + Convert a string backend name to a _Backend enum value. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + _Backend: enum value if backend_name is a valid in-tree type |
| 85 | + None: otherwise it's an invalid in-tree type or an out-of-tree platform |
| 86 | + is loaded. |
| 87 | + """ |
| 88 | + assert backend_name is not None |
| 89 | + backend_name = backend_name.removesuffix("_VLLM_V1") |
| 90 | + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ |
| 91 | + None |
0 commit comments