Skip to content

Commit d5e2622

Browse files
committed
more coderabbit
1 parent 3026397 commit d5e2622

File tree

6 files changed

+369
-49
lines changed

6 files changed

+369
-49
lines changed

src/app/endpoints/query.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Handler for REST API call to provide answer to query."""
22

3-
import ast
43
from datetime import datetime, UTC
54
import json
65
import logging
76
import os
87
from pathlib import Path
9-
import re
108
from typing import Annotated, Any, cast
119

10+
import pydantic
11+
1212
from llama_stack_client import APIConnectionError
1313
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1414
from llama_stack_client.types import UserMessage, Shield # type: ignore
@@ -43,13 +43,12 @@
4343
)
4444
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
4545
from utils.suid import get_suid
46+
from utils.metadata import parse_knowledge_search_metadata
4647

4748
logger = logging.getLogger("app.endpoints.handlers")
4849
router = APIRouter(tags=["query"])
4950
auth_dependency = get_auth_dependency()
5051

51-
METADATA_PATTERN = re.compile(r"^\s*Metadata:\s*(\{.*?\})\s*$", re.MULTILINE)
52-
5352

5453
def _process_knowledge_search_content(
5554
tool_response: Any, metadata_map: dict[str, dict[str, Any]]
@@ -75,17 +74,14 @@ def _process_knowledge_search_content(
7574
if not text:
7675
continue
7776

78-
for match in METADATA_PATTERN.findall(text):
79-
try:
80-
meta = ast.literal_eval(match)
81-
# Verify the result is a dict before accessing keys
82-
if isinstance(meta, dict) and "document_id" in meta:
83-
metadata_map[meta["document_id"]] = meta
84-
except (SyntaxError, ValueError): # only expected from literal_eval
85-
logger.exception(
86-
"An exception was thrown in processing %s",
87-
match,
88-
)
77+
try:
78+
parsed_metadata = parse_knowledge_search_metadata(text)
79+
metadata_map.update(parsed_metadata)
80+
except ValueError:
81+
logger.exception(
82+
"An exception was thrown in processing metadata from text: %s",
83+
text[:200] + "..." if len(text) > 200 else text,
84+
)
8985

9086

9187
def extract_referenced_documents_from_steps(steps: list) -> list[ReferencedDocument]:
@@ -113,12 +109,23 @@ def extract_referenced_documents_from_steps(steps: list) -> list[ReferencedDocum
113109

114110
_process_knowledge_search_content(tool_response, metadata_map)
115111

116-
# Extract referenced documents from metadata
117-
return [
118-
ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
119-
for v in metadata_map.values()
120-
if "docs_url" in v and "title" in v
121-
]
112+
# Extract referenced documents from metadata with error handling
113+
referenced_documents = []
114+
for v in metadata_map.values():
115+
if "docs_url" in v and "title" in v:
116+
try:
117+
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
118+
referenced_documents.append(doc)
119+
except (pydantic.ValidationError, ValueError, Exception) as e:
120+
logger.warning(
121+
"Skipping invalid referenced document with docs_url='%s', title='%s': %s",
122+
v.get("docs_url", "<missing>"),
123+
v.get("title", "<missing>"),
124+
str(e),
125+
)
126+
continue
127+
128+
return referenced_documents
122129

123130

124131
query_response: dict[int | str, dict[str, Any]] = {

src/app/endpoints/streaming_query.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Handler for REST API call to provide answer to streaming query."""
22

3-
import ast
43
import json
5-
import re
64
import logging
75
from typing import Annotated, Any, AsyncIterator, Iterator
86

7+
import pydantic
8+
99
from llama_stack_client import APIConnectionError
1010
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1111
from llama_stack_client.types import UserMessage # type: ignore
@@ -27,6 +27,7 @@
2727
from models.responses import ReferencedDocument
2828
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
2929
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
30+
from utils.metadata import parse_knowledge_search_metadata
3031

3132
from app.endpoints.query import (
3233
get_rag_toolgroups,
@@ -46,9 +47,6 @@
4647
auth_dependency = get_auth_dependency()
4748

4849

49-
METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")
50-
51-
5250
def format_stream_data(d: dict) -> str:
5351
"""Format outbound data in the Event Stream Format."""
5452
data = json.dumps(d)
@@ -79,15 +77,24 @@ def stream_end_event(metadata_map: dict) -> str:
7977
lambda v: ("docs_url" in v) and ("title" in v),
8078
metadata_map.values(),
8179
):
82-
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
83-
referenced_documents.append(
84-
{
85-
"doc_url": str(
86-
doc.doc_url
87-
), # Convert AnyUrl to string for JSON serialization
88-
"doc_title": doc.doc_title,
89-
}
90-
)
80+
try:
81+
doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"])
82+
referenced_documents.append(
83+
{
84+
"doc_url": str(
85+
doc.doc_url
86+
), # Convert AnyUrl to string for JSON serialization
87+
"doc_title": doc.doc_title,
88+
}
89+
)
90+
except (pydantic.ValidationError, ValueError, Exception) as e:
91+
logger.warning(
92+
"Skipping invalid referenced document with docs_url='%s', title='%s': %s",
93+
v.get("docs_url", "<missing>"),
94+
v.get("title", "<missing>"),
95+
str(e),
96+
)
97+
continue
9198

