Skip to content

Commit 1c5dbf1

Browse files
committed
Allow configurating a default model/provider
This patch allows for users to configure a default model/provider pair in the configuration file. Now models are selected as: * If no model/provider is specified in the configuration nor in the request, lightspeed-stack will use the FIRST MODEL AVAILABLE from llama-stack. * If the default model/provider is specified in the configuration file and a model/provider ARE NOT PROVIDED IN THE REQUEST, lightspeed-stack will use the model/provider FROM THE CONFIGURATION FILE. * If the default model/provider is specified in the configuration file and a model/provider ARE PROVIDED IN THE REQUEST, lightspeed-stack will use the model/provider FROM THE REQUEST. tl;dr the precedent order to use a model is: request, configuration, first available in llama-stack. Signed-off-by: Lucas Alvares Gomes <[email protected]>
1 parent a15f073 commit 1c5dbf1

File tree

6 files changed

+203
-30
lines changed

6 files changed

+203
-30
lines changed

src/app/endpoints/query.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,24 @@ def select_model_and_provider_id(
175175
models: ModelListResponse, query_request: QueryRequest
176176
) -> tuple[str, str | None]:
177177
"""Select the model ID and provider ID based on the request or available models."""
178+
# If model_id and provider_id are provided in the request, use them
178179
model_id = query_request.model
179180
provider_id = query_request.provider
180181

181-
# TODO(lucasagomes): support default model selection via configuration
182-
if not model_id:
183-
logger.info("No model specified in request, using the first available LLM")
182+
# If model_id is not provided in the request, check the configuration
183+
if not model_id or not provider_id:
184+
logger.debug(
185+
"No model ID or provider ID specified in request, checking configuration"
186+
)
187+
model_id = configuration.llama_stack_configuration.default_model
188+
provider_id = configuration.llama_stack_configuration.default_provider
189+
190+
# If no model is specified in the request or configuration, use the first available LLM
191+
if not model_id or not provider_id:
192+
logger.debug(
193+
"No model ID or provider ID specified in request or configuration, "
194+
"using the first available LLM"
195+
)
184196
try:
185197
model = next(
186198
m
@@ -202,7 +214,8 @@ def select_model_and_provider_id(
202214
},
203215
) from e
204216

205-
logger.info("Searching for model: %s, provider: %s", model_id, provider_id)
217+
# Validate that the model_id and provider_id are in the available models
218+
logger.debug("Searching for model: %s, provider: %s", model_id, provider_id)
206219
if not any(
207220
m.identifier == model_id and m.provider_id == provider_id for m in models
208221
):

src/metrics/utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
"""Utility functions for metrics handling."""
22

3+
from configuration import configuration
34
from client import LlamaStackClientHolder
45
from log import get_logger
56
import metrics
67

78
logger = get_logger(__name__)
89

910

10-
# TODO(lucasagomes): Change this metric once we are allowed to set the the
11-
# default model/provider via the configuration.The default provider/model
12-
# will be set to 1, and the rest will be set to 0.
1311
def setup_model_metrics() -> None:
1412
"""Perform setup of all metrics related to LLM model and provider."""
1513
client = LlamaStackClientHolder().get_client()
@@ -19,14 +17,29 @@ def setup_model_metrics() -> None:
1917
if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
2018
]
2119

20+
default_model_label = (
21+
configuration.llama_stack_configuration.default_provider,
22+
configuration.llama_stack_configuration.default_model,
23+
)
24+
2225
for model in models:
2326
provider = model.provider_id
2427
model_name = model.identifier
2528
if provider and model_name:
29+
# If the model/provider combination is the default, set the metric value to 1
30+
# Otherwise, set it to 0
31+
default_model_value = 0
2632
label_key = (provider, model_name)
27-
metrics.provider_model_configuration.labels(*label_key).set(1)
33+
if label_key == default_model_label:
34+
default_model_value = 1
35+
36+
# Set the metric for the provider/model configuration
37+
metrics.provider_model_configuration.labels(*label_key).set(
38+
default_model_value
39+
)
2840
logger.debug(
29-
"Set provider/model configuration for %s/%s to 1",
41+
"Set provider/model configuration for %s/%s to %d",
3042
provider,
3143
model_name,
44+
default_model_value,
3245
)

src/models/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class LlamaStackConfiguration(BaseModel):
6262
api_key: Optional[str] = None
6363
use_as_library_client: Optional[bool] = None
6464
library_client_config_path: Optional[str] = None
65+
default_model: Optional[str] = None
66+
default_provider: Optional[str] = None
6567

6668
@model_validator(mode="after")
6769
def check_llama_stack_model(self) -> Self:
@@ -100,6 +102,19 @@ def check_llama_stack_model(self) -> Self:
100102
)
101103
return self
102104

