Skip to content

Commit 4db72e5

Browse files
authored
[Bugfix][Refactor] Unify model management in frontend (#11660)
Signed-off-by: Joe Runde <[email protected]>
1 parent 0c6f998 commit 4db72e5

15 files changed

+365
-307
lines changed

tests/entrypoints/openai/test_cli_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
66
validate_parsed_serve_args)
7-
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
7+
from vllm.entrypoints.openai.serving_models import LoRAModulePath
88
from vllm.utils import FlexibleArgumentParser
99

1010
from ...utils import VLLM_PATH

tests/entrypoints/openai/test_lora_lineage.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files):
5555
"64",
5656
]
5757

58-
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
58+
# Enable the /v1/load_lora_adapter endpoint
59+
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
60+
61+
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
5962
yield remote_server
6063

6164

@@ -67,8 +70,8 @@ async def client_for_lora_lineage(server_with_lora_modules_json):
6770

6871

6972
@pytest.mark.asyncio
70-
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
71-
zephyr_lora_files):
73+
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
74+
zephyr_lora_files):
7275
models = await client_for_lora_lineage.models.list()
7376
models = models.data
7477
served_model = models[0]
@@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
8184
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
8285
assert lora_models[0].id == "zephyr-lora"
8386
assert lora_models[1].id == "zephyr-lora2"
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_dynamic_lora_lineage(
91+
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):
92+
93+
response = await client_for_lora_lineage.post("load_lora_adapter",
94+
cast_to=str,
95+
body={
96+
"lora_name":
97+
"zephyr-lora-3",
98+
"lora_path":
99+
zephyr_lora_files
100+
})
101+
# Ensure adapter loads before querying /models
102+
assert "success" in response
103+
104+
models = await client_for_lora_lineage.models.list()
105+
models = models.data
106+
dynamic_lora_model = models[-1]
107+
assert dynamic_lora_model.root == zephyr_lora_files
108+
assert dynamic_lora_model.parent == MODEL_NAME
109+
assert dynamic_lora_model.id == "zephyr-lora-3"

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from vllm.engine.multiprocessing.client import MQLLMEngineClient
99
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
1010
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
11-
from vllm.entrypoints.openai.serving_engine import BaseModelPath
11+
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
12+
OpenAIServingModels)
1213
from vllm.transformers_utils.tokenizer import get_tokenizer
1314

1415
MODEL_NAME = "openai-community/gpt2"
@@ -50,14 +51,13 @@ async def _async_serving_chat_init():
5051
engine = MockEngine()
5152
model_config = await engine.get_model_config()
5253

54+
models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
5355
serving_completion = OpenAIServingChat(engine,
5456
model_config,
55-
BASE_MODEL_PATHS,
57+
models,
5658
response_role="assistant",
5759
chat_template=CHAT_TEMPLATE,
5860
chat_template_content_format="auto",
59-
lora_modules=None,
60-
prompt_adapters=None,
6161
request_logger=None)
6262
return serving_completion
6363

@@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens():
7272
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
7373
mock_engine.errored = False
7474

75+
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
76+
model_config=MockModelConfig())
7577
serving_chat = OpenAIServingChat(mock_engine,
7678
MockModelConfig(),
77-
BASE_MODEL_PATHS,
79+
models,
7880
response_role="assistant",
7981
chat_template=CHAT_TEMPLATE,
8082
chat_template_content_format="auto",
81-
lora_modules=None,
82-
prompt_adapters=None,
8383
request_logger=None)
8484
req = ChatCompletionRequest(
8585
model=MODEL_NAME,
@@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config():
115115
mock_engine.errored = False
116116

117117
# Initialize the serving chat
118+
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
119+
model_config=mock_model_config)
118120
serving_chat = OpenAIServingChat(mock_engine,
119121
mock_model_config,
120-
BASE_MODEL_PATHS,
122+
models,
121123
response_role="assistant",
122124
chat_template=CHAT_TEMPLATE,
123125
chat_template_content_format="auto",
124-
lora_modules=None,
125-
prompt_adapters=None,
126126
request_logger=None)
127127
req = ChatCompletionRequest(
128128
model=MODEL_NAME,

tests/entrypoints/openai/test_serving_engine.py renamed to tests/entrypoints/openai/test_serving_models.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import pytest
55

66
from vllm.config import ModelConfig
7-
from vllm.engine.protocol import EngineClient
87
from vllm.entrypoints.openai.protocol import (ErrorResponse,
98
LoadLoraAdapterRequest,
109
UnloadLoraAdapterRequest)
11-
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
10+
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
11+
OpenAIServingModels)
1212
from vllm.lora.request import LoRARequest
1313

1414
MODEL_NAME = "meta-llama/Llama-2-7b"
@@ -19,101 +19,99 @@
1919
"Success: LoRA adapter '{lora_name}' removed successfully.")
2020

2121

22-
async def _async_serving_engine_init():
23-
mock_engine_client = MagicMock(spec=EngineClient)
22+
async def _async_serving_models_init() -> OpenAIServingModels:
2423
mock_model_config = MagicMock(spec=ModelConfig)
2524
# Set the max_model_len attribute to avoid missing attribute
2625
mock_model_config.max_model_len = 2048
2726

