Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions src/app/endpoints/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""Handler for REST API call to list available tools from MCP servers."""

import logging
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_stack_client import APIConnectionError

from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from models.config import Action
from models.responses import ToolsResponse
from utils.endpoints import check_configuration_loaded
from utils.tool_formatter import format_tools_list

logger = logging.getLogger(__name__)
router = APIRouter(tags=["tools"])


tools_responses: dict[int | str, dict[str, Any]] = {
200: {
"description": "Successful Response",
"content": {
"application/json": {
"example": {
"tools": [
{
"identifier": "",
"description": "",
"parameters": [
{
"name": "",
"description": "",
"parameter_type": "",
"required": "True/False",
"default": "null",
}
],
"provider_id": "",
"toolgroup_id": "",
"server_source": "",
"type": "tool",
}
]
}
}
},
},
500: {"description": "Connection to Llama Stack is broken or MCP server error"},
}


@router.get("/tools", responses=tools_responses)
@authorize(Action.GET_TOOLS)
async def tools_endpoint_handler(
request: Request,
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
) -> ToolsResponse:
"""
Handle requests to the /tools endpoint.

Process GET requests to the /tools endpoint, returning a consolidated list of
available tools from all configured MCP servers.

Raises:
HTTPException: If unable to connect to the Llama Stack server or if
tool retrieval fails for any reason.

Returns:
ToolsResponse: An object containing the consolidated list of available tools
with metadata including tool name, description, parameters, and server source.
"""
# Used only by the middleware
_ = auth

# Nothing interesting in the request
_ = request

check_configuration_loaded(configuration)

try:
# Get Llama Stack client
client = AsyncLlamaStackClientHolder().get_client()

consolidated_tools = []
mcp_server_names = (
{mcp_server.name for mcp_server in configuration.mcp_servers}
if configuration.mcp_servers
else set()
)

# Get all available toolgroups
try:
logger.debug("Retrieving tools from all toolgroups")
toolgroups_response = await client.toolgroups.list()

for toolgroup in toolgroups_response:
try:
# Get tools for each toolgroup
tools_response = await client.tools.list(
toolgroup_id=toolgroup.identifier
)

# Convert tools to dict format
tools_count = 0
server_source = "unknown"

for tool in tools_response:
tool_dict = dict(tool)

# Determine server source based on toolgroup type
if toolgroup.identifier in mcp_server_names:
# This is an MCP server toolgroup
mcp_server = next(
(
s
for s in configuration.mcp_servers
if s.name == toolgroup.identifier
),
None,
)
tool_dict["server_source"] = (
mcp_server.url if mcp_server else toolgroup.identifier
)
else:
# This is a built-in toolgroup
tool_dict["server_source"] = "builtin"

consolidated_tools.append(tool_dict)
tools_count += 1
server_source = tool_dict["server_source"]

logger.debug(
"Retrieved %d tools from toolgroup %s (source: %s)",
tools_count,
toolgroup.identifier,
server_source,
)

except Exception as e: # pylint: disable=broad-exception-caught
# Catch any exception from individual toolgroup failures to allow
# processing of other toolgroups to continue (partial failure scenario)
logger.warning(
"Failed to retrieve tools from toolgroup %s: %s",
toolgroup.identifier,
e,
)
continue

except APIConnectionError as e:
logger.warning("Failed to retrieve tools from toolgroups: %s", e)
raise
except (ValueError, AttributeError) as e:
logger.warning("Failed to retrieve tools from toolgroups: %s", e)

logger.info(
"Retrieved total of %d tools (%d from built-in toolgroups, %d from MCP servers)",
len(consolidated_tools),
len([t for t in consolidated_tools if t.get("server_source") == "builtin"]),
len([t for t in consolidated_tools if t.get("server_source") != "builtin"]),
)

# Format tools with structured description parsing
formatted_tools = format_tools_list(consolidated_tools)

return ToolsResponse(tools=formatted_tools)

# Connection to Llama Stack server
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unable to connect to Llama Stack",
"cause": str(e),
},
) from e
# Any other exception that can occur during tool listing
except Exception as e:
logger.error("Unable to retrieve list of tools: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unable to retrieve list of tools",
"cause": str(e),
},
) from e
2 changes: 2 additions & 0 deletions src/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
conversations,
conversations_v2,
metrics,
tools,
)


Expand All @@ -28,6 +29,7 @@ def include_routers(app: FastAPI) -> None:
app.include_router(root.router)
app.include_router(info.router, prefix="/v1")
app.include_router(models.router, prefix="/v1")
app.include_router(tools.router, prefix="/v1")
app.include_router(shields.router, prefix="/v1")
app.include_router(query.router, prefix="/v1")
app.include_router(streaming_query.router, prefix="/v1")
Expand Down
1 change: 1 addition & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ class Action(str, Enum):
DELETE_CONVERSATION = "delete_conversation"
FEEDBACK = "feedback"
GET_MODELS = "get_models"
GET_TOOLS = "get_tools"
GET_SHIELDS = "get_shields"
GET_METRICS = "get_metrics"
GET_CONFIG = "get_config"
Expand Down
31 changes: 31 additions & 0 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,37 @@ class ModelsResponse(BaseModel):
)


class ToolsResponse(BaseModel):
"""Model representing a response to tools request."""

tools: list[dict[str, Any]] = Field(
description=(
"List of tools available from all configured MCP servers and built-in toolgroups"
),
examples=[
[
{
"identifier": "filesystem_read",
"description": "Read contents of a file from the filesystem",
"parameters": [
{
"name": "path",
"description": "Path to the file to read",
"parameter_type": "string",
"required": True,
"default": None,
}
],
"provider_id": "model-context-protocol",
"toolgroup_id": "filesystem-tools",
"server_source": "http://localhost:3000",
"type": "tool",
}
]
],
)


class ShieldsResponse(BaseModel):
"""Model representing a response to shields request."""

Expand Down
108 changes: 108 additions & 0 deletions src/utils/tool_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Utility functions for formatting and parsing MCP tool descriptions."""

