-
Notifications
You must be signed in to change notification settings - Fork 49
implements referenced documents on /query and updates /streaming_query to match #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
f022501
8c0a1d0
d6e1475
866bb5c
3026397
d5e2622
a7f815e
af646c6
06cba91
1f8fbb9
1256193
1fa7b87
dcf4f23
12becc5
0291e11
a463e6f
36a5604
622ae9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,9 @@ | |
| import logging | ||
| import os | ||
| from pathlib import Path | ||
| from typing import Annotated, Any | ||
| from typing import Annotated, Any, cast | ||
|
|
||
| import pydantic | ||
|
|
||
| from llama_stack_client import APIConnectionError | ||
| from llama_stack_client import AsyncLlamaStackClient # type: ignore | ||
|
|
@@ -25,7 +27,12 @@ | |
| from app.database import get_session | ||
| import metrics | ||
| from models.database.conversations import UserConversation | ||
| from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse | ||
| from models.responses import ( | ||
| QueryResponse, | ||
| UnauthorizedResponse, | ||
| ForbiddenResponse, | ||
| ReferencedDocument, | ||
| ) | ||
| from models.requests import QueryRequest, Attachment | ||
| import constants | ||
| from utils.endpoints import ( | ||
|
|
@@ -36,15 +43,116 @@ | |
| ) | ||
| from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups | ||
| from utils.suid import get_suid | ||
| from utils.metadata import parse_knowledge_search_metadata | ||
|
|
||
| logger = logging.getLogger("app.endpoints.handlers") | ||
| router = APIRouter(tags=["query"]) | ||
| auth_dependency = get_auth_dependency() | ||
|
|
||
|
|
||
| def _process_knowledge_search_content(tool_response: Any) -> dict[str, dict[str, Any]]: | ||
| """Process knowledge search tool response content for metadata. | ||
|
|
||
| Args: | ||
| tool_response: Tool response object containing content to parse | ||
|
|
||
| Returns: | ||
| Dictionary mapping document_id to metadata dict | ||
| """ | ||
| metadata_map: dict[str, dict[str, Any]] = {} | ||
|
|
||
| # Guard against missing tool_response or content | ||
| if not tool_response: | ||
| return metadata_map | ||
|
|
||
| content = getattr(tool_response, "content", None) | ||
| if not content: | ||
| return metadata_map | ||
|
|
||
| # Ensure content is iterable | ||
| try: | ||
| iter(content) | ||
| except TypeError: | ||
| return metadata_map | ||
|
|
||
| for text_content_item in content: | ||
| # Skip items that lack a non-empty "text" attribute | ||
| text = getattr(text_content_item, "text", None) | ||
| if not text: | ||
| continue | ||
|
|
||
| try: | ||
| parsed_metadata = parse_knowledge_search_metadata(text) | ||
| metadata_map.update(parsed_metadata) | ||
| except ValueError: | ||
| logger.exception( | ||
| "An exception was thrown in processing metadata from text: %s", | ||
| text[:200] + "..." if len(text) > 200 else text, | ||
| ) | ||
|
||
|
|
||
| return metadata_map | ||
|
|
||
|
|
||
| def extract_referenced_documents_from_steps( | ||
| steps: list[Any], | ||
| ) -> list[ReferencedDocument]: | ||
| """Extract referenced documents from tool execution steps. | ||
|
|
||
| Args: | ||
| steps: List of response steps from the agent | ||
|
|
||
| Returns: | ||
| List of referenced documents with doc_url and doc_title | ||
| """ | ||
| metadata_map: dict[str, dict[str, Any]] = {} | ||
|
|
||
| for step in steps: | ||
| if getattr(step, "step_type", "") != "tool_execution" or not hasattr( | ||
| step, "tool_responses" | ||
| ): | ||
| continue | ||
|
|
||
| for tool_response in getattr(step, "tool_responses", []) or []: | ||
| if getattr( | ||
| tool_response, "tool_name", "" | ||
| ) != "knowledge_search" or not getattr(tool_response, "content", []): | ||
| continue | ||
|
|
||
| response_metadata = _process_knowledge_search_content(tool_response) | ||
| metadata_map.update(response_metadata) | ||
|
|
||
| # Extract referenced documents from metadata with error handling | ||
| referenced_documents = [] | ||
| for v in metadata_map.values(): | ||
| if "docs_url" in v and "title" in v: | ||
| try: | ||
| doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"]) | ||
| referenced_documents.append(doc) | ||
| except (pydantic.ValidationError, ValueError) as e: | ||
| logger.warning( | ||
| "Skipping invalid referenced document with docs_url='%s', title='%s': %s", | ||
| v.get("docs_url", "<missing>"), | ||
| v.get("title", "<missing>"), | ||
| str(e), | ||
| ) | ||
| continue | ||
|
|
||
| return referenced_documents | ||
|
|
||
|
|
||
| query_response: dict[int | str, dict[str, Any]] = { | ||
| 200: { | ||
| "conversation_id": "123e4567-e89b-12d3-a456-426614174000", | ||
| "response": "LLM answer", | ||
| "referenced_documents": [ | ||
| { | ||
| "doc_url": ( | ||
| "https://docs.openshift.com/container-platform/" | ||
| "4.15/operators/olm/index.html" | ||
| ), | ||
| "doc_title": "Operator Lifecycle Manager (OLM)", | ||
| } | ||
| ], | ||
| }, | ||
| 400: { | ||
| "description": "Missing or invalid credentials provided by client", | ||
|
|
@@ -54,7 +162,7 @@ | |
| "description": "User is not authorized", | ||
| "model": ForbiddenResponse, | ||
| }, | ||
| 503: { | ||
| 500: { | ||
| "detail": { | ||
| "response": "Unable to connect to Llama Stack", | ||
| "cause": "Connection error.", | ||
|
|
@@ -189,7 +297,7 @@ async def query_endpoint_handler( | |
| user_conversation=user_conversation, query_request=query_request | ||
| ), | ||
| ) | ||
| response, conversation_id = await retrieve_response( | ||
| response, conversation_id, referenced_documents = await retrieve_response( | ||
| client, | ||
| llama_stack_model_id, | ||
| query_request, | ||
|
|
@@ -223,7 +331,11 @@ async def query_endpoint_handler( | |
| provider_id=provider_id, | ||
| ) | ||
|
|
||
| return QueryResponse(conversation_id=conversation_id, response=response) | ||
| return QueryResponse( | ||
| conversation_id=conversation_id, | ||
| response=response, | ||
| referenced_documents=referenced_documents, | ||
| ) | ||
|
|
||
| # connection to Llama Stack server | ||
| except APIConnectionError as e: | ||
|
|
@@ -316,13 +428,13 @@ def is_input_shield(shield: Shield) -> bool: | |
| return _is_inout_shield(shield) or not is_output_shield(shield) | ||
|
|
||
|
|
||
| async def retrieve_response( # pylint: disable=too-many-locals | ||
| async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches | ||
| client: AsyncLlamaStackClient, | ||
| model_id: str, | ||
| query_request: QueryRequest, | ||
| token: str, | ||
| mcp_headers: dict[str, dict[str, str]] | None = None, | ||
| ) -> tuple[str, str]: | ||
| ) -> tuple[str, str, list[ReferencedDocument]]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_input_shields = [ | ||
| shield.identifier | ||
|
|
@@ -402,15 +514,33 @@ async def retrieve_response( # pylint: disable=too-many-locals | |
| toolgroups=toolgroups, | ||
| ) | ||
|
|
||
| # Check for validation errors in the response | ||
| # Check for validation errors and extract referenced documents | ||
| steps = getattr(response, "steps", []) | ||
| for step in steps: | ||
| if step.step_type == "shield_call" and step.violation: | ||
| if getattr(step, "step_type", "") == "shield_call" and getattr( | ||
| step, "violation", False | ||
| ): | ||
| # Metric for LLM validation errors | ||
| metrics.llm_calls_validation_errors_total.inc() | ||
| break | ||
|
|
||
| return str(response.output_message.content), conversation_id # type: ignore[union-attr] | ||
| # Extract referenced documents from tool execution steps | ||
| referenced_documents = extract_referenced_documents_from_steps(steps) | ||
|
|
||
| # When stream=False, response should have output_message attribute | ||
| response_obj = cast(Any, response) | ||
|
|
||
| # Safely guard access to output_message and content | ||
| output_message = getattr(response_obj, "output_message", None) | ||
| if output_message and getattr(output_message, "content", None) is not None: | ||
| content_str = str(output_message.content) | ||
| else: | ||
| content_str = "" | ||
|
|
||
| return ( | ||
| content_str, | ||
| conversation_id, | ||
| referenced_documents, | ||
| ) | ||
|
|
||
|
|
||
| def validate_attachments_metadata(attachments: list[Attachment]) -> None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π οΈ Refactor suggestion
Handle string content and parse before iterating to avoid char-wise iteration
If tool_response.content is ever a string, iter(content) will succeed and youβll end up iterating characters, skipping all metadata. Handle the string case explicitly and parse it in non-strict mode for resilience.
Apply this diff:
content = getattr(tool_response, "content", None) if not content: return metadata_map + # Handle the case where content is a single text blob + if isinstance(content, str): + parsed = parse_knowledge_search_metadata(content, strict=False) + if parsed: + metadata_map.update(parsed) + return metadata_map + # Ensure content is iterable try: iter(content) except TypeError: return metadata_mapπ Committable suggestion
π€ Prompt for AI Agents
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved by 12becc5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
I'll check the commit to see how you resolved the string content handling issue.
@thoraxe, excellent! I can see you've resolved the string content handling issue in commit 12becc5. The implementation properly handles the case where
contentis a string by parsing it directly withparse_knowledge_search_metadata(content, strict=False)before attempting to iterate, which prevents the problematic char-by-char iteration I flagged.The solution looks clean and follows the non-strict parsing approach for resilience. Thank you for addressing this potential bug!
π Script executed:
Length of output: 5914
Let me check the new consolidated implementation in the metadata utility to see how the string handling was resolved:
@thoraxe, this is an excellent solution! I can see you took a more comprehensive approach than my original suggestion. Instead of just fixing the string iteration bug in place, you've:
_process_knowledge_search_contentandextract_referenced_documents_from_stepsinto the dedicatedutils/metadata.pymoduleThe new
extract_referenced_documents_from_stepsimport fromutils.metadatashows this is now properly consolidated. This is actually a superior solution to my targeted fix - you've addressed both the immediate bug and the underlying architectural issue that would have led to similar problems in the future.Great work on the refactoring!
π Script executed:
Length of output: 2955
Perfect! I can see the string handling fix was indeed implemented in the consolidated
process_knowledge_search_contentfunction. The solution addresses the exact issue I flagged:This implementation:
isinstance(content, str)json.loads(content, strict=False)for resilienceiter()checkThe architectural refactoring combined with the targeted string handling fix is exactly what was needed. You've successfully resolved both the immediate bug and improved the overall code structure!
This consolidation approach also ensures the fix applies consistently across both query and streaming_query endpoints. Well done!
𧬠Code Graph Analysis Results
File: src/utils/metadata.py, Lines: 161-205
Snippet: extract_referenced_documents_from_steps
"""
def extract_referenced_documents_from_steps(
steps: list[Any],
) -> list[ReferencedDocument]:
"""Extract referenced documents from tool execution steps.
"""
File: src/app/endpoints/query.py, (retrieve_response function)
Snippet: Summary of purpose and behavior
Provided for quick context on how the non-streaming query path constructs and returns a response.
(Note: Full code not pasted here; this is a concise behavioral summary for reviewer context.)
File: src/utils/mcp_headers.py, Lines: 15-26
Snippet: mcp_headers_dependency
"""
async def mcp_headers_dependency(request: Request) -> dict[str, dict[str, str]]:
"""Get the MCP headers dependency to passed to mcp servers.
"""
File: src/utils/mcp_headers.py, Lines: 56-90
Snippet: handle_mcp_headers_with_toolgroups
"""
def handle_mcp_headers_with_toolgroups(
mcp_headers: dict[str, dict[str, str]], config: AppConfig
) -> dict[str, dict[str, str]]:
"""Process MCP headers by converting toolgroup names to URLs.
"""
File: src/models/responses.py, Lines: 48-52
Snippet: ReferencedDocument
"""
class ReferencedDocument(BaseModel):
"""Model representing a document referenced in generating a response."""
"""
File: src/models/responses.py, Lines: 55-113
Snippet: QueryResponse
"""
class QueryResponse(BaseModel):
"""Model representing LLM response to a query.
"""
File: src/models/responses.py, Lines: 353-367
Snippet: UnauthorizedResponse
"""
class UnauthorizedResponse(BaseModel):
"""Model representing response for missing or invalid credentials."""
"""
File: src/models/responses.py, Lines: 370-382
Snippet: ForbiddenResponse
"""
class ForbiddenResponse(UnauthorizedResponse):
"""Model representing response for forbidden access."""
"""
File: src/models/requests.py, Lines: 70-220
Snippet: QueryRequest (summary)
Summary: Model representing a request for the LLM.
File: src/models/requests.py, Lines: 15-67
Snippet: Attachment
"""
class Attachment(BaseModel):
"""Model representing an attachment that can be send from UI as part of query.
"""
File: src/app/database.py, Lines: 34-40
Snippet: get_session
"""
def get_session() -> Session:
"""Get a database session. Raises an error if not initialized."""
if SessionLocal is None:
raise RuntimeError(
"Database session not initialized. Call initialize_database() first."
)
return SessionLocal()
"""
File: src/models/database/conversations.py, Lines: 11-36
Snippet: UserConversation
"""
class UserConversation(Base): # pylint: disable=too-few-public-methods
"""Model for storing user conversation metadata."""
"""
File: src/utils/suid.py, Lines: 6-12
Snippet: get_suid
"""
def get_suid() -> str:
"""Generate a unique session ID (SUID) using UUID4.
"""
File: src/client.py, Lines: 18-50 and 44-50
Snippet: AsyncLlamaStackClientHolder (class)
Summary:
File: src/utils/endpoints.py, Lines: 34-40
Snippet: check_configuration_loaded
"""
def check_configuration_loaded(config: AppConfig) -> None:
"""Check that configuration is loaded and raise exception when it is not."""
if config is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"response": "Configuration is not loaded"},
)
"""
File: src/utils/endpoints.py, Lines: 78-127
Snippet: get_agent
Summary:
File: src/utils/endpoints.py, Lines: 43-74
Snippet: get_system_prompt
Summary:
File: src/auth/init.py, Lines: 14-43
Snippet: get_auth_dependency
Summary:
File: src/app/endpoints/query.py, Lines: 0-n (select_model_and_provider_id)
Snippet: select_model_and_provider_id
Summary:
File: src/configuration.py, Lines: 55-60; 71-76; 79-84; 87-92; 116-121
Snippet: Configuration accessors
Summary:
File: src/app/endpoints/query.py, Lines: (context)
Snippet: is_input_shield / is_output_shield / _is_inout_shield
Summary:
File: src/app/endpoints/query.py, Lines: (context)
Snippet: get_rag_toolgroups
Summary:
If you found this review helpful, would you consider giving us a shout-out on X?
Thank you for using CodeRabbit!