Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 52 additions & 15 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Annotated, Any


from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage, Shield # type: ignore
Expand All @@ -25,7 +26,12 @@
from app.database import get_session
import metrics
from models.database.conversations import UserConversation
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.responses import (
QueryResponse,
UnauthorizedResponse,
ForbiddenResponse,
ReferencedDocument,
)
from models.requests import QueryRequest, Attachment
import constants
from utils.endpoints import (
Expand All @@ -36,15 +42,28 @@
)
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.metadata import (
extract_referenced_documents_from_steps,
)

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
auth_dependency = get_auth_dependency()


query_response: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM answer",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
},
400: {
"description": "Missing or invalid credentials provided by client",
Expand All @@ -54,7 +73,7 @@
"description": "User is not authorized",
"model": ForbiddenResponse,
},
503: {
500: {
"detail": {
"response": "Unable to connect to Llama Stack",
"cause": "Connection error.",
Expand Down Expand Up @@ -203,7 +222,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
response, conversation_id, referenced_documents = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand Down Expand Up @@ -237,7 +256,11 @@ async def query_endpoint_handler(
provider_id=provider_id,
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=response,
referenced_documents=referenced_documents,
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -381,7 +404,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[str, str, list[ReferencedDocument]]:
"""
Retrieve response from LLMs and agents.

Expand All @@ -404,8 +427,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.

Returns:
tuple[str, str]: A tuple containing the LLM or agent's response content
and the conversation ID.
tuple[str, str, list[ReferencedDocument]]: A tuple containing the response
content, the conversation ID, and the list of referenced documents parsed
from tool execution steps.
"""
available_input_shields = [
shield.identifier
Expand Down Expand Up @@ -485,26 +509,39 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
toolgroups=toolgroups,
)

# Check for validation errors in the response
# Check for validation errors and extract referenced documents
steps = getattr(response, "steps", [])
for step in steps:
if step.step_type == "shield_call" and step.violation:
if getattr(step, "step_type", "") == "shield_call" and getattr(
step, "violation", False
):
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
break

# Extract referenced documents from tool execution steps
referenced_documents = extract_referenced_documents_from_steps(steps)

# When stream=False, response should have output_message attribute
output_message = getattr(response, "output_message", None)
if output_message is not None:
content = getattr(output_message, "content", None)
if content is not None:
return str(content), conversation_id
response_text = str(content)
else:
response_text = ""
else:
# fallback
logger.warning(
"Response lacks output_message.content (conversation_id=%s)",
conversation_id,
)
response_text = ""

# fallback
logger.warning(
"Response lacks output_message.content (conversation_id=%s)",
return (
response_text,
conversation_id,
referenced_documents,
)
return "", conversation_id


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
65 changes: 40 additions & 25 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Handler for REST API call to provide answer to streaming query."""

import ast
import json
import re
import logging
from typing import Annotated, Any, AsyncIterator, Iterator

import pydantic

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
Expand All @@ -24,8 +24,10 @@
import metrics
from models.requests import QueryRequest
from models.database.conversations import UserConversation
from models.responses import ReferencedDocument
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.metadata import parse_knowledge_search_metadata

from app.endpoints.query import (
get_rag_toolgroups,
Expand All @@ -45,9 +47,6 @@
auth_dependency = get_auth_dependency()


METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")


def format_stream_data(d: dict) -> str:
"""
Format a dictionary as a Server-Sent Events (SSE) data string.
Expand Down Expand Up @@ -102,20 +101,36 @@ def stream_end_event(metadata_map: dict) -> str:
str: A Server-Sent Events (SSE) formatted string
representing the end of the data stream.
"""
# Create ReferencedDocument objects and convert them to serializable dict format
referenced_documents = []
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
):
try:
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
referenced_documents.append(
{
"doc_url": str(
doc.doc_url
), # Convert AnyUrl to string for JSON serialization
"doc_title": doc.doc_title,
}
)
except (pydantic.ValidationError, ValueError) as e:
logger.warning(
"Skipping invalid referenced document with docs_url='%s', title='%s': %s",
v.get("docs_url", "<missing>"),
v.get("title", "<missing>"),
str(e),
)
continue

return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": [
{
"doc_url": v["docs_url"],
"doc_title": v["title"],
}
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
)
],
"referenced_documents": referenced_documents,
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
Expand Down Expand Up @@ -435,16 +450,16 @@ def _handle_tool_execution_event(
newline_pos = summary.find("\n")
if newline_pos > 0:
summary = summary[:newline_pos]
for match in METADATA_PATTERN.findall(text_content_item.text):
try:
meta = ast.literal_eval(match)
if "document_id" in meta:
metadata_map[meta["document_id"]] = meta
except Exception: # pylint: disable=broad-except
logger.debug(
"An exception was thrown in processing %s",
match,
)
try:
parsed_metadata = parse_knowledge_search_metadata(
text_content_item.text, strict=False
)
metadata_map.update(parsed_metadata)
except ValueError as e:
logger.exception(
"Error processing metadata from text; position=%s",
getattr(e, "position", "unknown"),
)

yield format_stream_data(
{
Expand Down
40 changes: 37 additions & 3 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, AnyUrl


class ModelsResponse(BaseModel):
Expand Down Expand Up @@ -36,21 +36,30 @@ class ModelsResponse(BaseModel):

# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
# we are keeping it simple. The missing fields are:
# - referenced_documents: The optional URLs and titles for the documents used
# to generate the response.
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
# - available_quotas: Quota available as measured by all configured quota limiters
# - tool_calls: List of tool requests.
# - tool_results: List of tool results.
# See LLMResponse in ols-service for more details.


class ReferencedDocument(BaseModel):
"""Model representing a document referenced in generating a response."""

doc_url: AnyUrl = Field(description="URL of the referenced document")
doc_title: str = Field(description="Title of the referenced document")


class QueryResponse(BaseModel):
"""Model representing LLM response to a query.

Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
referenced_documents: The optional URLs and titles for the documents used
to generate the response.
"""

conversation_id: Optional[str] = Field(
Expand All @@ -66,13 +75,38 @@ class QueryResponse(BaseModel):
],
)

referenced_documents: list[ReferencedDocument] = Field(
default_factory=list,
description="List of documents referenced in generating the response",
examples=[
[
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
]
],
)

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "Operator Lifecycle Manager (OLM) helps users install...",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
}
]
}
Expand Down
Loading