|
4 | 4 | import pytest |
5 | 5 |
|
6 | 6 | from vllm.config import ModelConfig |
7 | | -from vllm.engine.protocol import EngineClient |
8 | 7 | from vllm.entrypoints.openai.protocol import (ErrorResponse, |
9 | 8 | LoadLoraAdapterRequest, |
10 | 9 | UnloadLoraAdapterRequest) |
11 | | -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing |
| 10 | +from vllm.entrypoints.openai.serving_models import (BaseModelPath, |
| 11 | + OpenAIServingModels) |
12 | 12 | from vllm.lora.request import LoRARequest |
13 | 13 |
|
14 | 14 | MODEL_NAME = "meta-llama/Llama-2-7b" |
|
19 | 19 | "Success: LoRA adapter '{lora_name}' removed successfully.") |
20 | 20 |
|
21 | 21 |
|
22 | | -async def _async_serving_engine_init(): |
23 | | - mock_engine_client = MagicMock(spec=EngineClient) |
| 22 | +async def _async_serving_models_init() -> OpenAIServingModels: |
24 | 23 | mock_model_config = MagicMock(spec=ModelConfig) |
25 | 24 | # Set the max_model_len attribute to avoid missing attribute |
26 | 25 | mock_model_config.max_model_len = 2048 |
27 | 26 |
|
28 | | - serving_engine = OpenAIServing(mock_engine_client, |
29 | | - mock_model_config, |
30 | | - BASE_MODEL_PATHS, |
31 | | - lora_modules=None, |
32 | | - prompt_adapters=None, |
33 | | - request_logger=None) |
34 | | - return serving_engine |
| 27 | + serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS, |
| 28 | + model_config=mock_model_config, |
| 29 | + lora_modules=None, |
| 30 | + prompt_adapters=None) |
| 31 | + |
| 32 | + return serving_models |
35 | 33 |
|
36 | 34 |
|
37 | 35 | @pytest.mark.asyncio |
38 | 36 | async def test_serving_model_name(): |
39 | | - serving_engine = await _async_serving_engine_init() |
40 | | - assert serving_engine._get_model_name(None) == MODEL_NAME |
| 37 | + serving_models = await _async_serving_models_init() |
| 38 | + assert serving_models.model_name(None) == MODEL_NAME |
41 | 39 | request = LoRARequest(lora_name="adapter", |
42 | 40 | lora_path="/path/to/adapter2", |
43 | 41 | lora_int_id=1) |
44 | | - assert serving_engine._get_model_name(request) == request.lora_name |
| 42 | + assert serving_models.model_name(request) == request.lora_name |
45 | 43 |
|
46 | 44 |
|
47 | 45 | @pytest.mark.asyncio |
48 | 46 | async def test_load_lora_adapter_success(): |
49 | | - serving_engine = await _async_serving_engine_init() |
| 47 | + serving_models = await _async_serving_models_init() |
50 | 48 | request = LoadLoraAdapterRequest(lora_name="adapter", |
51 | 49 | lora_path="/path/to/adapter2") |
52 | | - response = await serving_engine.load_lora_adapter(request) |
| 50 | + response = await serving_models.load_lora_adapter(request) |
53 | 51 | assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') |
54 | | - assert len(serving_engine.lora_requests) == 1 |
55 | | - assert serving_engine.lora_requests[0].lora_name == "adapter" |
| 52 | + assert len(serving_models.lora_requests) == 1 |
| 53 | + assert serving_models.lora_requests[0].lora_name == "adapter" |
56 | 54 |
|
57 | 55 |
|
58 | 56 | @pytest.mark.asyncio |
59 | 57 | async def test_load_lora_adapter_missing_fields(): |
60 | | - serving_engine = await _async_serving_engine_init() |
| 58 | + serving_models = await _async_serving_models_init() |
61 | 59 | request = LoadLoraAdapterRequest(lora_name="", lora_path="") |
62 | | - response = await serving_engine.load_lora_adapter(request) |
| 60 | + response = await serving_models.load_lora_adapter(request) |
63 | 61 | assert isinstance(response, ErrorResponse) |
64 | 62 | assert response.type == "InvalidUserInput" |
65 | 63 | assert response.code == HTTPStatus.BAD_REQUEST |
66 | 64 |
|
67 | 65 |
|
68 | 66 | @pytest.mark.asyncio |
69 | 67 | async def test_load_lora_adapter_duplicate(): |
70 | | - serving_engine = await _async_serving_engine_init() |
| 68 | + serving_models = await _async_serving_models_init() |
71 | 69 | request = LoadLoraAdapterRequest(lora_name="adapter1", |
72 | 70 | lora_path="/path/to/adapter1") |
73 | | - response = await serving_engine.load_lora_adapter(request) |
| 71 | + response = await serving_models.load_lora_adapter(request) |
74 | 72 | assert response == LORA_LOADING_SUCCESS_MESSAGE.format( |
75 | 73 | lora_name='adapter1') |
76 | | - assert len(serving_engine.lora_requests) == 1 |
| 74 | + assert len(serving_models.lora_requests) == 1 |
77 | 75 |
|
78 | 76 | request = LoadLoraAdapterRequest(lora_name="adapter1", |
79 | 77 | lora_path="/path/to/adapter1") |
80 | | - response = await serving_engine.load_lora_adapter(request) |
| 78 | + response = await serving_models.load_lora_adapter(request) |
81 | 79 | assert isinstance(response, ErrorResponse) |
82 | 80 | assert response.type == "InvalidUserInput" |
83 | 81 | assert response.code == HTTPStatus.BAD_REQUEST |
84 | | - assert len(serving_engine.lora_requests) == 1 |
| 82 | + assert len(serving_models.lora_requests) == 1 |
85 | 83 |
|
86 | 84 |
|
87 | 85 | @pytest.mark.asyncio |
88 | 86 | async def test_unload_lora_adapter_success(): |
89 | | - serving_engine = await _async_serving_engine_init() |
| 87 | + serving_models = await _async_serving_models_init() |
90 | 88 | request = LoadLoraAdapterRequest(lora_name="adapter1", |
91 | 89 | lora_path="/path/to/adapter1") |
92 | | - response = await serving_engine.load_lora_adapter(request) |
93 | | - assert len(serving_engine.lora_requests) == 1 |
| 90 | + response = await serving_models.load_lora_adapter(request) |
| 91 | + assert len(serving_models.lora_requests) == 1 |
94 | 92 |
|
95 | 93 | request = UnloadLoraAdapterRequest(lora_name="adapter1") |
96 | | - response = await serving_engine.unload_lora_adapter(request) |
| 94 | + response = await serving_models.unload_lora_adapter(request) |
97 | 95 | assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( |
98 | 96 | lora_name='adapter1') |
99 | | - assert len(serving_engine.lora_requests) == 0 |
| 97 | + assert len(serving_models.lora_requests) == 0 |
100 | 98 |
|
101 | 99 |
|
102 | 100 | @pytest.mark.asyncio |
103 | 101 | async def test_unload_lora_adapter_missing_fields(): |
104 | | - serving_engine = await _async_serving_engine_init() |
| 102 | + serving_models = await _async_serving_models_init() |
105 | 103 | request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) |
106 | | - response = await serving_engine.unload_lora_adapter(request) |
| 104 | + response = await serving_models.unload_lora_adapter(request) |
107 | 105 | assert isinstance(response, ErrorResponse) |
108 | 106 | assert response.type == "InvalidUserInput" |
109 | 107 | assert response.code == HTTPStatus.BAD_REQUEST |
110 | 108 |
|
111 | 109 |
|
112 | 110 | @pytest.mark.asyncio |
113 | 111 | async def test_unload_lora_adapter_not_found(): |
114 | | - serving_engine = await _async_serving_engine_init() |
| 112 | + serving_models = await _async_serving_models_init() |
115 | 113 | request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") |
116 | | - response = await serving_engine.unload_lora_adapter(request) |
| 114 | + response = await serving_models.unload_lora_adapter(request) |
117 | 115 | assert isinstance(response, ErrorResponse) |
118 | 116 | assert response.type == "InvalidUserInput" |
119 | 117 | assert response.code == HTTPStatus.BAD_REQUEST |
0 commit comments