Skip to content

Commit e4d201e

Browse files
authored
Merge pull request #591 from tisnik/lcore-724-endpoints-for-conversation-history-v2
LCORE-724: Endpoints for conversation cache v2
2 parents a32f966 + 79c3f84 commit e4d201e

File tree

7 files changed

+304
-2
lines changed

7 files changed

+304
-2
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
"""Handler for REST API calls to manage conversation history."""
2+
3+
import logging
4+
from typing import Any
5+
6+
from fastapi import APIRouter, Request, Depends, HTTPException, status
7+
8+
from configuration import configuration
9+
from authentication import get_auth_dependency
10+
from authorization.middleware import authorize
11+
from models.cache_entry import CacheEntry
12+
from models.config import Action
13+
from models.responses import (
14+
ConversationsListResponseV2,
15+
ConversationResponse,
16+
ConversationDeleteResponse,
17+
UnauthorizedResponse,
18+
)
19+
from utils.endpoints import check_configuration_loaded
20+
from utils.suid import check_suid
21+
22+
logger = logging.getLogger("app.endpoints.handlers")
23+
router = APIRouter(tags=["conversations_v2"])
24+
auth_dependency = get_auth_dependency()
25+
26+
27+
conversation_responses: dict[int | str, dict[str, Any]] = {
28+
200: {
29+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
30+
"chat_history": [
31+
{
32+
"messages": [
33+
{"content": "Hi", "type": "user"},
34+
{"content": "Hello!", "type": "assistant"},
35+
],
36+
"started_at": "2024-01-01T00:00:00Z",
37+
"completed_at": "2024-01-01T00:00:05Z",
38+
"provider": "provider ID",
39+
"model": "model ID",
40+
}
41+
],
42+
},
43+
400: {
44+
"description": "Missing or invalid credentials provided by client",
45+
"model": UnauthorizedResponse,
46+
},
47+
401: {
48+
"description": "Unauthorized: Invalid or missing Bearer token",
49+
"model": UnauthorizedResponse,
50+
},
51+
404: {
52+
"detail": {
53+
"response": "Conversation not found",
54+
"cause": "The specified conversation ID does not exist.",
55+
}
56+
},
57+
}
58+
59+
conversation_delete_responses: dict[int | str, dict[str, Any]] = {
60+
200: {
61+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
62+
"success": True,
63+
"message": "Conversation deleted successfully",
64+
},
65+
400: {
66+
"description": "Missing or invalid credentials provided by client",
67+
"model": UnauthorizedResponse,
68+
},
69+
401: {
70+
"description": "Unauthorized: Invalid or missing Bearer token",
71+
"model": UnauthorizedResponse,
72+
},
73+
404: {
74+
"detail": {
75+
"response": "Conversation not found",
76+
"cause": "The specified conversation ID does not exist.",
77+
}
78+
},
79+
}
80+
81+
conversations_list_responses: dict[int | str, dict[str, Any]] = {
82+
200: {
83+
"conversations": [
84+
{
85+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
86+
}
87+
]
88+
}
89+
}
90+
91+
92+
@router.get("/conversations", responses=conversations_list_responses)
93+
@authorize(Action.LIST_CONVERSATIONS)
94+
async def get_conversations_list_endpoint_handler(
95+
request: Request, # pylint: disable=unused-argument
96+
auth: Any = Depends(auth_dependency),
97+
) -> ConversationsListResponseV2:
98+
"""Handle request to retrieve all conversations for the authenticated user."""
99+
check_configuration_loaded(configuration)
100+
101+
user_id = auth[0]
102+
103+
logger.info("Retrieving conversations for user %s", user_id)
104+
105+
if configuration.conversation_cache is None:
106+
logger.warning("Converastion cache is not configured")
107+
raise HTTPException(
108+
status_code=status.HTTP_404_NOT_FOUND,
109+
detail={
110+
"response": "Conversation cache is not configured",
111+
"cause": "Conversation cache is not configured",
112+
},
113+
)
114+
115+
conversations = configuration.conversation_cache.list(user_id, False)
116+
logger.info("Conversations for user %s: %s", user_id, len(conversations))
117+
118+
return ConversationsListResponseV2(conversations=conversations)
119+
120+
121+
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
122+
@authorize(Action.GET_CONVERSATION)
123+
async def get_conversation_endpoint_handler(
124+
request: Request, # pylint: disable=unused-argument
125+
conversation_id: str,
126+
auth: Any = Depends(auth_dependency),
127+
) -> ConversationResponse:
128+
"""Handle request to retrieve a conversation by ID."""
129+
check_configuration_loaded(configuration)
130+
check_valid_conversation_id(conversation_id)
131+
132+
user_id = auth[0]
133+
logger.info("Retrieving conversation %s for user %s", conversation_id, user_id)
134+
135+
if configuration.conversation_cache is None:
136+
logger.warning("Converastion cache is not configured")
137+
raise HTTPException(
138+
status_code=status.HTTP_404_NOT_FOUND,
139+
detail={
140+
"response": "Conversation cache is not configured",
141+
"cause": "Conversation cache is not configured",
142+
},
143+
)
144+
145+
check_conversation_existence(user_id, conversation_id)
146+
147+
conversation = configuration.conversation_cache.get(user_id, conversation_id, False)
148+
chat_history = [transform_chat_message(entry) for entry in conversation]
149+
150+
return ConversationResponse(
151+
conversation_id=conversation_id, chat_history=chat_history
152+
)
153+
154+
155+
@router.delete(
156+
"/conversations/{conversation_id}", responses=conversation_delete_responses
157+
)
158+
@authorize(Action.DELETE_CONVERSATION)
159+
async def delete_conversation_endpoint_handler(
160+
request: Request, # pylint: disable=unused-argument
161+
conversation_id: str,
162+
auth: Any = Depends(auth_dependency),
163+
) -> ConversationDeleteResponse:
164+
"""Handle request to delete a conversation by ID."""
165+
check_configuration_loaded(configuration)
166+
check_valid_conversation_id(conversation_id)
167+
168+
user_id = auth[0]
169+
logger.info("Deleting conversation %s for user %s", conversation_id, user_id)
170+
171+
if configuration.conversation_cache is None:
172+
logger.warning("Converastion cache is not configured")
173+
raise HTTPException(
174+
status_code=status.HTTP_404_NOT_FOUND,
175+
detail={
176+
"response": "Conversation cache is not configured",
177+
"cause": "Conversation cache is not configured",
178+
},
179+
)
180+
181+
check_conversation_existence(user_id, conversation_id)
182+
183+
logger.info("Deleting conversation %s for user %s", conversation_id, user_id)
184+
deleted = configuration.conversation_cache.delete(user_id, conversation_id, False)
185+
186+
if deleted:
187+
return ConversationDeleteResponse(
188+
conversation_id=conversation_id,
189+
success=True,
190+
response="Conversation deleted successfully",
191+
)
192+
return ConversationDeleteResponse(
193+
conversation_id=conversation_id,
194+
success=True,
195+
response="Conversation can not be deleted",
196+
)
197+
198+
199+
def check_valid_conversation_id(conversation_id: str) -> None:
200+
"""Check validity of conversation ID format."""
201+
if not check_suid(conversation_id):
202+
logger.error("Invalid conversation ID format: %s", conversation_id)
203+
raise HTTPException(
204+
status_code=status.HTTP_400_BAD_REQUEST,
205+
detail={
206+
"response": "Invalid conversation ID format",
207+
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
208+
},
209+
)
210+
211+
212+
def check_conversation_existence(user_id: str, conversation_id: str) -> None:
213+
"""Check if conversation exists."""
214+
# checked already, but we need to make pyright happy
215+
if configuration.conversation_cache is None:
216+
return
217+
conversations = configuration.conversation_cache.list(user_id, False)
218+
if conversation_id not in conversations:
219+
logger.error("No conversation found for conversation ID %s", conversation_id)
220+
raise HTTPException(
221+
status_code=status.HTTP_404_NOT_FOUND,
222+
detail={
223+
"response": "Conversation not found",
224+
"cause": f"Conversation {conversation_id} could not be retrieved.",
225+
},
226+
)
227+
228+
229+
def transform_chat_message(entry: CacheEntry) -> dict[str, Any]:
230+
"""Transform the message read from cache into format used by response payload."""
231+
return {
232+
"provider": entry.provider,
233+
"model": entry.model,
234+
"query": entry.query,
235+
"response": entry.response,
236+
"messages": [
237+
{"content": entry.query, "type": "user"},
238+
{"content": entry.response, "type": "assistant"},
239+
],
240+
}

