Skip to content

Commit dfac5da

Browse files
committed
[V0 deprecation] Remove VLLM_USE_V1 usage in most modules
Signed-off-by: wangxiyuan <[email protected]>
1 parent 0ce743f commit dfac5da

File tree

19 files changed

+108
-243
lines changed

19 files changed

+108
-243
lines changed

docs/usage/v1_guide.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack).
88

9-
To disable V1, please set the environment variable as: `VLLM_USE_V1=0`, and send us a GitHub issue sharing the reason!
10-
119
## Why vLLM V1?
1210

1311
vLLM V0 successfully supported a wide range of models and hardware, but as new features were developed independently, the system grew increasingly complex. This complexity made it harder to integrate new capabilities and introduced technical debt, revealing the need for a more streamlined and unified design.

tests/conftest.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -154,26 +154,6 @@ def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
154154
"""Singleton instance of {class}`AudioTestAssets`."""
155155

156156

157-
@pytest.fixture(scope="function", autouse=True)
158-
def cleanup_VLLM_USE_V1(monkeypatch):
159-
"""
160-
The V1 oracle sets "VLLM_USE_V1" during loading. This means
161-
that each invocation of a test change the env variable.
162-
163-
If we touch "VLLM_USE_V1" with monkeypatch, then any changes
164-
made during the test run by vLLM will be cleaned up.
165-
166-
This fixture is used by every test.
167-
"""
168-
169-
# If VLLM_USE_V1 is not set, set then delete. This will
170-
# cause monkeypatch to clean up VLLM_USE_V1 upon exit
171-
# if VLLM modifies the value of envs.VLLM_USE_V1.
172-
if "VLLM_USE_V1" not in os.environ:
173-
monkeypatch.setenv("VLLM_USE_V1", "")
174-
monkeypatch.delenv("VLLM_USE_V1")
175-
176-
177157
@pytest.fixture(autouse=True)
178158
def init_test_http_connection():
179159
# pytest_asyncio may use a different event loop per test

tests/v1/engine/test_async_llm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,15 +424,12 @@ async def test_customize_loggers(monkeypatch):
424424

425425

