1111import vllm .envs as envs
1212from vllm .logger import init_logger
1313from vllm .triton_utils import tl , triton
14+ from vllm .utils .torch_utils import is_torch_equal_or_newer
1415
1516logger = init_logger (__name__ )
1617
@@ -716,6 +717,10 @@ def linear_batch_invariant(input, weight, bias=None):
716717_batch_invariant_MODE = False
717718_batch_invariant_LIB = None
718719_original_torch_bmm = None
720+ _original_fp16_reduction_precision = None
721+ _original_bf16_reduction_precision = None
722+ _original_cublas_workspace_cfg = None
723+ _original_cublaslt_workspace_size = None
719724
720725
721726def is_batch_invariant_mode_enabled ():
@@ -724,6 +729,8 @@ def is_batch_invariant_mode_enabled():
724729
725730def enable_batch_invariant_mode ():
726731 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
732+ global _original_fp16_reduction_precision , _original_bf16_reduction_precision
733+ global _original_cublas_workspace_cfg , _original_cublaslt_workspace_size
727734 if _batch_invariant_MODE :
728735 return
729736
@@ -745,14 +752,75 @@ def enable_batch_invariant_mode():
745752 _original_torch_bmm = torch .bmm
746753 torch .bmm = bmm_batch_invariant
747754
755+ _original_bf16_reduction_precision = (
756+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction
757+ )
758+ _original_fp16_reduction_precision = (
759+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction
760+ )
761+
762+ reduced_precision_val = (
763+ (False , False ) if is_torch_equal_or_newer ("2.10.0.dev" ) else False
764+ )
765+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = (
766+ reduced_precision_val
767+ )
768+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = (
769+ reduced_precision_val
770+ )
771+ torch .backends .cuda .preferred_blas_library (backend = "cublaslt" )
772+
773+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
774+ _original_cublas_workspace_cfg = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , None )
775+ _original_cublaslt_workspace_size = os .environ .get (
776+ "CUBLASLT_WORKSPACE_SIZE" , None
777+ )
778+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
779+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
780+
748781
749782def disable_batch_invariant_mode ():
750783 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
784+ global _original_fp16_reduction_precision , _original_bf16_reduction_precision
785+ global _original_cublas_workspace_cfg , _original_cublaslt_workspace_size
786+ if not _batch_invariant_MODE :
787+ return
788+
751789 if _batch_invariant_LIB is not None :
752790 _batch_invariant_LIB ._destroy ()
753791 if _original_torch_bmm is not None :
754792 torch .bmm = _original_torch_bmm
755793 _original_torch_bmm = None
794+
795+ if _original_bf16_reduction_precision is not None :
796+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = (
797+ _original_bf16_reduction_precision
798+ )
799+ _original_bf16_reduction_precision = None
800+ if _original_fp16_reduction_precision is not None :
801+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = (
802+ _original_fp16_reduction_precision
803+ )
804+ _original_fp16_reduction_precision = None
805+
806+ torch .backends .cuda .preferred_blas_library (backend = "default" )
807+
808+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
809+ # Set cublas env vars to previous results. If previous results are None,
810+ # that means the env vars were not set, so we should remove them.
811+ if _original_cublas_workspace_cfg :
812+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = _original_cublas_workspace_cfg
813+ elif "CUBLAS_WORKSPACE_CONFIG" in os .environ :
814+ del os .environ ["CUBLAS_WORKSPACE_CONFIG" ]
815+
816+ if _original_cublaslt_workspace_size :
817+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = _original_cublaslt_workspace_size
818+ elif "CUBLASLT_WORKSPACE_SIZE" in os .environ :
819+ del os .environ ["CUBLASLT_WORKSPACE_SIZE" ]
820+
821+ _original_cublas_workspace_cfg = None
822+ _original_cublaslt_workspace_size = None
823+
756824 _batch_invariant_MODE = False
757825 _batch_invariant_LIB = None
758826
@@ -831,6 +899,9 @@ def override_envs_for_invariance():
831899 os .environ ["NCCL_NTHREADS" ] = "1"
832900 os .environ ["NCCL_SOCKET_NTHREADS" ] = "1"
833901
902+ # torch.compile settings
903+ os .environ ["VLLM_USE_AOT_COMPILE" ] = "0"
904+
834905
835906def init_batch_invariance ():
836907 # this will hit all the csrc overrides as well
0 commit comments