Skip to content

Commit 22f10e3

Browse files
committed
allow disabling query model and provider
1 parent 18fdf3c commit 22f10e3

File tree

10 files changed

+196
-10
lines changed

10 files changed

+196
-10
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ customization:
369369
disable_query_system_prompt: true
370370
```
371371

372+
### Control model/provider overrides via authorization
373+
374+
By default, clients may specify `model` and `provider` in `/v1/query` and `/v1/streaming_query`. Override is permitted only to callers granted the `MODEL_OVERRIDE` action via the authorization rules. Requests that include `model` or `provider` without this permission are rejected with HTTP 403.
375+
372376
## Safety Shields
373377

374378
A single Llama Stack configuration file can include multiple safety shields, which are utilized in agent

docs/openapi.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,8 @@
786786
"get_models",
787787
"get_metrics",
788788
"get_config",
789-
"info"
789+
"info",
790+
"model_override"
790791
],
791792
"title": "Action",
792793
"description": "Available actions in the system."

src/app/endpoints/query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
get_agent,
3636
get_system_prompt,
3737
validate_conversation_ownership,
38+
validate_model_provider_override,
3839
)
3940
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
4041
from utils.transcripts import store_transcript
@@ -174,6 +175,9 @@ async def query_endpoint_handler(
174175
"""
175176
check_configuration_loaded(configuration)
176177

178+
# Enforce RBAC: optionally disallow overriding model/provider in requests
179+
validate_model_provider_override(query_request, request.state.authorized_actions)
180+
177181
# log Llama Stack configuration, but without sensitive information
178182
llama_stack_config = configuration.llama_stack_configuration.model_copy()
179183
llama_stack_config.api_key = "********"

src/app/endpoints/streaming_query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3434
from utils.transcripts import store_transcript
3535
from utils.types import TurnSummary
36+
from utils.endpoints import validate_model_provider_override
3637

3738
from app.endpoints.query import (
3839
get_rag_toolgroups,
@@ -548,6 +549,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
548549

549550
check_configuration_loaded(configuration)
550551

552+
# Enforce RBAC: optionally disallow overriding model/provider in requests
553+
validate_model_provider_override(query_request, request.state.authorized_actions)
554+
551555
# log Llama Stack configuration, but without sensitive information
552556
llama_stack_config = configuration.llama_stack_configuration.model_copy()
553557
llama_stack_config.api_key = "********"

src/authorization/middleware.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import wraps, lru_cache
55
from typing import Any, Callable, Tuple
66
from fastapi import HTTPException, status
7+
from starlette.requests import Request
78

89
from authorization.resolvers import (
910
AccessResolver,
@@ -64,7 +65,9 @@ def get_authorization_resolvers() -> Tuple[RolesResolver, AccessResolver]:
6465
)
6566

6667

67-
async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -> None:
68+
async def _perform_authorization_check(
69+
action: Action, args: tuple[Any, ...], kwargs: dict[str, Any]
70+
) -> None:
6871
"""Perform authorization check - common logic for all decorators."""
6972
role_resolver, access_resolver = get_authorization_resolvers()
7073

@@ -93,12 +96,16 @@ async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -
9396

9497
authorized_actions = access_resolver.get_actions(user_roles)
9598

96-
try:
97-
request = kwargs["request"]
98-
request.state.authorized_actions = authorized_actions
99-
except KeyError:
100-
# This endpoint doesn't seem care about the authorized actions, so no need to set it
101-
pass
99+
req: Request | None = None
100+
if "request" in kwargs and isinstance(kwargs["request"], Request):
101+
req = kwargs["request"]
102+
else:
103+
for arg in args:
104+
if isinstance(arg, Request):
105+
req = arg
106+
break
107+
if req is not None:
108+
req.state.authorized_actions = authorized_actions
102109

103110

104111
def authorize(action: Action) -> Callable:
@@ -107,7 +114,7 @@ def authorize(action: Action) -> Callable:
107114
def decorator(func: Callable) -> Callable:
108115
@wraps(func)
109116
async def wrapper(*args: Any, **kwargs: Any) -> Any:
110-
await _perform_authorization_check(action, kwargs)
117+
await _perform_authorization_check(action, args, kwargs)
111118
return await func(*args, **kwargs)
112119

113120
return wrapper

src/models/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ class Action(str, Enum):
311311
GET_CONFIG = "get_config"
312312

313313
INFO = "info"
314+
# Allow overriding model/provider via request
315+
MODEL_OVERRIDE = "model_override"
314316

315317

316318
class AccessRule(ConfigurationBase):

src/utils/endpoints.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from configuration import AppConfig
1414
from utils.suid import get_suid
1515
from utils.types import GraniteToolParser
16+
from models.config import Action
1617

1718

1819
logger = logging.getLogger("utils.endpoints")
@@ -84,6 +85,29 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
8485
return constants.DEFAULT_SYSTEM_PROMPT
8586

8687

88+
def validate_model_provider_override(
89+
query_request: QueryRequest, authorized_actions: set[Action] | frozenset[Action]
90+
) -> None:
91+
"""Validate whether model/provider overrides are allowed by RBAC.
92+
93+
Raises HTTP 403 if the request includes model or provider and the caller
94+
lacks Action.MODEL_OVERRIDE permission.
95+
"""
96+
if (query_request.model is not None or query_request.provider is not None) and (
97+
Action.MODEL_OVERRIDE not in authorized_actions
98+
):
99+
raise HTTPException(
100+
status_code=status.HTTP_403_FORBIDDEN,
101+
detail={
102+
"response": (
103+
"This instance does not permit overriding model/provider in the query request "
104+
"(missing permission: MODEL_OVERRIDE). Please remove the model and provider "
105+
"fields from your request."
106+
)
107+
},
108+
)
109+
110+
87111
# # pylint: disable=R0913,R0917
88112
async def get_agent(
89113
client: AsyncLlamaStackClient,

tests/unit/app/endpoints/test_query.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,3 +1507,60 @@ def test_evaluate_model_hints(
15071507

15081508
assert provider_id == expected_provider
15091509
assert model_id == expected_model
1510+
1511+
1512+
@pytest.mark.asyncio
1513+
async def test_query_endpoint_rejects_model_provider_override_without_permission(
1514+
mocker, dummy_request
1515+
):
1516+
"""Assert 403 and message when request includes model/provider without MODEL_OVERRIDE."""
1517+
# Patch endpoint configuration (no need to set customization)
1518+
cfg = AppConfig()
1519+
cfg.init_from_dict(
1520+
{
1521+
"name": "test",
1522+
"service": {
1523+
"host": "localhost",
1524+
"port": 8080,
1525+
"auth_enabled": False,
1526+
"workers": 1,
1527+
"color_log": True,
1528+
"access_log": True,
1529+
},
1530+
"llama_stack": {
1531+
"api_key": "test-key",
1532+
"url": "http://test.com:1234",
1533+
"use_as_library_client": False,
1534+
},
1535+
"user_data_collection": {"transcripts_enabled": False},
1536+
"mcp_servers": [],
1537+
}
1538+
)
1539+
mocker.patch("app.endpoints.query.configuration", cfg)
1540+
1541+
# Patch authorization to exclude MODEL_OVERRIDE from authorized actions
1542+
from authorization.resolvers import NoopRolesResolver
1543+
1544+
access_resolver = mocker.Mock()
1545+
access_resolver.check_access.return_value = True
1546+
access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE}
1547+
mocker.patch(
1548+
"authorization.middleware.get_authorization_resolvers",
1549+
return_value=(NoopRolesResolver(), access_resolver),
1550+
)
1551+
1552+
# Build a request that tries to override model/provider
1553+
query_request = QueryRequest(query="What?", model="m", provider="p")
1554+
1555+
with pytest.raises(HTTPException) as exc_info:
1556+
await query_endpoint_handler(
1557+
request=dummy_request, query_request=query_request, auth=MOCK_AUTH
1558+
)
1559+
1560+
expected_msg = (
1561+
"This instance does not permit overriding model/provider in the query request "
1562+
"(missing permission: MODEL_OVERRIDE). Please remove the model and provider "
1563+
"fields from your request."
1564+
)
1565+
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
1566+
assert exc_info.value.detail["response"] == expected_msg

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444

4545
from models.requests import QueryRequest, Attachment
46-
from models.config import ModelContextProtocolServer
46+
from models.config import ModelContextProtocolServer, Action
4747
from utils.types import ToolCallSummary, TurnSummary
4848

4949
MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token")
@@ -1515,3 +1515,63 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
15151515
stream=True,
15161516
toolgroups=expected_toolgroups,
15171517
)
1518+
1519+
1520+
@pytest.mark.asyncio
1521+
async def test_streaming_query_endpoint_rejects_model_provider_override_without_permission(
1522+
mocker,
1523+
):
1524+
"""Assert 403 when request includes model/provider without MODEL_OVERRIDE."""
1525+
cfg = AppConfig()
1526+
cfg.init_from_dict(
1527+
{
1528+
"name": "test",
1529+
"service": {
1530+
"host": "localhost",
1531+
"port": 8080,
1532+
"auth_enabled": False,
1533+
"workers": 1,
1534+
"color_log": True,
1535+
"access_log": True,
1536+
},
1537+
"llama_stack": {
1538+
"api_key": "test-key",
1539+
"url": "http://test.com:1234",
1540+
"use_as_library_client": False,
1541+
},
1542+
"user_data_collection": {"transcripts_enabled": False},
1543+
"mcp_servers": [],
1544+
}
1545+
)
1546+
mocker.patch("app.endpoints.streaming_query.configuration", cfg)
1547+
1548+
# Patch authorization to exclude MODEL_OVERRIDE from authorized actions
1549+
from authorization.resolvers import NoopRolesResolver
1550+
1551+
access_resolver = mocker.Mock()
1552+
access_resolver.check_access.return_value = True
1553+
access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE}
1554+
mocker.patch(
1555+
"authorization.middleware.get_authorization_resolvers",
1556+
return_value=(NoopRolesResolver(), access_resolver),
1557+
)
1558+
1559+
# Build a query request that tries to override model/provider
1560+
query_request = QueryRequest(query="What?", model="m", provider="p")
1561+
1562+
request = Request(
1563+
scope={
1564+
"type": "http",
1565+
}
1566+
)
1567+
1568+
with pytest.raises(HTTPException) as exc_info:
1569+
await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH)
1570+
1571+
expected_msg = (
1572+
"This instance does not permit overriding model/provider in the query request "
1573+
"(missing permission: MODEL_OVERRIDE). Please remove the model and provider "
1574+
"fields from your request."
1575+
)
1576+
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
1577+
assert exc_info.value.detail["response"] == expected_msg

tests/unit/utils/test_endpoints.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from models.requests import QueryRequest
1212
from utils import endpoints
1313
from utils.endpoints import get_agent
14+
from models.config import Action
1415

1516
CONFIGURED_SYSTEM_PROMPT = "This is a configured system prompt"
1617

@@ -591,3 +592,25 @@ async def test_get_agent_no_tools_false_preserves_parser(
591592
tool_parser=mock_parser,
592593
enable_session_persistence=True,
593594
)
595+
596+
597+
def test_validate_model_provider_override_allowed_with_action():
598+
"""Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider."""
599+
query_request = QueryRequest(query="q", model="m", provider="p")
600+
authorized_actions = {Action.MODEL_OVERRIDE}
601+
endpoints.validate_model_provider_override(query_request, authorized_actions)
602+
603+
604+
def test_validate_model_provider_override_rejected_without_action():
605+
"""Ensure HTTP 403 when request includes model/provider and caller lacks permission."""
606+
query_request = QueryRequest(query="q", model="m", provider="p")
607+
authorized_actions: set[Action] = set()
608+
with pytest.raises(HTTPException) as exc_info:
609+
endpoints.validate_model_provider_override(query_request, authorized_actions)
610+
assert exc_info.value.status_code == 403
611+
612+
613+
def test_validate_model_provider_override_no_override_without_action():
614+
"""No exception when request does not include model/provider regardless of permission."""
615+
query_request = QueryRequest(query="q")
616+
endpoints.validate_model_provider_override(query_request, set())

0 commit comments

Comments
 (0)