src/app/endpoints/query.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
check_configuration_loaded,
4848
get_agent,
4949
get_system_prompt,
50+
store_conversation_into_cache,
5051
validate_conversation_ownership,
5152
validate_model_provider_override,
5253
)
@@ -279,6 +280,16 @@ async def query_endpoint_handler(
279280
provider_id=provider_id,
280281
)
281282

283+
store_conversation_into_cache(
284+
configuration,
285+
user_id,
286+
conversation_id,
287+
provider_id,
288+
model_id,
289+
query_request.query,
290+
summary.llm_response,
291+
)
292+
282293
# Convert tool calls to response format
283294
logger.info("Processing tool calls...")
284295
tool_calls = [

src/app/endpoints/streaming_query.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
check_configuration_loaded,
4848
get_agent,
4949
get_system_prompt,
50+
store_conversation_into_cache,
5051
validate_model_provider_override,
5152
)
5253
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
@@ -704,6 +705,16 @@ async def response_generator(
704705
attachments=query_request.attachments or [],
705706
)
706707

708+
store_conversation_into_cache(
709+
configuration,
710+
user_id,
711+
conversation_id,
712+
provider_id,
713+
model_id,
714+
query_request.query,
715+
summary.llm_response,
716+
)
717+
707718
persist_user_conversation_details(
708719
user_id=user_id,
709720
conversation_id=conversation_id,

src/app/routers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
streaming_query,
1414
authorized,
1515
conversations,
16+
conversations_v2,
1617
metrics,
1718
)
1819

