Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion src/app/endpoints/conversations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Handler for REST API calls to manage conversation history."""

import logging
import re
from typing import Any

from llama_stack_client import APIConnectionError, NotFoundError
Expand All @@ -20,6 +21,60 @@

conversation_id_to_agent_id: dict[str, str] = {}


def parse_attachments_from_content(
content: list[dict[str, str]],
) -> list[dict[str, str]] | None:
"""Parse attachment information from content items.

The content structure is:
- Index 0: User message
- Index 1: <attachments_info> section with metadata
- Index 2+: Actual attachment content (one per attachment line)

Args:
content: The content list of TextContentItems

Returns:
List of attachment dictionaries or None if no attachments found
"""
if len(content) < 2:
return None

attachments_info_text = content[1].get("text", "")
attachments_pattern = r"<attachments_info>\n(.*?)\n</attachments_info>"
attachments_match = re.search(attachments_pattern, attachments_info_text, re.DOTALL)

if not attachments_match:
return None

attachments_text = attachments_match.group(1)
attachments_info: list[dict[str, str]] = []

attachment_lines = []
for line in attachments_text.strip().split("\n"):
line = line.strip()
if line:
attachment_lines.append(line)
# Parse: "attachment_type: value, content_type: value"
attachment_match = re.match(
r"attachment_type:\s*([^,]+),\s*content_type:\s*(.+)", line
)
if attachment_match:
attachment_info = {
"attachment_type": attachment_match.group(1).strip(),
"content_type": attachment_match.group(2).strip(),
}
attachments_info.append(attachment_info)

for i, attachment_info in enumerate(attachments_info):
content_index = 2 + i
if content_index < len(content):
attachment_info["content"] = content[content_index].get("text", "")

return attachments_info if attachments_info else None


conversation_responses: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand Down Expand Up @@ -82,9 +137,14 @@ def simplify_session_data(session_data: Any) -> list[dict[str, Any]]:
# Clean up input messages
cleaned_messages = []
for msg in turn.get("input_messages", []):
content = msg.get("content", "")
# Parse attachments from content (handles both string and list of TextContentItems)
attachments = parse_attachments_from_content(content)

cleaned_msg = {
"content": msg.get("content"),
"content": content[0].get("text"),
"type": msg.get("role"), # Rename role to type
"attachments": attachments,
}
cleaned_messages.append(cleaned_msg)

Expand Down
22 changes: 21 additions & 1 deletion src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Toolgroup,
)
from llama_stack_client.types.model_list_response import ModelListResponse
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem

from fastapi import APIRouter, HTTPException, status, Depends

Expand Down Expand Up @@ -295,8 +296,27 @@ def retrieve_response( # pylint: disable=too-many-locals
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]

attachment_lines = [
f"attachment_type: {attachment.attachment_type}, "
f"content_type: {attachment.content_type}"
for attachment in (query_request.attachments or [])
]

response = agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
messages=[
UserMessage(
role="user",
content=[
TextContentItem(type="text", text=query_request.query),
TextContentItem(
type="text",
text=f"<attachments_info>\n{'\n'.join(attachment_lines)}\n"
f"</attachments_info>",
),
],
)
],
session_id=conversation_id,
documents=query_request.get_documents(),
stream=False,
Expand Down
29 changes: 25 additions & 4 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,9 @@ async def retrieve_response(
available_output_shields,
)
# use system prompt from request or default one
system_prompt = get_system_prompt(query_request, configuration)
logger.debug("Using system prompt: %s", system_prompt)
logger.debug(
"Using system prompt: %s", get_system_prompt(query_request, configuration)
)

# TODO(lucasagomes): redact attachments content before sending to LLM
# if attachments are provided, validate them
Expand All @@ -528,7 +529,7 @@ async def retrieve_response(
agent, conversation_id = await get_agent(
client,
model_id,
system_prompt,
get_system_prompt(query_request, configuration),
available_input_shields,
available_output_shields,
query_request.conversation_id,
Expand Down Expand Up @@ -561,8 +562,28 @@ async def retrieve_response(
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]

# Generate attachment info lines
attachment_lines = [
f"attachment_type: {attachment.attachment_type}, "
f"content_type: {attachment.content_type}"
for attachment in (query_request.attachments or [])
]

response = await agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
messages=[
UserMessage(
role="user",
content=[
TextContentItem(type="text", text=query_request.query),
TextContentItem(
type="text",
text=f"<attachments_info>\n{'\n'.join(attachment_lines)}\n"
f"</attachments_info>",
),
],
)
],
session_id=conversation_id,
documents=query_request.get_documents(),
stream=True,
Expand Down
Loading