9299
return format_stream_data(
93100
{
@@ -335,20 +342,20 @@ def _handle_tool_execution_event(
335342
newline_pos = summary.find("\n")
336343
if newline_pos > 0:
337344
summary = summary[:newline_pos]
338-
for match in METADATA_PATTERN.findall(text_content_item.text):
339-
try:
340-
meta = ast.literal_eval(match)
341-
# Verify the result is a dict before accessing keys
342-
if isinstance(meta, dict) and "document_id" in meta:
343-
metadata_map[meta["document_id"]] = meta
344-
except (
345-
SyntaxError,
346-
ValueError,
347-
): # only expected from literal_eval
348-
logger.exception(
349-
"An exception was thrown in processing %s",
350-
match,
351-
)
345+
try:
346+
parsed_metadata = parse_knowledge_search_metadata(
347+
text_content_item.text
348+
)
349+
metadata_map.update(parsed_metadata)
350+
except ValueError:
351+
logger.exception(
352+
"An exception was thrown in processing metadata from text: %s",
353+
(
354+
text_content_item.text[:200] + "..."
355+
if len(text_content_item.text) > 200
356+
else text_content_item.text
357+
),
358+
)
352359

353360
yield format_stream_data(
354361
{

src/utils/metadata.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Shared utilities for parsing metadata from knowledge search responses."""
2+
3+
import ast
4+
import re
5+
from typing import Any
6+
7+
8+
METADATA_PATTERN = re.compile(r"^\s*Metadata:\s*(\{.*?\})\s*$", re.MULTILINE)
9+
10+
11+
def parse_knowledge_search_metadata(text: str) -> dict[str, Any]:
12+
"""Parse metadata from knowledge search text content.
13+
14+
Args:
15+
text: Text content that may contain metadata patterns
16+
17+
Returns:
18+
Dictionary of document_id -> metadata mappings
19+
20+
Raises:
21+
ValueError: If metadata parsing fails due to invalid JSON or syntax
22+
"""
23+
metadata_map: dict[str, Any] = {}
24+
25+
for match in METADATA_PATTERN.findall(text):
26+
try:
27+
meta = ast.literal_eval(match)
28+
# Verify the result is a dict before accessing keys
29+
if isinstance(meta, dict) and "document_id" in meta:
30+
metadata_map[meta["document_id"]] = meta
31+
except (SyntaxError, ValueError) as e:
32+
raise ValueError(f"Failed to parse metadata '{match}': {e}") from e
33+
34+
return metadata_map

tests/unit/app/endpoints/test_query.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,3 +1888,107 @@ async def test_retrieve_response_with_structured_content_object(
18881888
# Should convert the structured object to string representation
18891889
assert response == str(structured_content)
18901890
assert conversation_id == "fake_conversation_id"
1891+
1892+
1893+
@pytest.mark.asyncio
1894+
async def test_retrieve_response_skips_invalid_docs_url(prepare_agent_mocks, mocker):
1895+
"""Test that retrieve_response skips entries with invalid docs_url."""
1896+
mock_client, mock_agent = prepare_agent_mocks
1897+
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
1898+
mock_client.shields.list.return_value = []
1899+
mock_client.vector_dbs.list.return_value = []
1900+
1901+
# Mock tool response with valid and invalid docs_url entries
1902+
invalid_docs_url_results = [
1903+
"""knowledge_search tool found 2 chunks:
1904+
BEGIN of knowledge_search tool results.
1905+
""",
1906+
"""Result 1
1907+
Content: Valid content
1908+
Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Valid Doc', 'document_id': 'doc-1'}
1909+
""",
1910+
"""Result 2
1911+
Content: Invalid content
1912+
Metadata: {'docs_url': 'not-a-valid-url', 'title': 'Invalid Doc', 'document_id': 'doc-2'}
1913+
""",
1914+
"""END of knowledge_search tool results.
1915+
""",
1916+
]
1917+
1918+
mock_tool_response = mocker.Mock()
1919+
mock_tool_response.call_id = "c1"
1920+
mock_tool_response.tool_name = "knowledge_search"
1921+
mock_tool_response.content = [
1922+
mocker.Mock(text=s, type="text") for s in invalid_docs_url_results
1923+
]
1924+
1925+
mock_tool_execution_step = mocker.Mock()
1926+
mock_tool_execution_step.step_type = "tool_execution"
1927+
mock_tool_execution_step.tool_responses = [mock_tool_response]
1928+
1929+
mock_agent.create_turn.return_value.steps = [mock_tool_execution_step]
1930+
1931+
# Mock configuration with empty MCP servers
1932+
mock_config = mocker.Mock()
1933+
mock_config.mcp_servers = []
1934+
mocker.patch("app.endpoints.query.configuration", mock_config)
1935+
mocker.patch(
1936+
"app.endpoints.query.get_agent",
1937+
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
1938+
)
1939+
1940+
query_request = QueryRequest(query="What is OpenStack?")
1941+
model_id = "fake_model_id"
1942+
access_token = "test_token"
1943+
1944+
response, conversation_id, referenced_documents = await retrieve_response(
1945+
mock_client, model_id, query_request, access_token
1946+
)
1947+
1948+
assert response == "LLM answer"
1949+
assert conversation_id == "fake_conversation_id"
1950+
1951+
# Assert only the valid document is included, invalid one is skipped
1952+
assert len(referenced_documents) == 1
1953+
assert str(referenced_documents[0].doc_url) == "https://example.com/doc1"
1954+
assert referenced_documents[0].doc_title == "Valid Doc"
1955+
1956+
1957+
@pytest.mark.asyncio
1958+
async def test_extract_referenced_documents_from_steps_handles_validation_errors(
1959+
mocker,
1960+
):
1961+
"""Test that extract_referenced_documents_from_steps handles validation errors gracefully."""
1962+
# Mock tool response with invalid docs_url that will cause pydantic validation error
1963+
mock_tool_response = mocker.Mock()
1964+
mock_tool_response.tool_name = "knowledge_search"
1965+
mock_tool_response.content = [
1966+
mocker.Mock(
1967+
text="""Result 1
1968+
Content: Valid content
1969+
Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Valid Doc', 'document_id': 'doc-1'}
1970+
"""
1971+
),
1972+
mocker.Mock(
1973+
text="""Result 2
1974+
Content: Invalid content
1975+
Metadata: {'docs_url': 'invalid-url', 'title': 'Invalid Doc', 'document_id': 'doc-2'}
1976+
"""
1977+
),
1978+
]
1979+
1980+
mock_tool_execution_step = mocker.Mock()
1981+
mock_tool_execution_step.step_type = "tool_execution"
1982+
mock_tool_execution_step.tool_responses = [mock_tool_response]
1983+
1984+
steps = [mock_tool_execution_step]
1985+
1986+
# Import the function directly to test it
1987+
from app.endpoints.query import extract_referenced_documents_from_steps
1988+
1989+
referenced_documents = extract_referenced_documents_from_steps(steps)
1990+
1991+
# Should only return the valid document, skipping the invalid one
1992+
assert len(referenced_documents) == 1
1993+
assert str(referenced_documents[0].doc_url) == "https://example.com/doc1"
1994+
assert referenced_documents[0].doc_title == "Valid Doc"

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,44 @@ def test_stream_end_event_with_referenced_documents():
11041104
assert "Test Document 4" not in doc_titles
11051105

11061106

1107+
def test_stream_end_event_skips_invalid_docs_url():
1108+
"""Test stream_end_event skips entries with invalid docs_url."""
1109+
metadata_map = {
1110+
"doc-1": {
1111+
"docs_url": "https://example.com/doc1",
1112+
"title": "Valid Document",
1113+
"document_id": "doc-1",
1114+
},
1115+
"doc-2": {
1116+
"docs_url": "not-a-valid-url", # Invalid URL that will cause ValidationError
1117+
"title": "Invalid Document",
1118+
"document_id": "doc-2",
1119+
},
1120+
"doc-3": {
1121+
"docs_url": "", # Empty URL that will cause ValidationError
1122+
"title": "Empty URL Document",
1123+
"document_id": "doc-3",
1124+
},
1125+
}
1126+
1127+
result = stream_end_event(metadata_map)
1128+
1129+
# Parse the JSON response
1130+
parsed = json.loads(result.replace("data: ", ""))
1131+
1132+
# Verify structure
1133+
assert parsed["event"] == "end"
1134+
assert "referenced_documents" in parsed["data"]
1135+
1136+
# Verify only valid documents are included, invalid ones are skipped
1137+
referenced_docs = parsed["data"]["referenced_documents"]
1138+
assert len(referenced_docs) == 1
1139+
1140+
# Verify the valid document is included
1141+
assert referenced_docs[0]["doc_url"] == "https://example.com/doc1"
1142+
assert referenced_docs[0]["doc_title"] == "Valid Document"
1143+
1144+
11071145
def test_stream_build_event_error():
11081146
"""Test stream_build_event function returns a 'error' when chunk contains error information."""
11091147
# Create a mock chunk without an expected payload structure

0 commit comments

Comments
 (0)