426426
@pytest.mark.asyncio
427-
async def test_customize_aggregated_loggers(monkeypatch):
427+
async def test_customize_aggregated_loggers():
428428
"""Test that we can customize the aggregated loggers.
429429
If a customized logger is provided at the init, it should
430430
be added to the default loggers.
431431
"""
432-
433-
with monkeypatch.context() as m, ExitStack() as after:
434-
m.setenv("VLLM_USE_V1", "1")
435-
432+
with ExitStack() as after:
436433
with set_default_torch_num_threads(1):
437434
engine = AsyncLLM.from_engine_args(
438435
TEXT_ENGINE_ARGS,

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,11 +868,8 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
868868

869869
@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
870870
def test_structured_output_with_structural_tag(
871-
monkeypatch: pytest.MonkeyPatch,
872871
guided_decoding_backend: str,
873872
):
874-
monkeypatch.setenv("VLLM_USE_V1", "1")
875-
876873
llm = LLM(
877874
model="Qwen/Qwen2.5-1.5B-Instruct",
878875
guided_decoding_backend=guided_decoding_backend,

tests/v1/sample/test_logprobs.py

Lines changed: 59 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,6 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
530530
def test_spec_decode_logprobs(
531531
logprobs_mode: LogprobsMode,
532532
model_setup: tuple[str, str, str],
533-
monkeypatch: pytest.MonkeyPatch,
534533
):
535534
"""Spec decode logprobs should match those of the base model.
536535
@@ -541,64 +540,62 @@ def test_spec_decode_logprobs(
541540
"""
542541
from vllm import LLM
543542

544-
with monkeypatch.context() as m:
545-
m.setenv("VLLM_USE_V1", "1")
546-
prompt = "Hello world"
547-
sampling_params = SamplingParams(
548-
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
549-
)
550-
method, model_name, spec_model_name = model_setup
551-
max_model_len = 256
552-
553-
# Run base LLM.
554-
ref_llm = LLM(
555-
model=model_name,
556-
max_logprobs=5,
557-
max_model_len=max_model_len,
558-
seed=42,
559-
logprobs_mode=logprobs_mode,
560-
gpu_memory_utilization=0.4,
561-
)
562-
ref_results = ref_llm.generate([prompt], sampling_params)
563-
# Collect logprobs outputs from reference LLM.
564-
ref_logprobs = []
565-
for output in ref_results[0].outputs:
566-
for logprobs in output.logprobs:
567-
for token_id in logprobs:
568-
ref_logprobs.append(logprobs[token_id])
569-
del ref_llm
570-
torch.cuda.empty_cache()
571-
cleanup_dist_env_and_memory()
572-
573-
# Run spec decode LLM.
574-
spec_llm = LLM(
575-
model_name,
576-
speculative_config={
577-
"method": method,
578-
"model": spec_model_name,
579-
"num_speculative_tokens": 3,
580-
"max_model_len": max_model_len,
581-
},
582-
max_logprobs=5,
583-
max_model_len=max_model_len,
584-
seed=42,
585-
logprobs_mode=logprobs_mode,
586-
gpu_memory_utilization=0.4,
587-
)
588-
spec_results = spec_llm.generate([prompt], sampling_params)
589-
# Collect logprobs outputs from spec decode LLM.
590-
spec_logprobs = []
591-
for output in spec_results[0].outputs:
592-
for logprobs in output.logprobs:
593-
for token_id in logprobs:
594-
spec_logprobs.append(logprobs[token_id])
595-
del spec_llm
596-
torch.cuda.empty_cache()
597-
cleanup_dist_env_and_memory()
598-
599-
# Per-token logprobs are expected to be the same.
600-
assert len(ref_logprobs) == len(spec_logprobs)
601-
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
602-
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
603-
assert ref_logprob.rank == spec_logprob.rank
604-
assert ref_logprob.decoded_token == spec_logprob.decoded_token
543+
prompt = "Hello world"
544+
sampling_params = SamplingParams(
545+
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
546+
)
547+
method, model_name, spec_model_name = model_setup
548+
max_model_len = 256
549+
550+
# Run base LLM.
551+
ref_llm = LLM(
552+
model=model_name,
553+
max_logprobs=5,
554+
max_model_len=max_model_len,
555+
seed=42,
556+
logprobs_mode=logprobs_mode,
557+
gpu_memory_utilization=0.4,
558+
)
559+
ref_results = ref_llm.generate([prompt], sampling_params)
560+
# Collect logprobs outputs from reference LLM.
561+
ref_logprobs = []
562+
for output in ref_results[0].outputs:
563+
for logprobs in output.logprobs:
564+
for token_id in logprobs:
565+
ref_logprobs.append(logprobs[token_id])
566+
del ref_llm
567+
torch.cuda.empty_cache()
568+
cleanup_dist_env_and_memory()
569+
570+
# Run spec decode LLM.
571+
spec_llm = LLM(
572+
model_name,
573+
speculative_config={
574+
"method": method,
575+
"model": spec_model_name,
576+
"num_speculative_tokens": 3,
577+
"max_model_len": max_model_len,
578+
},
579+
max_logprobs=5,
580+
max_model_len=max_model_len,
581+
seed=42,
582+
logprobs_mode=logprobs_mode,
583+
gpu_memory_utilization=0.4,
584+
)
585+
spec_results = spec_llm.generate([prompt], sampling_params)
586+
# Collect logprobs outputs from spec decode LLM.
587+
spec_logprobs = []
588+
for output in spec_results[0].outputs:
589+
for logprobs in output.logprobs:
590+
for token_id in logprobs:
591+
spec_logprobs.append(logprobs[token_id])
592+
del spec_llm
593+
torch.cuda.empty_cache()
594+
cleanup_dist_env_and_memory()
595+
596+
# Per-token logprobs are expected to be the same.
597+
assert len(ref_logprobs) == len(spec_logprobs)
598+
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
599+
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
600+
assert ref_logprob.rank == spec_logprob.rank
601+
assert ref_logprob.decoded_token == spec_logprob.decoded_token

vllm/attention/layers/chunked_local_attention.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import functools
43
from typing import ClassVar
54

65
import torch
76

8-
from vllm import envs
97
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
108
from vllm.attention.selector import get_attn_backend
119
from vllm.config import CacheConfig
@@ -22,7 +20,6 @@
2220
from ..layer import Attention
2321

2422

25-
@functools.lru_cache
2623
def create_chunked_local_attention_backend(
2724
underlying_attn_backend: AttentionBackend,
2825
attention_chunk_size: int,
@@ -78,17 +75,12 @@ def __init__(
7875
kv_cache_dtype = "auto"
7976
block_size = 16
8077

81-
if envs.VLLM_USE_V1:
82-
underlying_attn_backend = get_attn_backend(
83-
head_size, dtype, kv_cache_dtype, block_size
84-
)
85-
86-
attn_backend = create_chunked_local_attention_backend(
87-
underlying_attn_backend, attention_chunk_size, block_size
88-
)
89-
else:
90-
# in v0 the local attention is handled inside the backends
91-
attn_backend = None
78+
underlying_attn_backend = get_attn_backend(
79+
head_size, dtype, kv_cache_dtype, block_size
80+
)
81+
attn_backend = create_chunked_local_attention_backend(
82+
underlying_attn_backend, attention_chunk_size, block_size
83+
)
9284

9385
super().__init__(
9486
num_heads=num_heads,

vllm/attention/layers/cross_attention.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import functools
43
from copy import copy
54

65
import numpy as np
76
import torch
87

9-
from vllm import envs
108
from vllm.attention.backends.abstract import (
119
AttentionBackend,
1210
AttentionMetadata,
@@ -78,7 +76,6 @@ def _get_cross_slot_mapping(
7876
return torch.empty(0, dtype=torch.int64, device=device)
7977

8078

81-
@functools.lru_cache
8279
def create_cross_attention_backend(
8380
underlying_attn_backend: AttentionBackend,
8481
) -> type[AttentionBackend]:
@@ -150,15 +147,10 @@ def __init__(
150147
kv_cache_dtype = "auto"
151148
block_size = 16
152149

153-
if envs.VLLM_USE_V1:
154-
underlying_attn_backend = get_attn_backend(
155-
head_size, dtype, kv_cache_dtype, block_size
156-
)
157-
158-
attn_backend = create_cross_attention_backend(underlying_attn_backend)
159-
else:
160-
# in v0 cross attention is handled inside the backends
161-
attn_backend = None
150+
underlying_attn_backend = get_attn_backend(
151+
head_size, dtype, kv_cache_dtype, block_size
152+
)
153+
attn_backend = create_cross_attention_backend(underlying_attn_backend)
162154

163155
if attn_type is not None:
164156
assert attn_type == AttentionType.ENCODER_DECODER, (

vllm/attention/layers/encoder_only_attention.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import functools
43
from copy import copy
54

65
import torch
76

8-
from vllm import envs
97
from vllm.attention.backends.abstract import (
108
AttentionBackend,
119
AttentionMetadata,
@@ -22,7 +20,6 @@
2220
from vllm.v1.kv_cache_interface import KVCacheSpec
2321

2422

25-
@functools.lru_cache
2623
def create_encoder_only_attention_backend(
2724
underlying_attn_backend: AttentionBackend,
2825
) -> type[AttentionBackend]:
@@ -74,17 +71,11 @@ def __init__(
7471
kv_cache_dtype = "auto"
7572
block_size = 16
7673

77-
if envs.VLLM_USE_V1:
78-
underlying_attn_backend = get_attn_backend(
79-
head_size, dtype, kv_cache_dtype, block_size
80-
)
74+
underlying_attn_backend = get_attn_backend(
75+
head_size, dtype, kv_cache_dtype, block_size
76+
)
8177

82-
attn_backend = create_encoder_only_attention_backend(
83-
underlying_attn_backend
84-
)
85-
else:
86-
# in v0 encoder only attention is handled inside the backends
87-
attn_backend = None
78+
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
8879

8980
if attn_type is not None:
9081
assert attn_type == AttentionType.ENCODER_ONLY, (

vllm/attention/selector.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,12 @@ def get_attn_backend(
134134
use_sparse: bool = False,
135135
) -> type[AttentionBackend]:
136136
"""Selects which attention backend to use and lazily imports it."""
137-
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
138-
# value to be returned from the cache if the value changes between calls.
139-
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
140-
# private function.
141137
return _cached_get_attn_backend(
142138
head_size=head_size,
143139
dtype=dtype,
144140
kv_cache_dtype=kv_cache_dtype,
145141
block_size=block_size,
146-
use_v1=envs.VLLM_USE_V1,
142+
use_v1=True,
147143
use_mla=use_mla,
148144
has_sink=has_sink,
149145
use_sparse=use_sparse,

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import Callable
66
from typing import TYPE_CHECKING, cast
77

8-
import vllm.envs as envs
98
from vllm.config import VllmConfig
109
from vllm.distributed.kv_transfer.kv_connector.base import (
1110
KVConnectorBase,
@@ -44,12 +43,6 @@ def create_connector(
4443
config: VllmConfig,
4544
role: KVConnectorRole,
4645
) -> KVConnectorBase:
47-
if not envs.VLLM_USE_V1:
48-
raise ValueError(
49-
"Attempting to initialize a V1 Connector, "
50-
f"but found {envs.VLLM_USE_V1=}"
51-
)
52-
5346
kv_transfer_config = config.kv_transfer_config
5447
if kv_transfer_config is None:
5548
raise ValueError("kv_transfer_config must be set to create a connector")

0 commit comments

Comments
 (0)