import logging
from typing import Any

logger = logging.getLogger(__name__)


def format_tool_response(tool_dict: dict[str, Any]) -> dict[str, Any]:
"""
Format a tool dictionary to include only required fields.

Args:
tool_dict: Raw tool dictionary from Llama Stack

Returns:
Formatted tool dictionary with only required fields
"""
# Clean up description if it contains structured metadata
description = tool_dict.get("description", "")
if description and ("TOOL_NAME=" in description or "DISPLAY_NAME=" in description):
# Extract clean description from structured metadata
clean_description = extract_clean_description(description)
description = clean_description

# Extract only the required fields
formatted_tool = {
"identifier": tool_dict.get("identifier", ""),
"description": description,
"parameters": tool_dict.get("parameters", []),
"provider_id": tool_dict.get("provider_id", ""),
"toolgroup_id": tool_dict.get("toolgroup_id", ""),
"server_source": tool_dict.get("server_source", ""),
"type": tool_dict.get("type", ""),
}

return formatted_tool


def extract_clean_description(description: str) -> str:
"""
Extract a clean description from structured metadata format.

Args:
description: Raw description with structured metadata

Returns:
Clean description without metadata
"""
min_description_length = 20
fallback_truncation_length = 200

try:
# Look for the main description after all the metadata
description_parts = description.split("\n\n")
for part in description_parts:
if not any(
part.strip().startswith(prefix)
for prefix in [
"TOOL_NAME=",
"DISPLAY_NAME=",
"USECASE=",
"INSTRUCTIONS=",
"INPUT_DESCRIPTION=",
"OUTPUT_DESCRIPTION=",
"EXAMPLES=",
"PREREQUISITES=",
"AGENT_DECISION_CRITERIA=",
]
):
if (
part.strip() and len(part.strip()) > min_description_length
): # Reasonable description length
return part.strip()

# If no clean description found, try to extract from USECASE
lines = description.split("\n")
for line in lines:
if line.startswith("USECASE="):
return line.replace("USECASE=", "").strip()

# Fallback to first 200 characters
return (
description[:fallback_truncation_length] + "..."
if len(description) > fallback_truncation_length
else description
)

except (ValueError, AttributeError) as e:
logger.warning("Failed to extract clean description: %s", e)
return (
description[:fallback_truncation_length] + "..."
if len(description) > fallback_truncation_length
else description
)


def format_tools_list(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Format a list of tools with structured description parsing.

Args:
tools: List of raw tool dictionaries

Returns:
List of formatted tool dictionaries
"""
return [format_tool_response(tool) for tool in tools]
Loading
Loading