-
Notifications
You must be signed in to change notification settings - Fork 49
implements referenced documents on /query and updates /streaming_query to match #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
f022501
8c0a1d0
d6e1475
866bb5c
3026397
d5e2622
a7f815e
af646c6
06cba91
1f8fbb9
1256193
1fa7b87
dcf4f23
12becc5
0291e11
a463e6f
36a5604
622ae9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,9 @@ | |
| import logging | ||
| import os | ||
| from pathlib import Path | ||
| from typing import Annotated, Any | ||
| from typing import Annotated, Any, cast | ||
|
|
||
| import pydantic | ||
|
|
||
| from llama_stack_client import APIConnectionError | ||
| from llama_stack_client import AsyncLlamaStackClient # type: ignore | ||
|
|
@@ -25,7 +27,12 @@ | |
| from app.database import get_session | ||
| import metrics | ||
| from models.database.conversations import UserConversation | ||
| from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse | ||
| from models.responses import ( | ||
| QueryResponse, | ||
| UnauthorizedResponse, | ||
| ForbiddenResponse, | ||
| ReferencedDocument, | ||
| ) | ||
| from models.requests import QueryRequest, Attachment | ||
| import constants | ||
| from utils.endpoints import ( | ||
|
|
@@ -36,15 +43,106 @@ | |
| ) | ||
| from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups | ||
| from utils.suid import get_suid | ||
| from utils.metadata import parse_knowledge_search_metadata | ||
|
|
||
| logger = logging.getLogger("app.endpoints.handlers") | ||
| router = APIRouter(tags=["query"]) | ||
| auth_dependency = get_auth_dependency() | ||
|
|
||
|
|
||
| def _process_knowledge_search_content( | ||
| tool_response: Any, metadata_map: dict[str, dict[str, Any]] | ||
| ) -> None: | ||
| """Process knowledge search tool response content for metadata.""" | ||
| # Guard against missing tool_response or content | ||
| if not tool_response: | ||
| return | ||
|
|
||
| content = getattr(tool_response, "content", None) | ||
| if not content: | ||
| return | ||
|
|
||
| # Ensure content is iterable | ||
| try: | ||
| iter(content) | ||
| except TypeError: | ||
| return | ||
|
|
||
| for text_content_item in content: | ||
| # Skip items that lack a non-empty "text" attribute | ||
| text = getattr(text_content_item, "text", None) | ||
| if not text: | ||
| continue | ||
|
|
||
| try: | ||
| parsed_metadata = parse_knowledge_search_metadata(text) | ||
| metadata_map.update(parsed_metadata) | ||
| except ValueError: | ||
| logger.exception( | ||
| "An exception was thrown in processing metadata from text: %s", | ||
| text[:200] + "..." if len(text) > 200 else text, | ||
| ) | ||
|
||
|
|
||
|
|
||
| def extract_referenced_documents_from_steps( | ||
| steps: list[Any], | ||
| ) -> list[ReferencedDocument]: | ||
| """Extract referenced documents from tool execution steps. | ||
|
|
||
| Args: | ||
| steps: List of response steps from the agent | ||
|
|
||
| Returns: | ||
| List of referenced documents with doc_url and doc_title | ||
| """ | ||
| metadata_map: dict[str, dict[str, Any]] = {} | ||
|
|
||
| for step in steps: | ||
| if getattr(step, "step_type", "") != "tool_execution" or not hasattr( | ||
| step, "tool_responses" | ||
| ): | ||
| continue | ||
|
|
||
| for tool_response in getattr(step, "tool_responses", []) or []: | ||
| if getattr( | ||
| tool_response, "tool_name", "" | ||
| ) != "knowledge_search" or not getattr(tool_response, "content", []): | ||
| continue | ||
|
|
||
| _process_knowledge_search_content(tool_response, metadata_map) | ||
|
|
||
| # Extract referenced documents from metadata with error handling | ||
| referenced_documents = [] | ||
| for v in metadata_map.values(): | ||
| if "docs_url" in v and "title" in v: | ||
| try: | ||
| doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"]) | ||
| referenced_documents.append(doc) | ||
| except (pydantic.ValidationError, ValueError) as e: | ||
| logger.warning( | ||
| "Skipping invalid referenced document with docs_url='%s', title='%s': %s", | ||
| v.get("docs_url", "<missing>"), | ||
| v.get("title", "<missing>"), | ||
| str(e), | ||
| ) | ||
| continue | ||
|
|
||
| return referenced_documents | ||
|
|
||
|
|
||
| query_response: dict[int | str, dict[str, Any]] = { | ||
| 200: { | ||
| "conversation_id": "123e4567-e89b-12d3-a456-426614174000", | ||
| "response": "LLM answer", | ||
| "referenced_documents": [ | ||
| { | ||
| "doc_url": ( | ||
| "https://docs.openshift.com/container-platform/" | ||
| "4.15/operators/olm/index.html" | ||
| ), | ||
| "doc_title": "Operator Lifecycle Manager (OLM)", | ||
| } | ||
| ], | ||
| }, | ||
| 400: { | ||
| "description": "Missing or invalid credentials provided by client", | ||
|
|
@@ -54,7 +152,7 @@ | |
| "description": "User is not authorized", | ||
| "model": ForbiddenResponse, | ||
| }, | ||
| 503: { | ||
| 500: { | ||
| "detail": { | ||
| "response": "Unable to connect to Llama Stack", | ||
| "cause": "Connection error.", | ||
|
|
@@ -189,7 +287,7 @@ async def query_endpoint_handler( | |
| user_conversation=user_conversation, query_request=query_request | ||
| ), | ||
| ) | ||
| response, conversation_id = await retrieve_response( | ||
| response, conversation_id, referenced_documents = await retrieve_response( | ||
| client, | ||
| llama_stack_model_id, | ||
| query_request, | ||
|
|
@@ -223,7 +321,11 @@ async def query_endpoint_handler( | |
| provider_id=provider_id, | ||
| ) | ||
|
|
||
| return QueryResponse(conversation_id=conversation_id, response=response) | ||
| return QueryResponse( | ||
| conversation_id=conversation_id, | ||
| response=response, | ||
| referenced_documents=referenced_documents, | ||
| ) | ||
|
|
||
| # connection to Llama Stack server | ||
| except APIConnectionError as e: | ||
|
|
@@ -316,13 +418,13 @@ def is_input_shield(shield: Shield) -> bool: | |
| return _is_inout_shield(shield) or not is_output_shield(shield) | ||
|
|
||
|
|
||
| async def retrieve_response( # pylint: disable=too-many-locals | ||
| async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches | ||
| client: AsyncLlamaStackClient, | ||
| model_id: str, | ||
| query_request: QueryRequest, | ||
| token: str, | ||
| mcp_headers: dict[str, dict[str, str]] | None = None, | ||
| ) -> tuple[str, str]: | ||
| ) -> tuple[str, str, list[ReferencedDocument]]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_input_shields = [ | ||
| shield.identifier | ||
|
|
@@ -402,15 +504,33 @@ async def retrieve_response( # pylint: disable=too-many-locals | |
| toolgroups=toolgroups, | ||
| ) | ||
|
|
||
| # Check for validation errors in the response | ||
| # Check for validation errors and extract referenced documents | ||
| steps = getattr(response, "steps", []) | ||
| for step in steps: | ||
| if step.step_type == "shield_call" and step.violation: | ||
| if getattr(step, "step_type", "") == "shield_call" and getattr( | ||
| step, "violation", False | ||
| ): | ||
| # Metric for LLM validation errors | ||
| metrics.llm_calls_validation_errors_total.inc() | ||
| break | ||
|
|
||
| return str(response.output_message.content), conversation_id # type: ignore[union-attr] | ||
| # Extract referenced documents from tool execution steps | ||
| referenced_documents = extract_referenced_documents_from_steps(steps) | ||
|
|
||
| # When stream=False, response should have output_message attribute | ||
| response_obj = cast(Any, response) | ||
|
|
||
| # Safely guard access to output_message and content | ||
| output_message = getattr(response_obj, "output_message", None) | ||
| if output_message and getattr(output_message, "content", None) is not None: | ||
| content_str = str(output_message.content) | ||
| else: | ||
| content_str = "" | ||
|
|
||
| return ( | ||
| content_str, | ||
| conversation_id, | ||
| referenced_documents, | ||
| ) | ||
|
|
||
|
|
||
| def validate_attachments_metadata(attachments: list[Attachment]) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metadata_mapseems to be the return value, not a real parameter. Please refactor to return newmetadata_mapThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed with 06cba91