Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
GIT_TAG 6f27a0544e027b53fc93d8d24f65cc2b75eedc78
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand All @@ -64,4 +64,4 @@ install(
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
)
)
7 changes: 6 additions & 1 deletion vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,14 +595,19 @@ def get_flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
if current_platform.get_device_capability()[0] == 9:
fa_version = 3 if is_fa_version_supported(3) else 2
else:
fa_version = 2

if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
if (current_platform.get_device_capability()[0] == 10
and envs.VLLM_FLASH_ATTN_VERSION == 3):
logger.warning("Cannot use FA version 3 on Blackwell platform",
"defaulting to FA version 2.")
fa_version = 2

if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
Expand Down