105+
@model_validator(mode="after")
106+
def check_default_model_and_provider(self) -> Self:
107+
"""Check default model and provider."""
108+
if self.default_model is None and self.default_provider is not None:
109+
raise ValueError(
110+
"Default model must be specified when default provider is set"
111+
)
112+
if self.default_model is not None and self.default_provider is None:
113+
raise ValueError(
114+
"Default provider must be specified when default model is set"
115+
)
116+
return self
117+
103118

104119
class DataCollectorConfiguration(BaseModel):
105120
"""Data collector configuration for sending data to ingress server."""

tests/unit/app/endpoints/test_query.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -179,30 +179,70 @@ def test_query_endpoint_handler_store_transcript(mocker):
179179
_test_query_endpoint_handler(mocker, store_transcript_to_file=True)
180180

181181

182-
def test_select_model_and_provider_id(mocker):
182+
def test_select_model_and_provider_id_from_request(mocker):
183183
"""Test the select_model_and_provider_id function."""
184-
mock_client = mocker.Mock()
185-
mock_client.models.list.return_value = [
184+
mocker.patch(
185+
"metrics.utils.configuration.llama_stack_configuration.default_provider",
186+
"default_provider",
187+
)
188+
mocker.patch(
189+
"metrics.utils.configuration.llama_stack_configuration.default_model",
190+
"default_model",
191+
)
192+
193+
model_list = [
186194
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
187195
mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"),
196+
mocker.Mock(
197+
identifier="default_model", model_type="llm", provider_id="default_provider"
198+
),
188199
]
189200

201+
# Create a query request with model and provider specified
190202
query_request = QueryRequest(
191-
query="What is OpenStack?", model="model1", provider="provider1"
203+
query="What is OpenStack?", model="model2", provider="provider2"
192204
)
193205

194-
model_id, provider_id = select_model_and_provider_id(
195-
mock_client.models.list(), query_request
206+
# Assert the model and provider from request take precedence from the configuration one
207+
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
208+
209+
assert model_id == "model2"
210+
assert provider_id == "provider2"
211+
212+
213+
def test_select_model_and_provider_id_from_configuration(mocker):
214+
"""Test the select_model_and_provider_id function."""
215+
mocker.patch(
216+
"metrics.utils.configuration.llama_stack_configuration.default_provider",
217+
"default_provider",
218+
)
219+
mocker.patch(
220+
"metrics.utils.configuration.llama_stack_configuration.default_model",
221+
"default_model",
196222
)
197223

198-
assert model_id == "model1"
199-
assert provider_id == "provider1"
224+
model_list = [
225+
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
226+
mocker.Mock(
227+
identifier="default_model", model_type="llm", provider_id="default_provider"
228+
),
229+
]
230+
231+
# Create a query request without model and provider specified
232+
query_request = QueryRequest(
233+
query="What is OpenStack?",
234+
)
235+
236+
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
237+
238+
# Assert that the default model and provider from the configuration are returned
239+
assert model_id == "default_model"
240+
assert provider_id == "default_provider"
200241

201242

202-
def test_select_model_and_provider_id_no_model(mocker):
243+
def test_select_model_and_provider_id_first_from_list(mocker):
203244
"""Test the select_model_and_provider_id function when no model is specified."""
204-
mock_client = mocker.Mock()
205-
mock_client.models.list.return_value = [
245+
model_list = [
206246
mocker.Mock(
207247
identifier="not_llm_type", model_type="embedding", provider_id="provider1"
208248
),
@@ -216,11 +256,10 @@ def test_select_model_and_provider_id_no_model(mocker):
216256

217257
query_request = QueryRequest(query="What is OpenStack?")
218258

219-
model_id, provider_id = select_model_and_provider_id(
220-
mock_client.models.list(), query_request
221-
)
259+
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
222260

223-
# Assert return the first available LLM model
261+
# Assert return the first available LLM model when no model/provider is
262+
# specified in the request or in the configuration
224263
assert model_id == "first_model"
225264
assert provider_id == "provider1"
226265

tests/unit/metrics/test_utis.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,62 @@ def test_setup_model_metrics(mocker):
88

99
# Mock the LlamaStackAsLibraryClient
1010
mock_client = mocker.patch("client.LlamaStackClientHolder.get_client").return_value
11+
mocker.patch(
12+
"metrics.utils.configuration.llama_stack_configuration.default_provider",
13+
"default_provider",
14+
)
15+
mocker.patch(
16+
"metrics.utils.configuration.llama_stack_configuration.default_model",
17+
"default_model",
18+
)
1119

1220
mock_metric = mocker.patch("metrics.provider_model_configuration")
13-
fake_model = mocker.Mock(
14-
provider_id="test_provider",
15-
identifier="test_model",
21+
# Mock a model that is the default
22+
model_default = mocker.Mock(
23+
provider_id="default_provider",
24+
identifier="default_model",
1625
model_type="llm",
1726
)
18-
mock_client.models.list.return_value = [fake_model]
27+
# Mock a model that is not the default
28+
model_0 = mocker.Mock(
29+
provider_id="test_provider-0",
30+
identifier="test_model-0",
31+
model_type="llm",
32+
)
33+
# Mock a second model which is not default
34+
model_1 = mocker.Mock(
35+
provider_id="test_provider-1",
36+
identifier="test_model-1",
37+
model_type="llm",
38+
)
39+
# Mock a model that is not an LLM type, should be ignored
40+
not_llm_model = mocker.Mock(
41+
provider_id="not-llm-provider",
42+
identifier="not-llm-model",
43+
model_type="not-llm",
44+
)
45+
46+
# Mock the list of models returned by the client
47+
mock_client.models.list.return_value = [
48+
model_0,
49+
model_default,
50+
not_llm_model,
51+
model_1,
52+
]
1953

2054
setup_model_metrics()
2155

22-
# Assert that the metric was set correctly
23-
mock_metric.labels("test_provider", "test_model").set.assert_called_once_with(1)
56+
# Check that the provider_model_configuration metric was set correctly
57+
# The default model should have a value of 1, others should be 0
58+
assert mock_metric.labels.call_count == 3
59+
mock_metric.assert_has_calls(
60+
[
61+
mocker.call.labels("test_provider-0", "test_model-0"),
62+
mocker.call.labels().set(0),
63+
mocker.call.labels("default_provider", "default_model"),
64+
mocker.call.labels().set(1),
65+
mocker.call.labels("test_provider-1", "test_model-1"),
66+
mocker.call.labels().set(0),
67+
],
68+
any_order=False, # Order matters here
69+
)

tests/unit/models/test_config.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ def test_llama_stack_configuration_constructor() -> None:
8585
)
8686
assert llama_stack_configuration is not None
8787

88+
# Test default model and provider
89+
llama_stack_configuration = LlamaStackConfiguration(
90+
use_as_library_client=False,
91+
url="http://localhost",
92+
default_model="default_model",
93+
default_provider="default_provider",
94+
)
95+
assert llama_stack_configuration is not None
96+
8897

8998
def test_llama_stack_configuration_no_run_yaml() -> None:
9099
"""
@@ -131,6 +140,36 @@ def test_llama_stack_wrong_configuration_no_config_file() -> None:
131140
LlamaStackConfiguration(use_as_library_client=True)
132141

133142

143+
def test_llama_stack_configuration_default_model_missing() -> None:
144+
"""
145+
Test case where only default provider is set, should fail
146+
"""
147+
with pytest.raises(
148+
ValueError,
149+
match="Default model must be specified when default provider is set",
150+
):
151+
LlamaStackConfiguration(
152+
use_as_library_client=False,
153+
url="http://localhost",
154+
default_provider="default_provider",
155+
)
156+
157+
158+
def test_llama_stack_configuration_default_provider_missing() -> None:
159+
"""
160+
Test case where only default model is set, should fail
161+
"""
162+
with pytest.raises(
163+
ValueError,
164+
match="Default provider must be specified when default model is set",
165+
):
166+
LlamaStackConfiguration(
167+
use_as_library_client=False,
168+
url="http://localhost",
169+
default_model="default_model",
170+
)
171+
172+
134173
def test_user_data_collection_feedback_enabled() -> None:
135174
"""Test the UserDataCollection constructor for feedback."""
136175
# correct configuration
@@ -420,6 +459,8 @@ def test_dump_configuration(tmp_path) -> None:
420459
llama_stack=LlamaStackConfiguration(
421460
use_as_library_client=True,
422461
library_client_config_path="tests/configuration/run.yaml",
462+
default_provider="default_provider",
463+
default_model="default_model",
423464
),
424465
user_data_collection=UserDataCollection(
425466
feedback_enabled=False, feedback_storage=None
@@ -465,6 +506,8 @@ def test_dump_configuration(tmp_path) -> None:
465506
"api_key": None,
466507
"use_as_library_client": True,
467508
"library_client_config_path": "tests/configuration/run.yaml",
509+
"default_provider": "default_provider",
510+
"default_model": "default_model",
468511
},
469512
"user_data_collection": {
470513
"feedback_enabled": False,
@@ -550,6 +593,8 @@ def test_dump_configuration_with_one_mcp_server(tmp_path) -> None:
550593
"api_key": None,
551594
"use_as_library_client": True,
552595
"library_client_config_path": "tests/configuration/run.yaml",
596+
"default_provider": None,
597+
"default_model": None,
553598
},
554599
"user_data_collection": {
555600
"feedback_enabled": False,
@@ -650,6 +695,8 @@ def test_dump_configuration_with_more_mcp_servers(tmp_path) -> None:
650695
"api_key": None,
651696
"use_as_library_client": True,
652697
"library_client_config_path": "tests/configuration/run.yaml",
698+
"default_provider": None,
699+
"default_model": None,
653700
},
654701
"user_data_collection": {
655702
"feedback_enabled": False,

0 commit comments

Comments
 (0)