@@ -31,6 +32,7 @@ def include_routers(app: FastAPI) -> None:
3132
app.include_router(config.router, prefix="/v1")
3233
app.include_router(feedback.router, prefix="/v1")
3334
app.include_router(conversations.router, prefix="/v1")
35+
app.include_router(conversations_v2.router, prefix="/v2")
3436

3537
# road-core does not version these endpoints
3638
app.include_router(health.router)

src/models/responses.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,16 @@ class ConversationsListResponse(BaseModel):
667667
}
668668

669669

670+
class ConversationsListResponseV2(BaseModel):
671+
"""Model representing a response for listing conversations of a user.
672+
673+
Attributes:
674+
conversations: List of conversation IDs associated with the user.
675+
"""
676+
677+
conversations: list[str]
678+
679+
670680
class ErrorResponse(BaseModel):
671681
"""Model representing error response for query endpoint."""
672682

src/utils/endpoints.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from llama_stack_client.lib.agents.agent import AsyncAgent
77

88
import constants
9+
from models.cache_entry import CacheEntry
910
from models.requests import QueryRequest
1011
from models.database.conversations import UserConversation
1112
from models.config import Action
@@ -135,6 +136,31 @@ def validate_model_provider_override(
135136
)
136137

137138

139+
# # pylint: disable=R0913,R0917
140+
def store_conversation_into_cache(
141+
config: AppConfig,
142+
user_id: str,
143+
conversation_id: str,
144+
provider_id: str,
145+
model_id: str,
146+
query: str,
147+
response: str,
148+
) -> None:
149+
"""Store one part of conversation into conversation history cache."""
150+
if config.conversation_cache_configuration.type is not None:
151+
cache = config.conversation_cache
152+
if cache is None:
153+
logger.warning("Conversation cache configured but not initialized")
154+
return
155+
cache_entry = CacheEntry(
156+
query=query,
157+
response=response,
158+
provider=provider_id,
159+
model=model_id,
160+
)
161+
cache.insert_or_append(user_id, conversation_id, cache_entry, False)
162+
163+
138164
# # pylint: disable=R0913,R0917
139165
async def get_agent(
140166
client: AsyncLlamaStackClient,

tests/unit/app/test_routers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from app.endpoints import (
1010
conversations,
11+
conversations_v2,
1112
root,
1213
info,
1314
models,
@@ -60,7 +61,7 @@ def test_include_routers() -> None:
6061
include_routers(app)
6162

6263
# are all routers added?
63-
assert len(app.routers) == 11
64+
assert len(app.routers) == 12
6465
assert root.router in app.get_routers()
6566
assert info.router in app.get_routers()
6667
assert models.router in app.get_routers()
@@ -80,7 +81,7 @@ def test_check_prefixes() -> None:
8081
include_routers(app)
8182

8283
# are all routers added?
83-
assert len(app.routers) == 11
84+
assert len(app.routers) == 12
8485
assert app.get_router_prefix(root.router) == ""
8586
assert app.get_router_prefix(info.router) == "/v1"
8687
assert app.get_router_prefix(models.router) == "/v1"
@@ -92,3 +93,4 @@ def test_check_prefixes() -> None:
9293
assert app.get_router_prefix(authorized.router) == ""
9394
assert app.get_router_prefix(conversations.router) == "/v1"
9495
assert app.get_router_prefix(metrics.router) == ""
96+
assert app.get_router_prefix(conversations_v2.router) == "/v2"

0 commit comments

Comments
 (0)