From 6985568c7414bec82dd88e194e7481ac120f12ed Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Tue, 15 Jul 2025 14:17:43 +0300 Subject: [PATCH] Add /conversations endpoint for conversation history management - Add GET /v1/conversations/{conversation_id} to retrieve conversation history - Add DELETE /v1/conversations/{conversation_id} to delete conversations - Use llama-stack client.agents.session.retrieve and .delete methods - Map conversation ID to agent ID for LlamaStack operations - Add ConversationResponse and ConversationDeleteResponse models - Include conversations router in main app routing - Maintain consistent error handling and authentication patterns --- src/app/endpoints/conversations.py | 263 +++++++++++++++++++++++++++ src/app/endpoints/query.py | 3 + src/app/endpoints/streaming_query.py | 2 + src/app/routers.py | 2 + src/models/responses.py | 86 +++++++++ tests/unit/app/test_routers.py | 7 +- 6 files changed, 361 insertions(+), 2 deletions(-) create mode 100644 src/app/endpoints/conversations.py diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py new file mode 100644 index 00000000..90208997 --- /dev/null +++ b/src/app/endpoints/conversations.py @@ -0,0 +1,263 @@ +"""Handler for REST API calls to manage conversation history.""" + +import logging +from typing import Any + +from llama_stack_client import APIConnectionError, NotFoundError + +from fastapi import APIRouter, HTTPException, status, Depends + +from client import LlamaStackClientHolder +from configuration import configuration +from models.responses import ConversationResponse, ConversationDeleteResponse +from auth import get_auth_dependency +from utils.endpoints import check_configuration_loaded +from utils.suid import check_suid + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["conversations"]) +auth_dependency = get_auth_dependency() + +conversation_id_to_agent_id: dict[str, str] = {} + +conversation_responses: dict[int | str, dict[str, Any]] = { + 200: { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "session_data": { + "session_id": "123e4567-e89b-12d3-a456-426614174000", + "turns": [], + "started_at": "2024-01-01T00:00:00Z", + }, + }, + 404: { + "detail": { + "response": "Conversation not found", + "cause": "The specified conversation ID does not exist.", + } + }, + 503: { + "detail": { + "response": "Unable to connect to Llama Stack", + "cause": "Connection error.", + } + }, +} + +conversation_delete_responses: dict[int | str, dict[str, Any]] = { + 200: { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "success": True, + "message": "Conversation deleted successfully", + }, + 404: { + "detail": { + "response": "Conversation not found", + "cause": "The specified conversation ID does not exist.", + } + }, + 503: { + "detail": { + "response": "Unable to connect to Llama Stack", + "cause": "Connection error.", + } + }, +} + + +def simplify_session_data(session_data: Any) -> list[dict[str, Any]]: + """Simplify session data to include only essential conversation information. + + Args: + session_data: The full session data from llama-stack + + Returns: + Simplified session data with only input_messages and output_message per turn + """ + session_dict = session_data.model_dump() + # Create simplified structure + chat_history = [] + + # Extract only essential data from each turn + for turn in session_dict.get("turns", []): + # Clean up input messages + cleaned_messages = [] + for msg in turn.get("input_messages", []): + cleaned_msg = { + "content": msg.get("content"), + "type": msg.get("role"), # Rename role to type + } + cleaned_messages.append(cleaned_msg) + + # Clean up output message + output_msg = turn.get("output_message", {}) + cleaned_messages.append( + { + "content": output_msg.get("content"), + "type": output_msg.get("role"), # Rename role to type + } + ) + + simplified_turn = { + "messages": cleaned_messages, + "started_at": turn.get("started_at"), + "completed_at": turn.get("completed_at"), + } + chat_history.append(simplified_turn) + + return chat_history + + +@router.get("/conversations/{conversation_id}", responses=conversation_responses) +def get_conversation_endpoint_handler( + conversation_id: str, + _auth: Any = Depends(auth_dependency), +) -> ConversationResponse: + """Handle request to retrieve a conversation by ID.""" + check_configuration_loaded(configuration) + + # Validate conversation ID format + if not check_suid(conversation_id): + logger.error("Invalid conversation ID format: %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "response": "Invalid conversation ID format", + "cause": f"Conversation ID {conversation_id} is not a valid UUID", + }, + ) + + agent_id = conversation_id_to_agent_id.get(conversation_id) + if not agent_id: + logger.error("Agent ID not found for conversation %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "conversation ID not found", + "cause": f"conversation ID {conversation_id} not found!", + }, + ) + + logger.info("Retrieving conversation %s", conversation_id) + + try: + client = LlamaStackClientHolder().get_client() + + session_data = client.agents.session.retrieve( + agent_id=agent_id, session_id=conversation_id + ) + + logger.info("Successfully retrieved conversation %s", conversation_id) + + # Simplify the session data to include only essential conversation information + chat_history = simplify_session_data(session_data) + + return ConversationResponse( + conversation_id=conversation_id, + chat_history=chat_history, + ) + + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + except NotFoundError as e: + logger.error("Conversation not found: %s", e) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation not found", + "cause": f"Conversation {conversation_id} could not be retrieved: {str(e)}", + }, + ) from e + except Exception as e: + # Handle case where session doesn't exist or other errors + logger.exception("Error retrieving conversation %s: %s", conversation_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unknown error", + "cause": f"Unknown error while getting conversation {conversation_id} : {str(e)}", + }, + ) from e + + +@router.delete( + "/conversations/{conversation_id}", responses=conversation_delete_responses +) +def delete_conversation_endpoint_handler( + conversation_id: str, + _auth: Any = Depends(auth_dependency), +) -> ConversationDeleteResponse: + """Handle request to delete a conversation by ID.""" + check_configuration_loaded(configuration) + + # Validate conversation ID format + if not check_suid(conversation_id): + logger.error("Invalid conversation ID format: %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "response": "Invalid conversation ID format", + "cause": f"Conversation ID {conversation_id} is not a valid UUID", + }, + ) + agent_id = conversation_id_to_agent_id.get(conversation_id) + if not agent_id: + logger.error("Agent ID not found for conversation %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "conversation ID not found", + "cause": f"conversation ID {conversation_id} not found!", + }, + ) + logger.info("Deleting conversation %s", conversation_id) + + try: + # Get Llama Stack client + client = LlamaStackClientHolder().get_client() + # Delete session using the conversation_id as session_id + # In this implementation, conversation_id and session_id are the same + client.agents.session.delete(agent_id=agent_id, session_id=conversation_id) + + logger.info("Successfully deleted conversation %s", conversation_id) + + return ConversationDeleteResponse( + conversation_id=conversation_id, + success=True, + response="Conversation deleted successfully", + ) + + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + except NotFoundError as e: + logger.error("Conversation not found: %s", e) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation not found", + "cause": f"Conversation {conversation_id} could not be deleted: {str(e)}", + }, + ) from e + except Exception as e: + # Handle case where session doesn't exist or other errors + logger.exception("Error deleting conversation %s: %s", conversation_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unknown error", + "cause": f"Unknown error while deleting conversation {conversation_id} : {str(e)}", + }, + ) from e diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3ebe0528..3aa833f4 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -23,6 +23,7 @@ from client import LlamaStackClientHolder from configuration import configuration +from app.endpoints.conversations import conversation_id_to_agent_id from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment import constants @@ -97,6 +98,8 @@ def get_agent( ) conversation_id = agent.create_session(get_suid()) _agent_cache[conversation_id] = agent + conversation_id_to_agent_id[conversation_id] = agent.agent_id + return agent, conversation_id diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 2a3353c4..2e2092e1 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -26,6 +26,7 @@ from utils.suid import get_suid from utils.types import GraniteToolParser +from app.endpoints.conversations import conversation_id_to_agent_id from app.endpoints.query import ( get_rag_toolgroups, is_transcripts_enabled, @@ -67,6 +68,7 @@ async def get_agent( ) conversation_id = await agent.create_session(get_suid()) _agent_cache[conversation_id] = agent + conversation_id_to_agent_id[conversation_id] = agent.agent_id return agent, conversation_id diff --git a/src/app/routers.py b/src/app/routers.py index fd17e26c..abfc0059 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -12,6 +12,7 @@ feedback, streaming_query, authorized, + conversations, ) @@ -28,6 +29,7 @@ def include_routers(app: FastAPI) -> None: app.include_router(streaming_query.router, prefix="/v1") app.include_router(config.router, prefix="/v1") app.include_router(feedback.router, prefix="/v1") + app.include_router(conversations.router, prefix="/v1") # road-core does not version these endpoints app.include_router(health.router) diff --git a/src/models/responses.py b/src/models/responses.py index a9778343..76270739 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -298,3 +298,89 @@ class ForbiddenResponse(UnauthorizedResponse): ] } } + + +class ConversationResponse(BaseModel): + """Model representing a response for retrieving a conversation. + + Attributes: + conversation_id: The conversation ID (UUID). + chat_history: The simplified chat history as a list of conversation turns. + + Example: + ```python + conversation_response = ConversationResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + chat_history=[ + { + "messages": [ + {"content": "Hello", "type": "user"}, + {"content": "Hi there!", "type": "assistant"} + ], + "started_at": "2024-01-01T00:01:00Z", + "completed_at": "2024-01-01T00:01:05Z" + } + ] + ) + ``` + """ + + conversation_id: str + chat_history: list[dict[str, Any]] + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "chat_history": [ + { + "messages": [ + {"content": "Hello", "type": "user"}, + {"content": "Hi there!", "type": "assistant"}, + ], + "started_at": "2024-01-01T00:01:00Z", + "completed_at": "2024-01-01T00:01:05Z", + } + ], + } + ] + } + } + + +class ConversationDeleteResponse(BaseModel): + """Model representing a response for deleting a conversation. + + Attributes: + conversation_id: The conversation ID (UUID) that was deleted. + success: Whether the deletion was successful. + response: A message about the deletion result. + + Example: + ```python + delete_response = ConversationDeleteResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + success=True, + response="Conversation deleted successfully" + ) + ``` + """ + + conversation_id: str + success: bool + response: str + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "success": True, + "response": "Conversation deleted successfully", + } + ] + } + } diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 674a6433..335443e5 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -5,6 +5,7 @@ from app.routers import include_routers # noqa:E402 from app.endpoints import ( + conversations, root, info, models, @@ -43,7 +44,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 9 + assert len(app.routers) == 10 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -53,6 +54,7 @@ def test_include_routers() -> None: assert feedback.router in app.get_routers() assert health.router in app.get_routers() assert authorized.router in app.get_routers() + assert conversations.router in app.get_routers() def test_check_prefixes() -> None: @@ -61,7 +63,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 9 + assert len(app.routers) == 10 assert app.get_router_prefix(root.router) is None assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -71,3 +73,4 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(feedback.router) == "/v1" assert app.get_router_prefix(health.router) is None assert app.get_router_prefix(authorized.router) is None + assert app.get_router_prefix(conversations.router) == "/v1"