diff --git a/src/app/endpoints/tools.py b/src/app/endpoints/tools.py new file mode 100644 index 00000000..31eba4fd --- /dev/null +++ b/src/app/endpoints/tools.py @@ -0,0 +1,190 @@ +"""Handler for REST API call to list available tools from MCP servers.""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from llama_stack_client import APIConnectionError + +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +from models.config import Action +from models.responses import ToolsResponse +from utils.endpoints import check_configuration_loaded +from utils.tool_formatter import format_tools_list + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["tools"]) + + +tools_responses: dict[int | str, dict[str, Any]] = { + 200: { + "description": "Successful Response", + "content": { + "application/json": { + "example": { + "tools": [ + { + "identifier": "", + "description": "", + "parameters": [ + { + "name": "", + "description": "", + "parameter_type": "", + "required": "True/False", + "default": "null", + } + ], + "provider_id": "", + "toolgroup_id": "", + "server_source": "", + "type": "tool", + } + ] + } + } + }, + }, + 500: {"description": "Connection to Llama Stack is broken or MCP server error"}, +} + + +@router.get("/tools", responses=tools_responses) +@authorize(Action.GET_TOOLS) +async def tools_endpoint_handler( + request: Request, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], +) -> ToolsResponse: + """ + Handle requests to the /tools endpoint. + + Process GET requests to the /tools endpoint, returning a consolidated list of + available tools from all configured MCP servers. + + Raises: + HTTPException: If unable to connect to the Llama Stack server or if + tool retrieval fails for any reason. + + Returns: + ToolsResponse: An object containing the consolidated list of available tools + with metadata including tool name, description, parameters, and server source. + """ + # Used only by the middleware + _ = auth + + # Nothing interesting in the request + _ = request + + check_configuration_loaded(configuration) + + try: + # Get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + + consolidated_tools = [] + mcp_server_names = ( + {mcp_server.name for mcp_server in configuration.mcp_servers} + if configuration.mcp_servers + else set() + ) + + # Get all available toolgroups + try: + logger.debug("Retrieving tools from all toolgroups") + toolgroups_response = await client.toolgroups.list() + + for toolgroup in toolgroups_response: + try: + # Get tools for each toolgroup + tools_response = await client.tools.list( + toolgroup_id=toolgroup.identifier + ) + + # Convert tools to dict format + tools_count = 0 + server_source = "unknown" + + for tool in tools_response: + tool_dict = dict(tool) + + # Determine server source based on toolgroup type + if toolgroup.identifier in mcp_server_names: + # This is an MCP server toolgroup + mcp_server = next( + ( + s + for s in configuration.mcp_servers + if s.name == toolgroup.identifier + ), + None, + ) + tool_dict["server_source"] = ( + mcp_server.url if mcp_server else toolgroup.identifier + ) + else: + # This is a built-in toolgroup + tool_dict["server_source"] = "builtin" + + consolidated_tools.append(tool_dict) + tools_count += 1 + server_source = tool_dict["server_source"] + + logger.debug( + "Retrieved %d tools from toolgroup %s (source: %s)", + tools_count, + toolgroup.identifier, + server_source, + ) + + except Exception as e: # pylint: disable=broad-exception-caught + # Catch any exception from individual toolgroup failures to allow + # processing of other toolgroups to continue (partial failure scenario) + logger.warning( + "Failed to retrieve tools from toolgroup %s: %s", + toolgroup.identifier, + e, + ) + continue + + except APIConnectionError as e: + logger.warning("Failed to retrieve tools from toolgroups: %s", e) + raise + except (ValueError, AttributeError) as e: + logger.warning("Failed to retrieve tools from toolgroups: %s", e) + + logger.info( + "Retrieved total of %d tools (%d from built-in toolgroups, %d from MCP servers)", + len(consolidated_tools), + len([t for t in consolidated_tools if t.get("server_source") == "builtin"]), + len([t for t in consolidated_tools if t.get("server_source") != "builtin"]), + ) + + # Format tools with structured description parsing + formatted_tools = format_tools_list(consolidated_tools) + + return ToolsResponse(tools=formatted_tools) + + # Connection to Llama Stack server + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + # Any other exception that can occur during tool listing + except Exception as e: + logger.error("Unable to retrieve list of tools: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to retrieve list of tools", + "cause": str(e), + }, + ) from e diff --git a/src/app/routers.py b/src/app/routers.py index 42606cea..7cd98203 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -16,6 +16,7 @@ conversations, conversations_v2, metrics, + tools, ) @@ -28,6 +29,7 @@ def include_routers(app: FastAPI) -> None: app.include_router(root.router) app.include_router(info.router, prefix="/v1") app.include_router(models.router, prefix="/v1") + app.include_router(tools.router, prefix="/v1") app.include_router(shields.router, prefix="/v1") app.include_router(query.router, prefix="/v1") app.include_router(streaming_query.router, prefix="/v1") diff --git a/src/models/config.py b/src/models/config.py index 99850a50..a09e055a 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -358,6 +358,7 @@ class Action(str, Enum): DELETE_CONVERSATION = "delete_conversation" FEEDBACK = "feedback" GET_MODELS = "get_models" + GET_TOOLS = "get_tools" GET_SHIELDS = "get_shields" GET_METRICS = "get_metrics" GET_CONFIG = "get_config" diff --git a/src/models/responses.py b/src/models/responses.py index 09db643f..80b97053 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -36,6 +36,37 @@ class ModelsResponse(BaseModel): ) +class ToolsResponse(BaseModel): + """Model representing a response to tools request.""" + + tools: list[dict[str, Any]] = Field( + description=( + "List of tools available from all configured MCP servers and built-in toolgroups" + ), + examples=[ + [ + { + "identifier": "filesystem_read", + "description": "Read contents of a file from the filesystem", + "parameters": [ + { + "name": "path", + "description": "Path to the file to read", + "parameter_type": "string", + "required": True, + "default": None, + } + ], + "provider_id": "model-context-protocol", + "toolgroup_id": "filesystem-tools", + "server_source": "http://localhost:3000", + "type": "tool", + } + ] + ], + ) + + class ShieldsResponse(BaseModel): """Model representing a response to shields request.""" diff --git a/src/utils/tool_formatter.py b/src/utils/tool_formatter.py new file mode 100644 index 00000000..4942c3d8 --- /dev/null +++ b/src/utils/tool_formatter.py @@ -0,0 +1,108 @@ +"""Utility functions for formatting and parsing MCP tool descriptions.""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def format_tool_response(tool_dict: dict[str, Any]) -> dict[str, Any]: + """ + Format a tool dictionary to include only required fields. + + Args: + tool_dict: Raw tool dictionary from Llama Stack + + Returns: + Formatted tool dictionary with only required fields + """ + # Clean up description if it contains structured metadata + description = tool_dict.get("description", "") + if description and ("TOOL_NAME=" in description or "DISPLAY_NAME=" in description): + # Extract clean description from structured metadata + clean_description = extract_clean_description(description) + description = clean_description + + # Extract only the required fields + formatted_tool = { + "identifier": tool_dict.get("identifier", ""), + "description": description, + "parameters": tool_dict.get("parameters", []), + "provider_id": tool_dict.get("provider_id", ""), + "toolgroup_id": tool_dict.get("toolgroup_id", ""), + "server_source": tool_dict.get("server_source", ""), + "type": tool_dict.get("type", ""), + } + + return formatted_tool + + +def extract_clean_description(description: str) -> str: + """ + Extract a clean description from structured metadata format. + + Args: + description: Raw description with structured metadata + + Returns: + Clean description without metadata + """ + min_description_length = 20 + fallback_truncation_length = 200 + + try: + # Look for the main description after all the metadata + description_parts = description.split("\n\n") + for part in description_parts: + if not any( + part.strip().startswith(prefix) + for prefix in [ + "TOOL_NAME=", + "DISPLAY_NAME=", + "USECASE=", + "INSTRUCTIONS=", + "INPUT_DESCRIPTION=", + "OUTPUT_DESCRIPTION=", + "EXAMPLES=", + "PREREQUISITES=", + "AGENT_DECISION_CRITERIA=", + ] + ): + if ( + part.strip() and len(part.strip()) > min_description_length + ): # Reasonable description length + return part.strip() + + # If no clean description found, try to extract from USECASE + lines = description.split("\n") + for line in lines: + if line.startswith("USECASE="): + return line.replace("USECASE=", "").strip() + + # Fallback to first 200 characters + return ( + description[:fallback_truncation_length] + "..." + if len(description) > fallback_truncation_length + else description + ) + + except (ValueError, AttributeError) as e: + logger.warning("Failed to extract clean description: %s", e) + return ( + description[:fallback_truncation_length] + "..." + if len(description) > fallback_truncation_length + else description + ) + + +def format_tools_list(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Format a list of tools with structured description parsing. + + Args: + tools: List of raw tool dictionaries + + Returns: + List of formatted tool dictionaries + """ + return [format_tool_response(tool) for tool in tools] diff --git a/tests/unit/app/endpoints/test_tools.py b/tests/unit/app/endpoints/test_tools.py new file mode 100644 index 00000000..bcf4ddc9 --- /dev/null +++ b/tests/unit/app/endpoints/test_tools.py @@ -0,0 +1,543 @@ +"""Unit tests for tools endpoint.""" + +import pytest +from fastapi import HTTPException + +from llama_stack_client import APIConnectionError + +# Import the function directly to bypass decorators +from app.endpoints import tools +from models.responses import ToolsResponse +from models.config import ( + Configuration, + ServiceConfiguration, + LlamaStackConfiguration, + UserDataCollection, + ModelContextProtocolServer, +) + +# Shared mock auth tuple with 4 fields as expected by the application +MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") + + +@pytest.fixture +def mock_configuration(): + """Create a mock configuration with MCP servers.""" + return Configuration( + name="test", + service=ServiceConfiguration(), + llama_stack=LlamaStackConfiguration(url="http://localhost:8321"), + user_data_collection=UserDataCollection(feedback_enabled=False), + mcp_servers=[ + ModelContextProtocolServer( + name="filesystem-tools", + provider_id="model-context-protocol", + url="http://localhost:3000", + ), + ModelContextProtocolServer( + name="git-tools", + provider_id="model-context-protocol", + url="http://localhost:3001", + ), + ], + ) + + +@pytest.fixture +def mock_tools_response(mocker): + """Create mock tools response from LlamaStack client.""" + # Create mock tools that behave like dict when converted + tool1 = mocker.Mock() + tool1.__dict__.update( + { + "identifier": "filesystem_read", + "description": "Read contents of a file from the filesystem", + "parameters": [ + { + "name": "path", + "description": "Path to the file to read", + "parameter_type": "string", + "required": True, + "default": None, + } + ], + "provider_id": "model-context-protocol", + "toolgroup_id": "filesystem-tools", + "type": "tool", + "metadata": {}, + } + ) + # Make dict() work on the mock + tool1.keys.return_value = tool1.__dict__.keys() + tool1.__getitem__ = lambda self, key: self.__dict__[key] + tool1.__iter__ = lambda self: iter(self.__dict__) + + tool2 = mocker.Mock() + tool2.__dict__.update( + { + "identifier": "git_status", + "description": "Get the status of a git repository", + "parameters": [ + { + "name": "repository_path", + "description": "Path to the git repository", + "parameter_type": "string", + "required": True, + "default": None, + } + ], + "provider_id": "model-context-protocol", + "toolgroup_id": "git-tools", + "type": "tool", + "metadata": {}, + } + ) + # Make dict() work on the mock + tool2.keys.return_value = tool2.__dict__.keys() + tool2.__getitem__ = lambda self, key: self.__dict__[key] + tool2.__iter__ = lambda self: iter(self.__dict__) + + return [tool1, tool2] + + +@pytest.mark.asyncio +async def test_tools_endpoint_success( + mocker, mock_configuration, mock_tools_response +): # pylint: disable=redefined-outer-name + """Test successful tools endpoint response.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list response + mock_toolgroup1 = mocker.Mock() + mock_toolgroup1.identifier = "filesystem-tools" + mock_toolgroup2 = mocker.Mock() + mock_toolgroup2.identifier = "git-tools" + mock_client.toolgroups.list.return_value = [mock_toolgroup1, mock_toolgroup2] + + # Mock tools.list responses for each MCP server + mock_client.tools.list.side_effect = [ + [mock_tools_response[0]], # filesystem-tools response + [mock_tools_response[1]], # git-tools response + ] + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpoint + response = await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + # Verify response + assert isinstance(response, ToolsResponse) + assert len(response.tools) == 2 + + # Verify first tool + tool1 = response.tools[0] + assert tool1["identifier"] == "filesystem_read" + assert tool1["description"] == "Read contents of a file from the filesystem" + assert tool1["server_source"] == "http://localhost:3000" + assert tool1["toolgroup_id"] == "filesystem-tools" + + # Verify second tool + tool2 = response.tools[1] + assert tool2["identifier"] == "git_status" + assert tool2["description"] == "Get the status of a git repository" + assert tool2["server_source"] == "http://localhost:3001" + assert tool2["toolgroup_id"] == "git-tools" + + # Verify client calls + assert mock_client.tools.list.call_count == 2 + mock_client.tools.list.assert_any_call(toolgroup_id="filesystem-tools") + mock_client.tools.list.assert_any_call(toolgroup_id="git-tools") + + +@pytest.mark.asyncio +async def test_tools_endpoint_no_mcp_servers(mocker): + """Test tools endpoint with no MCP servers configured.""" + # Mock configuration with no MCP servers + mock_config = Configuration( + name="test", + service=ServiceConfiguration(), + llama_stack=LlamaStackConfiguration(url="http://localhost:8321"), + user_data_collection=UserDataCollection(feedback_enabled=False), + mcp_servers=[], + ) + mocker.patch("app.endpoints.tools.configuration", mock_config) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list response - empty for no MCP servers + mock_client.toolgroups.list.return_value = [] + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpoint + response = await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + # Verify response + assert isinstance(response, ToolsResponse) + assert len(response.tools) == 0 + + +@pytest.mark.asyncio +async def test_tools_endpoint_api_connection_error( + mocker, mock_configuration +): # pylint: disable=redefined-outer-name + """Test tools endpoint with API connection error from individual servers.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list response + mock_toolgroup1 = mocker.Mock() + mock_toolgroup1.identifier = "filesystem-tools" + mock_toolgroup2 = mocker.Mock() + mock_toolgroup2.identifier = "git-tools" + mock_client.toolgroups.list.return_value = [mock_toolgroup1, mock_toolgroup2] + + # Mock API connection error - create a proper APIConnectionError + api_error = APIConnectionError(request=mocker.Mock()) + mock_client.tools.list.side_effect = api_error + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpointt - should not raise exception but return empty tools + response = await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + # Verify response - should be empty since all servers failed + assert isinstance(response, ToolsResponse) + assert len(response.tools) == 0 + + +@pytest.mark.asyncio +async def test_tools_endpoint_partial_failure( # pylint: disable=redefined-outer-name + mocker, mock_configuration, mock_tools_response +): + """Test tools endpoint with one MCP server failing.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list response + mock_toolgroup1 = mocker.Mock() + mock_toolgroup1.identifier = "filesystem-tools" + mock_toolgroup2 = mocker.Mock() + mock_toolgroup2.identifier = "git-tools" + mock_client.toolgroups.list.return_value = [mock_toolgroup1, mock_toolgroup2] + + # Mock tools.list responses - first succeeds, second fails + mock_client.tools.list.side_effect = [ + [mock_tools_response[0]], # filesystem-tools response + Exception("Server unavailable"), # git-tools fails + ] + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpoint + response = await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + # Verify response - should have only one tool from the successful server + assert isinstance(response, ToolsResponse) + assert len(response.tools) == 1 + assert response.tools[0]["identifier"] == "filesystem_read" + assert response.tools[0]["server_source"] == "http://localhost:3000" + + # Verify both servers were attempted + assert mock_client.tools.list.call_count == 2 + + +@pytest.mark.asyncio +async def test_tools_endpoint_builtin_toolgroup( + mocker, mock_configuration +): # pylint: disable=redefined-outer-name + """Test tools endpoint with built-in toolgroups.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list response with built-in toolgroup + mock_toolgroup = mocker.Mock() + mock_toolgroup.identifier = "builtin-tools" # Not in MCP server names + mock_client.toolgroups.list.return_value = [mock_toolgroup] + + # Mock tools.list response for built-in toolgroup + mock_tool = mocker.Mock() + mock_tool.__dict__.update( + { + "identifier": "builtin_tool", + "description": "A built-in tool", + "parameters": [], + "provider_id": "builtin", + "toolgroup_id": "builtin-tools", + "type": "tool", + "metadata": {}, + } + ) + mock_tool.keys.return_value = mock_tool.__dict__.keys() + mock_tool.__getitem__ = lambda self, key: self.__dict__[key] + mock_tool.__iter__ = lambda self: iter(self.__dict__) + + mock_client.tools.list.return_value = [mock_tool] + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpoint + response = await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + # Verify response + assert isinstance(response, ToolsResponse) + assert len(response.tools) == 1 + assert response.tools[0]["identifier"] == "builtin_tool" + assert response.tools[0]["server_source"] == "builtin" + + +@pytest.mark.asyncio +async def test_tools_endpoint_mixed_toolgroups(mocker): + """Test tools endpoint with both MCP and built-in toolgroups.""" + # Mock configuration with MCP servers + mock_config = Configuration( + name="test", + service=ServiceConfiguration(), + llama_stack=LlamaStackConfiguration(url="http://localhost:8321"), + user_data_collection=UserDataCollection(feedback_enabled=False), + mcp_servers=[ + ModelContextProtocolServer( + name="filesystem-tools", + provider_id="model-context-protocol", + url="http://localhost:3000", + ), + ], + ) + mocker.patch("app.endpoints.tools.configuration", mock_config) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list response with both MCP and built-in toolgroups + mock_toolgroup1 = mocker.Mock() + mock_toolgroup1.identifier = "filesystem-tools" # MCP server toolgroup + mock_toolgroup2 = mocker.Mock() + mock_toolgroup2.identifier = "builtin-tools" # Built-in toolgroup + mock_client.toolgroups.list.return_value = [mock_toolgroup1, mock_toolgroup2] + + # Mock tools.list responses + mock_tool1 = mocker.Mock() + mock_tool1.__dict__.update( + { + "identifier": "filesystem_read", + "description": "Read file", + "parameters": [], + "provider_id": "model-context-protocol", + "toolgroup_id": "filesystem-tools", + "type": "tool", + "metadata": {}, + } + ) + mock_tool1.keys.return_value = mock_tool1.__dict__.keys() + mock_tool1.__getitem__ = lambda self, key: self.__dict__[key] + mock_tool1.__iter__ = lambda self: iter(self.__dict__) + + mock_tool2 = mocker.Mock() + mock_tool2.__dict__.update( + { + "identifier": "builtin_tool", + "description": "Built-in tool", + "parameters": [], + "provider_id": "builtin", + "toolgroup_id": "builtin-tools", + "type": "tool", + "metadata": {}, + } + ) + mock_tool2.keys.return_value = mock_tool2.__dict__.keys() + mock_tool2.__getitem__ = lambda self, key: self.__dict__[key] + mock_tool2.__iter__ = lambda self: iter(self.__dict__) + + mock_client.tools.list.side_effect = [[mock_tool1], [mock_tool2]] + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpoint + response = await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + # Verify response - should have both tools with correct server sources + assert isinstance(response, ToolsResponse) + assert len(response.tools) == 2 + + # Find tools by identifier to avoid order dependency + mcp_tool = next(t for t in response.tools if t["identifier"] == "filesystem_read") + builtin_tool = next(t for t in response.tools if t["identifier"] == "builtin_tool") + + assert mcp_tool["server_source"] == "http://localhost:3000" + assert builtin_tool["server_source"] == "builtin" + + +@pytest.mark.asyncio +async def test_tools_endpoint_value_attribute_error( + mocker, mock_configuration +): # pylint: disable=redefined-outer-name + """Test tools endpoint with ValueError/AttributeError in toolgroups.list.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list to raise ValueError + mock_client.toolgroups.list.side_effect = ValueError("Invalid response format") + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpointt - should not raise exception but return empty tools + response = await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + # Verify response - should be empty since toolgroups.list failed + assert isinstance(response, ToolsResponse) + assert len(response.tools) == 0 + + +@pytest.mark.asyncio +async def test_tools_endpoint_apiconnection_error_toolgroups( # pylint: disable=redefined-outer-name + mocker, mock_configuration +): + """Test tools endpoint with APIConnectionError in toolgroups.list.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder and clien + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock toolgroups.list to raise APIConnectionError + api_error = APIConnectionError(request=mocker.Mock()) + mock_client.toolgroups.list.side_effect = api_error + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpointt and expect HTTPException + with pytest.raises(HTTPException) as exc_info: + await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + assert exc_info.value.status_code == 500 + assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] + + +@pytest.mark.asyncio +async def test_tools_endpoint_client_holder_apiconnection_error( # pylint: disable=redefined-outer-name + mocker, mock_configuration +): + """Test tools endpoint with APIConnectionError in client holder.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder to raise APIConnectionError + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + api_error = APIConnectionError(request=mocker.Mock()) + mock_client_holder.return_value.get_client.side_effect = api_error + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpointt and expect HTTPException + with pytest.raises(HTTPException) as exc_info: + await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + assert exc_info.value.status_code == 500 + assert "Unable to connect to Llama Stack" in exc_info.value.detail["response"] + + +@pytest.mark.asyncio +async def test_tools_endpoint_general_exception( + mocker, mock_configuration +): # pylint: disable=redefined-outer-name + """Test tools endpoint with general exception.""" + # Mock configuration + mocker.patch("app.endpoints.tools.configuration", mock_configuration) + + # Mock authorization decorator to bypass i + mocker.patch("app.endpoints.tools.authorize", lambda action: lambda func: func) + + # Mock client holder to raise exception + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client_holder.return_value.get_client.side_effect = Exception( + "Unexpected error" + ) + + # Mock request and auth + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + # Call the endpointt and expect HTTPException + with pytest.raises(HTTPException) as exc_info: + await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + assert exc_info.value.status_code == 500 + assert "Unable to retrieve list of tools" in exc_info.value.detail["response"] diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index d1ef55a2..9b0b5520 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -20,6 +20,7 @@ streaming_query, authorized, metrics, + tools, ) # noqa:E402 @@ -62,10 +63,11 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 13 + assert len(app.routers) == 14 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() + assert tools.router in app.get_routers() assert shields.router in app.get_routers() assert query.router in app.get_routers() assert streaming_query.router in app.get_routers() @@ -74,6 +76,7 @@ def test_include_routers() -> None: assert health.router in app.get_routers() assert authorized.router in app.get_routers() assert conversations.router in app.get_routers() + assert conversations_v2.router in app.get_routers() assert metrics.router in app.get_routers() @@ -83,10 +86,11 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 13 + assert len(app.routers) == 14 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" + assert app.get_router_prefix(tools.router) == "/v1" assert app.get_router_prefix(shields.router) == "/v1" assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1"