28-
serving_engine = OpenAIServing(mock_engine_client,
29-
mock_model_config,
30-
BASE_MODEL_PATHS,
31-
lora_modules=None,
32-
prompt_adapters=None,
33-
request_logger=None)
34-
return serving_engine
27+
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
28+
model_config=mock_model_config,
29+
lora_modules=None,
30+
prompt_adapters=None)
31+
32+
return serving_models
3533

3634

3735
@pytest.mark.asyncio
3836
async def test_serving_model_name():
39-
serving_engine = await _async_serving_engine_init()
40-
assert serving_engine._get_model_name(None) == MODEL_NAME
37+
serving_models = await _async_serving_models_init()
38+
assert serving_models.model_name(None) == MODEL_NAME
4139
request = LoRARequest(lora_name="adapter",
4240
lora_path="/path/to/adapter2",
4341
lora_int_id=1)
44-
assert serving_engine._get_model_name(request) == request.lora_name
42+
assert serving_models.model_name(request) == request.lora_name
4543

4644

4745
@pytest.mark.asyncio
4846
async def test_load_lora_adapter_success():
49-
serving_engine = await _async_serving_engine_init()
47+
serving_models = await _async_serving_models_init()
5048
request = LoadLoraAdapterRequest(lora_name="adapter",
5149
lora_path="/path/to/adapter2")
52-
response = await serving_engine.load_lora_adapter(request)
50+
response = await serving_models.load_lora_adapter(request)
5351
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
54-
assert len(serving_engine.lora_requests) == 1
55-
assert serving_engine.lora_requests[0].lora_name == "adapter"
52+
assert len(serving_models.lora_requests) == 1
53+
assert serving_models.lora_requests[0].lora_name == "adapter"
5654

5755

5856
@pytest.mark.asyncio
5957
async def test_load_lora_adapter_missing_fields():
60-
serving_engine = await _async_serving_engine_init()
58+
serving_models = await _async_serving_models_init()
6159
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
62-
response = await serving_engine.load_lora_adapter(request)
60+
response = await serving_models.load_lora_adapter(request)
6361
assert isinstance(response, ErrorResponse)
6462
assert response.type == "InvalidUserInput"
6563
assert response.code == HTTPStatus.BAD_REQUEST
6664

6765

6866
@pytest.mark.asyncio
6967
async def test_load_lora_adapter_duplicate():
70-
serving_engine = await _async_serving_engine_init()
68+
serving_models = await _async_serving_models_init()
7169
request = LoadLoraAdapterRequest(lora_name="adapter1",
7270
lora_path="/path/to/adapter1")
73-
response = await serving_engine.load_lora_adapter(request)
71+
response = await serving_models.load_lora_adapter(request)
7472
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
7573
lora_name='adapter1')
76-
assert len(serving_engine.lora_requests) == 1
74+
assert len(serving_models.lora_requests) == 1
7775

7876
request = LoadLoraAdapterRequest(lora_name="adapter1",
7977
lora_path="/path/to/adapter1")
80-
response = await serving_engine.load_lora_adapter(request)
78+
response = await serving_models.load_lora_adapter(request)
8179
assert isinstance(response, ErrorResponse)
8280
assert response.type == "InvalidUserInput"
8381
assert response.code == HTTPStatus.BAD_REQUEST
84-
assert len(serving_engine.lora_requests) == 1
82+
assert len(serving_models.lora_requests) == 1
8583

8684

8785
@pytest.mark.asyncio
8886
async def test_unload_lora_adapter_success():
89-
serving_engine = await _async_serving_engine_init()
87+
serving_models = await _async_serving_models_init()
9088
request = LoadLoraAdapterRequest(lora_name="adapter1",
9189
lora_path="/path/to/adapter1")
92-
response = await serving_engine.load_lora_adapter(request)
93-
assert len(serving_engine.lora_requests) == 1
90+
response = await serving_models.load_lora_adapter(request)
91+
assert len(serving_models.lora_requests) == 1
9492

9593
request = UnloadLoraAdapterRequest(lora_name="adapter1")
96-
response = await serving_engine.unload_lora_adapter(request)
94+
response = await serving_models.unload_lora_adapter(request)
9795
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
9896
lora_name='adapter1')
99-
assert len(serving_engine.lora_requests) == 0
97+
assert len(serving_models.lora_requests) == 0
10098

10199

102100
@pytest.mark.asyncio
103101
async def test_unload_lora_adapter_missing_fields():
104-
serving_engine = await _async_serving_engine_init()
102+
serving_models = await _async_serving_models_init()
105103
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
106-
response = await serving_engine.unload_lora_adapter(request)
104+
response = await serving_models.unload_lora_adapter(request)
107105
assert isinstance(response, ErrorResponse)
108106
assert response.type == "InvalidUserInput"
109107
assert response.code == HTTPStatus.BAD_REQUEST
110108

111109

112110
@pytest.mark.asyncio
113111
async def test_unload_lora_adapter_not_found():
114-
serving_engine = await _async_serving_engine_init()
112+
serving_models = await _async_serving_models_init()
115113
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
116-
response = await serving_engine.unload_lora_adapter(request)
114+
response = await serving_models.unload_lora_adapter(request)
117115
assert isinstance(response, ErrorResponse)
118116
assert response.type == "InvalidUserInput"
119117
assert response.code == HTTPStatus.BAD_REQUEST

0 commit comments

Comments
 (0)