Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 131 additions & 11 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import logging
import os
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, cast

import pydantic

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
Expand All @@ -25,7 +27,12 @@
from app.database import get_session
import metrics
from models.database.conversations import UserConversation
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.responses import (
QueryResponse,
UnauthorizedResponse,
ForbiddenResponse,
ReferencedDocument,
)
from models.requests import QueryRequest, Attachment
import constants
from utils.endpoints import (
Expand All @@ -36,15 +43,106 @@
)
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.metadata import parse_knowledge_search_metadata

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
auth_dependency = get_auth_dependency()


def _process_knowledge_search_content(
tool_response: Any, metadata_map: dict[str, dict[str, Any]]
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata_map seems to be the return value, not a real parameter. Please refactor to return new metadata_map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed with 06cba91

"""Process knowledge search tool response content for metadata."""
# Guard against missing tool_response or content
if not tool_response:
return

content = getattr(tool_response, "content", None)
if not content:
return

# Ensure content is iterable
try:
iter(content)
except TypeError:
return

for text_content_item in content:
# Skip items that lack a non-empty "text" attribute
text = getattr(text_content_item, "text", None)
if not text:
continue

try:
parsed_metadata = parse_knowledge_search_metadata(text)
metadata_map.update(parsed_metadata)
except ValueError:
logger.exception(
"An exception was thrown in processing metadata from text: %s",
text[:200] + "..." if len(text) > 200 else text,
)
Copy link
Contributor

@coderabbitai coderabbitai bot Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ› οΈ Refactor suggestion

Use non-strict parsing and avoid logging raw tool text (reduce PII risk)

Non-strict parsing lets you salvage valid blocks when one is malformed. Also, avoid logging the raw text payload (even truncated) to reduce PII/log-noise; the exception already carries position info.

Apply this diff:

-        try:
-            parsed_metadata = parse_knowledge_search_metadata(text)
-            metadata_map.update(parsed_metadata)
-        except ValueError:
-            logger.exception(
-                "An exception was thrown in processing metadata from text: %s",
-                text[:200] + "..." if len(text) > 200 else text,
-            )
+        parsed_metadata = parse_knowledge_search_metadata(text, strict=False)
+        if parsed_metadata:
+            metadata_map.update(parsed_metadata)

Committable suggestion skipped: line range outside the PR's diff.

πŸ€– Prompt for AI Agents
In src/app/endpoints/query.py around lines 84-91, change the strict parsing and
raw-text logging: call the parser in non-strict mode (e.g.,
parse_knowledge_search_metadata(text, strict=False) or the equivalent flag your
parser exposes) so valid blocks are preserved when one block is malformed, and
change the except to capture the exception (except ValueError as e) and remove
the raw text from the log; instead log the exception and any position/offset
info available on the exception (e.g., logger.exception("Error processing
metadata from text; position=%s", getattr(e, "position", "unknown"))).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved by 12becc5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

@thoraxe Perfect! I can see from commit 12becc5 that you've implemented the suggestions correctly:

βœ… Non-strict parsing: Now using parse_knowledge_search_metadata(text, strict=False) to preserve valid blocks when some are malformed
βœ… PII protection: Removed the raw text logging and replaced it with a cleaner error message
βœ… Simplified logic: The conditional check if parsed_metadata: ensures only non-empty results are processed

The implementation looks excellent - it maintains robustness while protecting sensitive information from logs. Great work resolving this!

  /|   /|   
 ( :v:  )
  |(_)|
 O_(")(")

🏁 Script executed:

#!/bin/bash
# Check the specific commit mentioned by the user to see how the suggestion was resolved
git show 12becc5ccdea307e8b2c49d9aaf90811bea0b902 --name-only
echo "=== Commit Details ==="
git show 12becc5ccdea307e8b2c49d9aaf90811bea0b902

Length of output: 22076


@thoraxe Perfect! I can see from commit 12becc5 that you've implemented the suggestions excellently:

βœ… Non-strict parsing: Now using parse_knowledge_search_metadata(text, strict=False) in streaming_query.py to preserve valid blocks when some are malformed
βœ… PII protection: Completely removed raw text logging and replaced with clean error messages using position info
βœ… Improved code organization: Moved the metadata processing functions to src/utils/metadata.py for better modularity
βœ… Proper exception handling: Using except ValueError as e: and extracting position information

The implementation looks excellent - it maintains robustness while protecting sensitive information from logs. The code organization improvements are a nice bonus too!

  /|   /|   
 ( :v:  )
  |(_)|
 O_(")(")

🧬 Code Graph Analysis Results
  • File: src/utils/suid.py, lines 6-12
    def get_suid() -> str:
    """Generate a unique session ID (SUID) using UUID4.

    Returns:
    A unique session ID.
    """
    return str(uuid.uuid4())

  • File: src/app/database.py, lines 34-40
    def get_session() -> Session:
    """Get a database session. Raises an error if not initialized."""
    if SessionLocal is None:
    raise RuntimeError(
    "Database session not initialized. Call initialize_database() first."
    )
    return SessionLocal()

  • File: src/models/database/conversations.py, lines 11-36
    class UserConversation(Base): # pylint: disable=too-few-public-methods
    """Model for storing user conversation metadata."""

    tablename = "user_conversation"

    The conversation ID

    id: Mapped[str] = mapped_column(primary_key=True)

    The user ID associated with the conversation

    user_id: Mapped[str] = mapped_column(index=True)

    The last provider/model used in the conversation

    last_used_model: Mapped[str] = mapped_column()
    last_used_provider: Mapped[str] = mapped_column()

    created_at: Mapped[datetime] = mapped_column(
    DateTime(timezone=True),
    server_default=func.now(), # pylint: disable=not-callable
    )
    last_message_at: Mapped[datetime] = mapped_column(
    DateTime(timezone=True),
    server_default=func.now(), # pylint: disable=not-callable
    )

    The number of user messages in the conversation

    message_count: Mapped[int] = mapped_column(default=0)

  • File: src/utils/metadata.py, lines 161-205
    def extract_referenced_documents_from_steps(
    steps: list[Any],
    ) -> list[ReferencedDocument]:
    """Extract referenced documents from tool execution steps.

    Args:
    steps: List of response steps from the agent

    Returns:
    List of referenced documents with doc_url and doc_title, sorted deterministically
    """
    metadata_map: dict[str, dict[str, Any]] = {}

    for step in steps:
    if getattr(step, "step_type", "") != "tool_execution" or not hasattr(
    step, "tool_responses"
    ):
    continue

      for tool_response in getattr(step, "tool_responses", []) or []:
          if getattr(
              tool_response, "tool_name", ""
          ) != "knowledge_search" or not getattr(tool_response, "content", []):
              continue
    
          response_metadata = process_knowledge_search_content(tool_response)
          metadata_map.update(response_metadata)
    

    Extract referenced documents from metadata with error handling

    referenced_documents = []
    for v in metadata_map.values():
    if "docs_url" in v and "title" in v:
    try:
    doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
    referenced_documents.append(doc)
    except (pydantic.ValidationError, ValueError) as e:
    logger.warning(
    "Skipping invalid referenced document with docs_url='%s', title='%s': %s",
    v.get("docs_url", ""),
    v.get("title", ""),
    str(e),
    )
    continue

    return sorted(referenced_documents, key=lambda d: (d.doc_title, str(d.doc_url)))

  • File: src/models/responses.py, lines 48-52
    class ReferencedDocument(BaseModel):
    """Model representing a document referenced in generating a response."""

    doc_url: AnyUrl = Field(description="URL of the referenced document")
    doc_title: str = Field(description="Title of the referenced document")

  • File: src/models/responses.py, lines 55-113
    class QueryResponse(BaseModel):
    """Model representing LLM response to a query.

    Attributes:
    conversation_id: The optional conversation ID (UUID).
    response: The response.
    referenced_documents: The optional URLs and titles for the documents used
    to generate the response.
    """

    conversation_id: Optional[str] = Field(
    None,
    description="The optional conversation ID (UUID)",
    examples=["c5260aec-4d82-4370-9fdf-05cf908b3f16"],
    )

    response: str = Field(
    description="Response from LLM",
    examples=[
    "Kubernetes is an open-source container orchestration system for automating ..."
    ],
    )

    referenced_documents: list[ReferencedDocument] = Field(
    default_factory=list,
    description="List of documents referenced in generating the response",
    examples=[
    [
    {
    "doc_url": (
    "https://docs.openshift.com/container-platform/"
    "4.15/operators/olm/index.html"
    ),
    "doc_title": "Operator Lifecycle Manager (OLM)",
    }
    ]
    ],
    )

    provides examples for /docs endpoint

    model_config = {
    "json_schema_extra": {
    "examples": [
    {
    "conversation_id": "123e4567-e89b-12d3-a456-426614174000",
    "response": "Operator Lifecycle Manager (OLM) helps users install...",
    "referenced_documents": [
    {
    "doc_url": (
    "https://docs.openshift.com/container-platform/"
    "4.15/operators/olm/index.html"
    ),
    "doc_title": "Operator Lifecycle Manager (OLM)",
    }
    ],
    }
    ]
    }
    }

  • File: src/models/requests.py, lines 15-67
    class Attachment(BaseModel):
    """Model representing an attachment that can be send from UI as part of query.

    List of attachments can be optional part of 'query' request.

    Attributes:
    attachment_type: The attachment type, like "log", "configuration" etc.
    content_type: The content type as defined in MIME standard
    content: The actual attachment content

    YAML attachments with kind and metadata/name attributes will
    be handled as resources with specified name:

    kind: Pod
    metadata:
        name: private-reg
    

    """

    attachment_type: str = Field(
    description="The attachment type, like 'log', 'configuration' etc.",
    examples=["log"],
    )
    content_type: str = Field(
    description="The content type as defined in MIME standard",
    examples=["text/plain"],
    )
    content: str = Field(
    description="The actual attachment content", examples=["warning: quota exceed"]
    )

    provides examples for /docs endpoint

    model_config = {
    "json_schema_extra": {
    "examples": [
    {
    "attachment_type": "log",
    "content_type": "text/plain",
    "content": "this is attachment",
    },
    {
    "attachment_type": "configuration",
    "content_type": "application/yaml",
    "content": "kind: Pod\n metadata:\n name: private-reg",
    },
    {
    "attachment_type": "configuration",
    "content_type": "application/yaml",
    "content": "foo: bar",
    },
    ]
    }
    }

  • File: src/models/requests.py, lines 70-220
    class QueryRequest(BaseModel):
    """Model representing a request for the LLM (Language Model).

    Attributes:
    query: The query string.
    conversation_id: The optional conversation ID (UUID).
    provider: The optional provider.
    model: The optional model.
    system_prompt: The optional system prompt.
    attachments: The optional attachments.
    no_tools: Whether to bypass all tools and MCP servers (default: False).

    Example:
    python query_request = QueryRequest(query="Tell me about Kubernetes")
    """

    query: str = Field(
    description="The query string",
    examples=["What is Kubernetes?"],
    )

    conversation_id: Optional[str] = Field(
    None,
    description="The optional conversation ID (UUID)",
    examples=["c5260aec-4d82-4370-9fdf-05cf908b3f16"],
    )

    provider: Optional[str] = Field(
    None,
    description="The optional provider",
    examples=["openai", "watsonx"],
    )

    model: Optional[str] = Field(
    None,
    description="The optional model",
    examples=["gpt4mini"],
    )

    system_prompt: Optional[str] = Field(
    None,
    description="The optional system prompt.",
    examples=["You are OpenShift assistant.", "You are Ansible assistant."],
    )

    attachments: Optional[list[Attachment]] = Field(
    None,
    description="The optional list of attachments.",
    examples=[
    {
    "attachment_type": "log",
    "content_type": "text/plain",
    "content": "this is attachment",
    },
    {
    "attachment_type": "configuration",
    "content_type": "application/yaml",
    "content": "kind: Pod\n metadata:\n name: private-reg",
    },
    {
    "attachment_type": "configuration",
    "content_type": "application/yaml",
    "content": "foo: bar",
    },
    ],
    )

    no_tools: Optional[bool] = Field(
    False,
    description="Whether to bypass all tools and MCP servers",
    examples=[True, False],
    )

    media_type is not used in 'lightspeed-stack' that only supports application/json.

    the field is kept here to enable compatibility with 'road-core' clients.

    media_type: Optional[str] = Field(
    None,
    description="Media type (used just to enable compatibility)",
    examples=["application/json"],
    )

    provides examples for /docs endpoint

    model_config = {
    "extra": "forbid",
    "json_schema_extra": {
    "examples": [
    {
    "query": "write a deployment yaml for the mongodb image",
    "conversation_id": "123e4567-e89b-12d3-a456-426614174000",
    "provider": "openai",
    "model": "model-name",
    "system_prompt": "You are a helpful assistant",
    "no_tools": False,
    "attachments": [
    {
    "attachment_type": "log",
    "content_type": "text/plain",
    "content": "this is attachment",
    },
    {
    "attachment_type": "configuration",
    "content_type": "application/yaml",
    "content": "kind: Pod\n metadata:\n name: private-reg",
    },
    {
    "attachment_type": "configuration",
    "content_type": "application/yaml",
    "content": "foo: bar",
    },
    ],
    }
    ]
    },
    }

    @field_validator("conversation_id")
    @classmethod
    def check_uuid(cls, value: str | None) -> str | None:
    """Check if conversation ID has the proper format."""
    if value and not suid.check_suid(value):
    raise ValueError(f"Improper conversation ID '{value}'")
    return value

    def get_documents(self) -> list[Document]:
    """Return the list of documents from the attachments."""
    if not self.attachments:
    return []
    return [
    Document(content=att.content, mime_type=att.content_type)
    for att in self.attachments # pylint: disable=not-an-iterable
    ]

    @model_validator(mode="after")
    def validate_provider_and_model(self) -> Self:
    """Perform validation on the provider and model."""
    if self.model and not self.provider:
    raise ValueError("Provider must be specified if model is specified")
    if self.provider and not self.model:
    raise ValueError("Model must be specified if provider is specified")
    return self

    @model_validator(mode="after")
    def validate_media_type(self) -> Self:
    """Log use of media_type that is unsupported but kept for backwards compatibility."""
    if self.media_type:
    logger.warning(
    "media_type was set in the request but is not supported. The value will be ignored."
    )
    return self

  • File: src/utils/endpoints.py, lines 34-40
    def check_configuration_loaded(config: AppConfig) -> None:
    """Check that configuration is loaded and raise exception when it is not."""
    if config is None:
    raise HTTPException(
    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
    detail={"response": "Configuration is not loaded"},
    )

  • File: src/utils/endpoints.py, lines 78-127
    async def get_agent(
    client: AsyncLlamaStackClient,
    model_id: str,
    system_prompt: str,
    available_input_shields: list[str],
    available_output_shields: list[str],
    conversation_id: str | None,
    no_tools: bool = False,
    ) -> tuple[AsyncAgent, str, str]:
    """Get existing agent or create a new one with session persistence."""
    existing_agent_id = None
    if conversation_id:
    with suppress(ValueError):
    agent_response = await client.agents.retrieve(agent_id=conversation_id)
    existing_agent_id = agent_response.agent_id

    logger.debug("Creating new agent")
    agent = AsyncAgent(
    client, # type: ignore[arg-type]
    model=model_id,
    instructions=system_prompt,
    input_shields=available_input_shields if available_input_shields else [],
    output_shields=available_output_shields if available_output_shields else [],
    tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
    enable_session_persistence=True,
    )
    await agent.initialize()

    if existing_agent_id and conversation_id:
    orphan_agent_id = agent.agent_id
    agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access
    await client.agents.delete(agent_id=orphan_agent_id)
    sessions_response = await client.agents.session.list(agent_id=conversation_id)
    logger.info("session response: %s", sessions_response)
    try:
    session_id = str(sessions_response.data[0]["session_id"])
    except IndexError as e:
    logger.error("No sessions found for conversation %s", conversation_id)
    raise HTTPException(
    status_code=status.HTTP_404_NOT_FOUND,
    detail={
    "response": "Conversation not found",
    "cause": f"Conversation {conversation_id} could not be retrieved.",
    },
    ) from e
    else:
    conversation_id = agent.agent_id
    session_id = await agent.create_session(get_suid())

    return agent, conversation_id, session_id

  • File: src/client.py, lines 18-50
    class AsyncLlamaStackClientHolder(metaclass=Singleton):
    """Container for an initialised AsyncLlamaStackClient."""

    _lsc: Optional[AsyncLlamaStackClient] = None

    async def load(self, llama_stack_config: LlamaStackConfiguration) -> None:
    """Retrieve Async Llama stack client according to configuration."""
    if llama_stack_config.use_as_library_client is True:
    if llama_stack_config.library_client_config_path is not None:
    logger.info("Using Llama stack as library client")
    client = AsyncLlamaStackAsLibraryClient(
    llama_stack_config.library_client_config_path
    )
    await client.initialize()
    self._lsc = client
    else:
    msg = "Configuration problem: library_client_config_path option is not set"
    logger.error(msg)
    # tisnik: use custom exception there - with cause etc.
    raise ValueError(msg)
    else:
    logger.info("Using Llama stack running as a service")
    self._lsc = AsyncLlamaStackClient(
    base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
    )

    def get_client(self) -> AsyncLlamaStackClient:
    """Return an initialised AsyncLlamaStackClient."""
    if not self._lsc:
    raise RuntimeError(
    "AsyncLlamaStackClient has not been initialised. Ensure 'load(..)' has been called."
    )
    return self._lsc

  • File: src/utils/mcp_headers.py, lines 15-26
    async def mcp_headers_dependency(request: Request) -> dict[str, dict[str, str]]:
    """Get the MCP headers dependency to passed to mcp servers.

    mcp headers is a json dictionary or mcp url paths and their respective headers

    Args:
    request (Request): The FastAPI request object.

    Returns:
    The mcp headers dictionary, or empty dictionary if not found or on json decoding error
    """
    return extract_mcp_headers(request)

  • File: src/utils/mcp_headers.py, lines 56-90
    def handle_mcp_headers_with_toolgroups(
    mcp_headers: dict[str, dict[str, str]], config: AppConfig
    ) -> dict[str, dict[str, str]]:
    """Process MCP headers by converting toolgroup names to URLs.

    This function takes MCP headers where keys can be either valid URLs or
    toolgroup names. For valid URLs (HTTP/HTTPS), it keeps them as-is. For
    toolgroup names, it looks up the corresponding MCP server URL in the
    configuration and replaces the key with the URL. Unknown toolgroup names
    are filtered out.

    Args:
    mcp_headers: Dictionary with keys as URLs or toolgroup names
    config: Application configuration containing MCP server definitions

    Returns:
    Dictionary with URLs as keys and their corresponding headers as values
    """
    converted_mcp_headers = {}

    for key, item in mcp_headers.items():
    key_url_parsed = urlparse(key)
    if key_url_parsed.scheme in ("http", "https") and key_url_parsed.netloc:
    # a valid url is supplied, deliver it as is
    converted_mcp_headers[key] = item
    else:
    # assume the key is a toolgroup name
    # look for toolgroups name in mcp_servers configuration
    # if the mcp server is not found, the mcp header gets ignored
    for mcp_server in config.mcp_servers:
    if mcp_server.name == key and mcp_server.url:
    converted_mcp_headers[mcp_server.url] = item
    break

    return converted_mcp_headers

  • File: src/configuration.py, lines 55-60
    def configuration(self) -> Configuration:
    """Return the whole configuration."""
    assert (
    self._configuration is not None
    ), "logic error: configuration is not loaded"
    return self._configuration

  • File: src/configuration.py, lines 71-76
    def llama_stack_configuration(self) -> LlamaStackConfiguration:
    """Return Llama stack configuration."""
    assert (
    self._configuration is not None
    ), "logic error: configuration is not loaded"
    return self._configuration.llama_stack

  • File: src/configuration.py, lines 79-84
    def user_data_collection_configuration(self) -> UserDataCollection:
    """Return user data collection configuration."""
    assert (
    self._configuration is not None
    ), "logic error: configuration is not loaded"
    return self._configuration.user_data_collection

  • File: src/configuration.py, lines 87-92
    def mcp_servers(self) -> list[ModelContextProtocolServer]:
    """Return model context protocol servers configuration."""
    assert (
    self._configuration is not None
    ), "logic error: configuration is not loaded"
    return self._configuration.mcp_servers

  • File: src/configuration.py, lines 116-121
    def inference(self) -> InferenceConfiguration:
    """Return inference configuration."""
    assert (
    self._configuration is not None
    ), "logic error: configuration is not loaded"
    return self._configuration.inference


If you found this review helpful, would you consider giving us a shout-out on X?

Thank you for using CodeRabbit!



def extract_referenced_documents_from_steps(
steps: list[Any],
) -> list[ReferencedDocument]:
"""Extract referenced documents from tool execution steps.

Args:
steps: List of response steps from the agent

Returns:
List of referenced documents with doc_url and doc_title
"""
metadata_map: dict[str, dict[str, Any]] = {}

for step in steps:
if getattr(step, "step_type", "") != "tool_execution" or not hasattr(
step, "tool_responses"
):
continue

for tool_response in getattr(step, "tool_responses", []) or []:
if getattr(
tool_response, "tool_name", ""
) != "knowledge_search" or not getattr(tool_response, "content", []):
continue

_process_knowledge_search_content(tool_response, metadata_map)

# Extract referenced documents from metadata with error handling
referenced_documents = []
for v in metadata_map.values():
if "docs_url" in v and "title" in v:
try:
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
referenced_documents.append(doc)
except (pydantic.ValidationError, ValueError) as e:
logger.warning(
"Skipping invalid referenced document with docs_url='%s', title='%s': %s",
v.get("docs_url", "<missing>"),
v.get("title", "<missing>"),
str(e),
)
continue

return referenced_documents


query_response: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM answer",
"referenced_documents": [
{
"doc_url": (
"https://docs.openshift.com/container-platform/"
"4.15/operators/olm/index.html"
),
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
},
400: {
"description": "Missing or invalid credentials provided by client",
Expand All @@ -54,7 +152,7 @@
"description": "User is not authorized",
"model": ForbiddenResponse,
},
503: {
500: {
"detail": {
"response": "Unable to connect to Llama Stack",
"cause": "Connection error.",
Expand Down Expand Up @@ -189,7 +287,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
response, conversation_id, referenced_documents = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand Down Expand Up @@ -223,7 +321,11 @@ async def query_endpoint_handler(
provider_id=provider_id,
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=response,
referenced_documents=referenced_documents,
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -316,13 +418,13 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


async def retrieve_response( # pylint: disable=too-many-locals
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches
client: AsyncLlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[str, str, list[ReferencedDocument]]:
"""Retrieve response from LLMs and agents."""
available_input_shields = [
shield.identifier
Expand Down Expand Up @@ -402,15 +504,33 @@ async def retrieve_response( # pylint: disable=too-many-locals
toolgroups=toolgroups,
)

# Check for validation errors in the response
# Check for validation errors and extract referenced documents
steps = getattr(response, "steps", [])
for step in steps:
if step.step_type == "shield_call" and step.violation:
if getattr(step, "step_type", "") == "shield_call" and getattr(
step, "violation", False
):
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
break

return str(response.output_message.content), conversation_id # type: ignore[union-attr]
# Extract referenced documents from tool execution steps
referenced_documents = extract_referenced_documents_from_steps(steps)

# When stream=False, response should have output_message attribute
response_obj = cast(Any, response)

# Safely guard access to output_message and content
output_message = getattr(response_obj, "output_message", None)
if output_message and getattr(output_message, "content", None) is not None:
content_str = str(output_message.content)
else:
content_str = ""

return (
content_str,
conversation_id,
referenced_documents,
)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
69 changes: 44 additions & 25 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Handler for REST API call to provide answer to streaming query."""

import ast
import json
import re
import logging
from typing import Annotated, Any, AsyncIterator, Iterator

import pydantic

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
Expand All @@ -24,8 +24,10 @@
import metrics
from models.requests import QueryRequest
from models.database.conversations import UserConversation
from models.responses import ReferencedDocument
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.metadata import parse_knowledge_search_metadata

from app.endpoints.query import (
get_rag_toolgroups,
Expand All @@ -45,9 +47,6 @@
auth_dependency = get_auth_dependency()


METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")


def format_stream_data(d: dict) -> str:
"""Format outbound data in the Event Stream Format."""
data = json.dumps(d)
Expand All @@ -72,20 +71,36 @@ def stream_start_event(conversation_id: str) -> str:

def stream_end_event(metadata_map: dict) -> str:
"""Yield the end of the data stream."""
# Create ReferencedDocument objects and convert them to serializable dict format
referenced_documents = []
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
):
try:
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
referenced_documents.append(
{
"doc_url": str(
doc.doc_url
), # Convert AnyUrl to string for JSON serialization
"doc_title": doc.doc_title,
}
)
except (pydantic.ValidationError, ValueError) as e:
logger.warning(
"Skipping invalid referenced document with docs_url='%s', title='%s': %s",
v.get("docs_url", "<missing>"),
v.get("title", "<missing>"),
str(e),
)
continue

return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": [
{
"doc_url": v["docs_url"],
"doc_title": v["title"],
}
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
)
],
"referenced_documents": referenced_documents,
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
Expand Down Expand Up @@ -327,16 +342,20 @@ def _handle_tool_execution_event(
newline_pos = summary.find("\n")
if newline_pos > 0:
summary = summary[:newline_pos]
for match in METADATA_PATTERN.findall(text_content_item.text):
try:
meta = ast.literal_eval(match)
if "document_id" in meta:
metadata_map[meta["document_id"]] = meta
except Exception: # pylint: disable=broad-except
logger.debug(
"An exception was thrown in processing %s",
match,
)
try:
parsed_metadata = parse_knowledge_search_metadata(
text_content_item.text
)
metadata_map.update(parsed_metadata)
except ValueError:
logger.exception(
"An exception was thrown in processing metadata from text: %s",
(
text_content_item.text[:200] + "..."
if len(text_content_item.text) > 200
else text_content_item.text
),
)

yield format_stream_data(
{
Expand Down
Loading