Skip to content

Commit e7acb20

Browse files
[Feature] Batch invariant torch.compile (#27660)
Signed-off-by: PaulZhang12 <[email protected]> Co-authored-by: Wentao Ye <[email protected]>
1 parent 4b68c4a commit e7acb20

File tree

4 files changed

+82
-9
lines changed

4 files changed

+82
-9
lines changed

vllm/config/model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from vllm.config.scheduler import RunnerType
2121
from vllm.config.utils import assert_hashable, config, getattr_iter
2222
from vllm.logger import init_logger
23-
from vllm.model_executor.layers.batch_invariant import (
24-
vllm_is_batch_invariant,
25-
)
2623
from vllm.platforms import current_platform
2724
from vllm.transformers_utils.config import (
2825
ConfigFormat,
@@ -436,10 +433,6 @@ def __post_init__(
436433
skip_mm_profiling: bool | None,
437434
video_pruning_rate: float | None,
438435
) -> None:
439-
# Enable batch invariance settings if requested
440-
if vllm_is_batch_invariant():
441-
self.enforce_eager = True
442-
443436
# Set the default seed to 0 in V1.
444437
# NOTE(woosuk): In V1, we use separate processes for workers (unless
445438
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here

vllm/envs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def disable_compile_cache() -> bool:
251251

252252

253253
def use_aot_compile() -> bool:
254+
from vllm.model_executor.layers.batch_invariant import (
255+
vllm_is_batch_invariant,
256+
)
254257
from vllm.utils.torch_utils import is_torch_equal_or_newer
255258

256259
default_value = (
@@ -259,7 +262,10 @@ def use_aot_compile() -> bool:
259262
else "0"
260263
)
261264

262-
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
265+
return (
266+
not vllm_is_batch_invariant()
267+
and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
268+
)
263269

264270

265271
def env_with_choices(

vllm/model_executor/layers/batch_invariant.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import vllm.envs as envs
1212
from vllm.logger import init_logger
1313
from vllm.triton_utils import tl, triton
14+
from vllm.utils.torch_utils import is_torch_equal_or_newer
1415

1516
logger = init_logger(__name__)
1617

@@ -716,6 +717,10 @@ def linear_batch_invariant(input, weight, bias=None):
716717
_batch_invariant_MODE = False
717718
_batch_invariant_LIB = None
718719
_original_torch_bmm = None
720+
_original_fp16_reduction_precision = None
721+
_original_bf16_reduction_precision = None
722+
_original_cublas_workspace_cfg = None
723+
_original_cublaslt_workspace_size = None
719724

720725

721726
def is_batch_invariant_mode_enabled():
@@ -724,6 +729,8 @@ def is_batch_invariant_mode_enabled():
724729

725730
def enable_batch_invariant_mode():
726731
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
732+
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
733+
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
727734
if _batch_invariant_MODE:
728735
return
729736

@@ -745,14 +752,75 @@ def enable_batch_invariant_mode():
745752
_original_torch_bmm = torch.bmm
746753
torch.bmm = bmm_batch_invariant
747754

755+
_original_bf16_reduction_precision = (
756+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
757+
)
758+
_original_fp16_reduction_precision = (
759+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
760+
)
761+
762+
reduced_precision_val = (
763+
(False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
764+
)
765+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
766+
reduced_precision_val
767+
)
768+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
769+
reduced_precision_val
770+
)
771+
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
772+
773+
if not is_torch_equal_or_newer("2.10.0.dev"):
774+
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
775+
_original_cublaslt_workspace_size = os.environ.get(
776+
"CUBLASLT_WORKSPACE_SIZE", None
777+
)
778+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
779+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
780+
748781

749782
def disable_batch_invariant_mode():
750783
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
784+
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
785+
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
786+
if not _batch_invariant_MODE:
787+
return
788+
751789
if _batch_invariant_LIB is not None:
752790
_batch_invariant_LIB._destroy()
753791
if _original_torch_bmm is not None:
754792
torch.bmm = _original_torch_bmm
755793
_original_torch_bmm = None
794+
795+
if _original_bf16_reduction_precision is not None:
796+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
797+
_original_bf16_reduction_precision
798+
)
799+
_original_bf16_reduction_precision = None
800+
if _original_fp16_reduction_precision is not None:
801+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
802+
_original_fp16_reduction_precision
803+
)
804+
_original_fp16_reduction_precision = None
805+
806+
torch.backends.cuda.preferred_blas_library(backend="default")
807+
808+
if not is_torch_equal_or_newer("2.10.0.dev"):
809+
# Set cublas env vars to previous results. If previous results are None,
810+
# that means the env vars were not set, so we should remove them.
811+
if _original_cublas_workspace_cfg:
812+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg
813+
elif "CUBLAS_WORKSPACE_CONFIG" in os.environ:
814+
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
815+
816+
if _original_cublaslt_workspace_size:
817+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size
818+
elif "CUBLASLT_WORKSPACE_SIZE" in os.environ:
819+
del os.environ["CUBLASLT_WORKSPACE_SIZE"]
820+
821+
_original_cublas_workspace_cfg = None
822+
_original_cublaslt_workspace_size = None
823+
756824
_batch_invariant_MODE = False
757825
_batch_invariant_LIB = None
758826

@@ -831,6 +899,9 @@ def override_envs_for_invariance():
831899
os.environ["NCCL_NTHREADS"] = "1"
832900
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
833901

902+
# torch.compile settings
903+
os.environ["VLLM_USE_AOT_COMPILE"] = "0"
904+
834905

835906
def init_batch_invariance():
836907
# this will hit all the csrc overrides as well

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def __init__(self, quant_config: Fp8Config):
363363
self.use_marlin = False
364364

365365
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
366+
self.use_deep_gemm = is_deep_gemm_supported()
366367

367368
self.weight_block_size = self.quant_config.weight_block_size
368369
self.block_quant = self.weight_block_size is not None
@@ -545,8 +546,10 @@ def apply(
545546
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
546547
# we will use BF16 dequant when DeepGEMM is not supported.
547548
if vllm_is_batch_invariant():
549+
# Call is_deep_gemm_supported() ahead of time for torch.compile
550+
# dynamo has trouble tracing through
548551
if self.block_quant and should_use_deepgemm_for_fp8_linear(
549-
torch.bfloat16, layer.weight, None
552+
torch.bfloat16, layer.weight, self.use_deep_gemm
550553
):
551554
# use group quant consistent with block size across K
552555
assert self.act_q_group_shape is not None

0 commit comments

Comments
 (0)