diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 9b898f62..486558ad 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,11 +1,12 @@ -"""Handler for REST API call to provide answer to streaming query.""" +"""Handler for REST API call to provide answer to streaming query.""" # pylint: disable=too-many-lines import ast import json import logging import re +import uuid from datetime import UTC, datetime -from typing import Annotated, Any, AsyncIterator, Iterator, cast +from typing import Annotated, Any, AsyncGenerator, AsyncIterator, Iterator, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import StreamingResponse @@ -39,7 +40,7 @@ from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration -from constants import DEFAULT_RAG_TOOL +from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT import metrics from metrics.utils import update_llm_token_count_from_turn from models.config import Action @@ -48,7 +49,6 @@ from models.responses import ForbiddenResponse, UnauthorizedResponse from utils.endpoints import ( check_configuration_loaded, - create_referenced_documents_with_metadata, create_rag_chunks_dict, get_agent, get_system_prompt, @@ -66,20 +66,25 @@ 200: { "description": "Streaming response with Server-Sent Events", "content": { - "text/event-stream": { + "application/json": { "schema": { "type": "string", "example": ( 'data: {"event": "start", ' '"data": {"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}}\n\n' - 'data: {"event": "token", "data": {"id": 0, "role": "inference", ' - '"token": "Hello"}}\n\n' + 'data: {"event": "token", "data": {"id": 0, "token": "Hello"}}\n\n' 'data: {"event": "end", "data": {"referenced_documents": [], ' '"truncated": null, "input_tokens": 0, "output_tokens": 0}, ' '"available_quotas": {}}\n\n' ), } - } + }, + "text/plain": { + "schema": { + "type": "string", + "example": "Hello world!\n\n---\n\nReference: https://example.com/doc", + } + }, }, }, 400: { @@ -105,6 +110,11 @@ METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") +# OLS-compatible event types +LLM_TOKEN_EVENT = "token" +LLM_TOOL_CALL_EVENT = "tool_call" +LLM_TOOL_RESULT_EVENT = "tool_result" + def format_stream_data(d: dict) -> str: """ @@ -144,7 +154,7 @@ def stream_start_event(conversation_id: str) -> str: ) -def stream_end_event(metadata_map: dict, summary: TurnSummary) -> str: +def stream_end_event(metadata_map: dict, media_type: str = MEDIA_TYPE_JSON) -> str: """ Yield the end of the data stream. @@ -155,31 +165,38 @@ def stream_end_event(metadata_map: dict, summary: TurnSummary) -> str: Parameters: metadata_map (dict): A mapping containing metadata about referenced documents. + media_type (str): The media type for the response format. Returns: str: A Server-Sent Events (SSE) formatted string representing the end of the data stream. """ - # Process RAG chunks using utility function - rag_chunks = create_rag_chunks_dict(summary) - - # Extract referenced documents using utility function - referenced_docs = create_referenced_documents_with_metadata(summary, metadata_map) + if media_type == MEDIA_TYPE_TEXT: + ref_docs_string = "\n".join( + f'{v["title"]}: {v["docs_url"]}' + for v in filter( + lambda v: ("docs_url" in v) and ("title" in v), + metadata_map.values(), + ) + ) + return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" - # Convert ReferencedDocument objects to dictionaries for JSON serialization + # For JSON media type, we need to create a proper structure + # Since we don't have access to summary here, we'll create a basic structure referenced_docs_dict = [ { - "doc_url": str(doc.doc_url) if doc.doc_url else None, - "doc_title": doc.doc_title, + "doc_url": v.get("docs_url"), + "doc_title": v.get("title"), } - for doc in referenced_docs + for v in metadata_map.values() + if "docs_url" in v and "title" in v ] return format_stream_data( { "event": "end", "data": { - "rag_chunks": rag_chunks, + "rag_chunks": [], # TODO(jboos): implement RAG chunks when summary is available "referenced_documents": referenced_docs_dict, "truncated": None, # TODO(jboos): implement truncated "input_tokens": 0, # TODO(jboos): implement input tokens @@ -190,7 +207,41 @@ def stream_end_event(metadata_map: dict, summary: TurnSummary) -> str: ) -def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterator[str]: +def stream_event(data: dict, event_type: str, media_type: str) -> str: + """Build an item to yield based on media type. + + Args: + data: The data to yield. + event_type: The type of event (e.g. token, tool request, tool execution). + media_type: Media type of the response (e.g. text or JSON). + + Returns: + str: The formatted string or JSON to yield. + """ + if media_type == MEDIA_TYPE_TEXT: + if event_type == LLM_TOKEN_EVENT: + return data["token"] + if event_type == LLM_TOOL_CALL_EVENT: + return f"\nTool call: {json.dumps(data)}\n" + if event_type == LLM_TOOL_RESULT_EVENT: + return f"\nTool result: {json.dumps(data)}\n" + logger.error("Unknown event type: %s", event_type) + return "" + return format_stream_data( + { + "event": event_type, + "data": data, + } + ) + + +def stream_build_event( + chunk: Any, + chunk_id: int, + metadata_map: dict, + media_type: str = MEDIA_TYPE_JSON, + conversation_id: str | None = None, +) -> Iterator[str]: """Build a streaming event from a chunk response. This function processes chunks from the Llama Stack streaming response and @@ -210,35 +261,39 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterato Iterator[str]: An iterable list of formatted SSE data strings with event information """ if hasattr(chunk, "error"): - yield from _handle_error_event(chunk, chunk_id) + yield from _handle_error_event(chunk, chunk_id, media_type) event_type = chunk.event.payload.event_type step_type = getattr(chunk.event.payload, "step_type", None) match (event_type, step_type): case (("turn_start" | "turn_awaiting_input"), _): - yield from _handle_turn_start_event(chunk_id) + yield from _handle_turn_start_event(chunk_id, media_type, conversation_id) case ("turn_complete", _): - yield from _handle_turn_complete_event(chunk, chunk_id) + yield from _handle_turn_complete_event(chunk, chunk_id, media_type) case (_, "shield_call"): - yield from _handle_shield_event(chunk, chunk_id) + yield from _handle_shield_event(chunk, chunk_id, media_type) case (_, "inference"): - yield from _handle_inference_event(chunk, chunk_id) + yield from _handle_inference_event(chunk, chunk_id, media_type) case (_, "tool_execution"): - yield from _handle_tool_execution_event(chunk, chunk_id, metadata_map) + yield from _handle_tool_execution_event( + chunk, chunk_id, metadata_map, media_type + ) case _: logger.debug( "Unhandled event combo: event_type=%s, step_type=%s", event_type, step_type, ) - yield from _handle_heartbeat_event(chunk_id) + yield from _handle_heartbeat_event(chunk_id, media_type) # ----------------------------------- # Error handling # ----------------------------------- -def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]: +def _handle_error_event( + chunk: Any, chunk_id: int, media_type: str = MEDIA_TYPE_JSON +) -> Iterator[str]: """ Yield error event. @@ -248,13 +303,68 @@ def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]: Parameters: chunk_id (int): The unique identifier for the current streaming chunk. + media_type (str): The media type for the response format. + """ + if media_type == MEDIA_TYPE_TEXT: + yield f"Error: {chunk.error['message']}" + else: + yield format_stream_data( + { + "event": "error", + "data": { + "id": chunk_id, + "token": chunk.error["message"], + }, + } + ) + + +def prompt_too_long_error(error: Exception, media_type: str) -> str: + """Return error representation for long prompts. + + Args: + error: The exception raised for long prompts. + media_type: Media type of the response (e.g. text or JSON). + + Returns: + str: The error message formatted for the media type. """ - yield format_stream_data( + logger.error("Prompt is too long: %s", error) + if media_type == MEDIA_TYPE_TEXT: + return f"Prompt is too long: {error}" + return format_stream_data( { "event": "error", "data": { - "id": chunk_id, - "token": chunk.error["message"], + "status_code": 413, + "response": "Prompt is too long", + "cause": str(error), + }, + } + ) + + +def generic_llm_error(error: Exception, media_type: str) -> str: + """Return error representation for generic LLM errors. + + Args: + error: The exception raised during processing. + media_type: Media type of the response (e.g. text or JSON). + + Returns: + str: The error message formatted for the media type. + """ + logger.error("Error while obtaining answer for user question") + logger.exception(error) + + if media_type == MEDIA_TYPE_TEXT: + return f"Error: {str(error)}" + return format_stream_data( + { + "event": "error", + "data": { + "response": "Internal server error", + "cause": str(error), }, } ) @@ -263,11 +373,15 @@ def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]: # ----------------------------------- # Turn handling # ----------------------------------- -def _handle_turn_start_event(chunk_id: int) -> Iterator[str]: +def _handle_turn_start_event( + _chunk_id: int, + media_type: str = MEDIA_TYPE_JSON, + conversation_id: str | None = None, +) -> Iterator[str]: """ Yield turn start event. - Yield a Server-Sent Event (SSE) token event indicating the + Yield a Server-Sent Event (SSE) start event indicating the start of a new conversation turn. Parameters: @@ -275,21 +389,28 @@ def _handle_turn_start_event(chunk_id: int) -> Iterator[str]: chunk. Yields: - str: SSE-formatted token event with an empty token to - signal turn start. + str: SSE-formatted start event with conversation_id. """ - yield format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "token": "", - }, - } - ) + # Use provided conversation_id or generate one if not available + if conversation_id is None: + conversation_id = str(uuid.uuid4()) + + if media_type == MEDIA_TYPE_TEXT: + yield ( + f"data: {json.dumps({'event': 'start', 'data': {'conversation_id': conversation_id}})}\n\n" # pylint: disable=line-too-long + ) + else: + yield format_stream_data( + { + "event": "start", + "data": {"conversation_id": conversation_id}, + } + ) -def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]: +def _handle_turn_complete_event( + chunk: Any, _chunk_id: int, media_type: str = MEDIA_TYPE_JSON +) -> Iterator[str]: """ Yield turn complete event. @@ -304,23 +425,29 @@ def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]: str: SSE-formatted string containing the turn completion event and output message content. """ - yield format_stream_data( - { - "event": "turn_complete", - "data": { - "id": chunk_id, - "token": interleaved_content_as_str( - chunk.event.payload.turn.output_message.content - ), - }, - } + full_response = interleaved_content_as_str( + chunk.event.payload.turn.output_message.content ) + if media_type == MEDIA_TYPE_TEXT: + yield ( + f"data: {json.dumps({'event': 'turn_complete', 'data': {'token': full_response}})}\n\n" + ) + else: + yield format_stream_data( + { + "event": "turn_complete", + "data": {"token": full_response}, + } + ) + # ----------------------------------- # Shield handling # ----------------------------------- -def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]: +def _handle_shield_event( + chunk: Any, chunk_id: int, media_type: str = MEDIA_TYPE_JSON +) -> Iterator[str]: """ Yield shield event. @@ -334,15 +461,13 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]: if chunk.event.payload.event_type == "step_complete": violation = chunk.event.payload.step_details.violation if not violation: - yield format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": "No Violation", - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": "No Violation", + }, + event_type=LLM_TOKEN_EVENT, + media_type=media_type, ) else: # Metric for LLM validation errors @@ -350,22 +475,22 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]: violation = ( f"Violation: {violation.user_message} (Metadata: {violation.metadata})" ) - yield format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": violation, - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": violation, + }, + event_type=LLM_TOKEN_EVENT, + media_type=media_type, ) # ----------------------------------- # Inference handling # ----------------------------------- -def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]: +def _handle_inference_event( + chunk: Any, chunk_id: int, media_type: str = MEDIA_TYPE_JSON +) -> Iterator[str]: """ Yield inference step event. @@ -377,52 +502,44 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]: Supports both string and ToolCall object tool calls. """ if chunk.event.payload.event_type == "step_start": - yield format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": "", - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": "", + }, + event_type=LLM_TOKEN_EVENT, + media_type=media_type, ) elif chunk.event.payload.event_type == "step_progress": if chunk.event.payload.delta.type == "tool_call": if isinstance(chunk.event.payload.delta.tool_call, str): - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": chunk.event.payload.delta.tool_call, - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": chunk.event.payload.delta.tool_call, + }, + event_type=LLM_TOOL_CALL_EVENT, + media_type=media_type, ) elif isinstance(chunk.event.payload.delta.tool_call, ToolCall): - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": chunk.event.payload.delta.tool_call.tool_name, - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": chunk.event.payload.delta.tool_call.tool_name, + }, + event_type=LLM_TOOL_CALL_EVENT, + media_type=media_type, ) elif chunk.event.payload.delta.type == "text": - yield format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": chunk.event.payload.delta.text, - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": chunk.event.payload.delta.text, + }, + event_type=LLM_TOKEN_EVENT, + media_type=media_type, ) @@ -431,7 +548,7 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]: # ----------------------------------- # pylint: disable=R1702,R0912 def _handle_tool_execution_event( - chunk: Any, chunk_id: int, metadata_map: dict + chunk: Any, chunk_id: int, metadata_map: dict, media_type: str = MEDIA_TYPE_JSON ) -> Iterator[str]: """ Yield tool call event. @@ -454,48 +571,42 @@ def _handle_tool_execution_event( events and responses. """ if chunk.event.payload.event_type == "step_start": - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": "", - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": "", + }, + event_type=LLM_TOOL_CALL_EVENT, + media_type=media_type, ) elif chunk.event.payload.event_type == "step_complete": for t in chunk.event.payload.step_details.tool_calls: - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": { - "tool_name": t.tool_name, - "arguments": t.arguments, - }, + yield stream_event( + data={ + "id": chunk_id, + "token": { + "tool_name": t.tool_name, + "arguments": t.arguments, }, - } + }, + event_type=LLM_TOOL_CALL_EVENT, + media_type=media_type, ) for r in chunk.event.payload.step_details.tool_responses: if r.tool_name == "query_from_memory": inserted_context = interleaved_content_as_str(r.content) - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": { - "tool_name": r.tool_name, - "response": f"Fetched {len(inserted_context)} bytes from memory", - }, + yield stream_event( + data={ + "id": chunk_id, + "token": { + "tool_name": r.tool_name, + "response": f"Fetched {len(inserted_context)} bytes from memory", }, - } + }, + event_type=LLM_TOOL_RESULT_EVENT, + media_type=media_type, ) elif r.tool_name == DEFAULT_RAG_TOOL and r.content: @@ -518,40 +629,38 @@ def _handle_tool_execution_event( match, ) - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": { - "tool_name": r.tool_name, - "summary": summary, - }, + yield stream_event( + data={ + "id": chunk_id, + "token": { + "tool_name": r.tool_name, + "summary": summary, }, - } + }, + event_type=LLM_TOOL_RESULT_EVENT, + media_type=media_type, ) else: - yield format_stream_data( - { - "event": "tool_call", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": { - "tool_name": r.tool_name, - "response": interleaved_content_as_str(r.content), - }, + yield stream_event( + data={ + "id": chunk_id, + "token": { + "tool_name": r.tool_name, + "response": interleaved_content_as_str(r.content), }, - } + }, + event_type=LLM_TOOL_RESULT_EVENT, + media_type=media_type, ) # ----------------------------------- # Catch-all for everything else # ----------------------------------- -def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: +def _handle_heartbeat_event( + chunk_id: int, media_type: str = MEDIA_TYPE_JSON +) -> Iterator[str]: """ Yield a heartbeat event. @@ -565,20 +674,19 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: Yields: str: SSE-formatted heartbeat event string. """ - yield format_stream_data( - { - "event": "heartbeat", - "data": { - "id": chunk_id, - "token": "heartbeat", - }, - } + yield stream_event( + data={ + "id": chunk_id, + "token": "heartbeat", + }, + event_type=LLM_TOKEN_EVENT, + media_type=media_type, ) @router.post("/streaming_query", responses=streaming_query_responses) @authorize(Action.STREAMING_QUERY) -async def streaming_query_endpoint_handler( # pylint: disable=R0915,R0914 +async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals,too-many-statements request: Request, query_request: QueryRequest, auth: Annotated[AuthTuple, Depends(get_auth_dependency())], @@ -673,7 +781,10 @@ async def response_generator( llm_response="No response from the model", tool_calls=[] ) - # Send start event + # Determine media type for response formatting + media_type = query_request.media_type or MEDIA_TYPE_JSON + + # Send start event at the beginning of the stream yield stream_start_event(conversation_id) async for chunk in turn_response: @@ -695,11 +806,13 @@ async def response_generator( if p.step_details.step_type == "tool_execution": summary.append_tool_calls_from_llama(p.step_details) - for event in stream_build_event(chunk, chunk_id, metadata_map): + for event in stream_build_event( + chunk, chunk_id, metadata_map, media_type, conversation_id + ): chunk_id += 1 yield event - yield stream_end_event(metadata_map, summary) + yield stream_end_event(metadata_map, media_type) if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") @@ -758,7 +871,12 @@ async def response_generator( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() - return StreamingResponse(response_generator(response)) + # Determine media type for response + # Note: The HTTP Content-Type header is always text/event-stream for SSE, + # but the media_type parameter controls how the content is formatted + return StreamingResponse( + response_generator(response), media_type="text/event-stream" + ) # connection to Llama Stack server except APIConnectionError as e: # Update metrics for the LLM call failure @@ -771,6 +889,21 @@ async def response_generator( "cause": str(e), }, ) from e + except Exception as e: # pylint: disable=broad-except + # Handle other errors with OLS-compatible error response + # This broad exception catch is intentional to ensure all errors + # are converted to OLS-compatible streaming responses + media_type = query_request.media_type or MEDIA_TYPE_JSON + error_response = generic_llm_error(e, media_type) + + async def error_generator() -> AsyncGenerator[str, None]: + yield error_response + + # Use text/event-stream for SSE-formatted JSON responses, text/plain for plain text + content_type = ( + "text/event-stream" if media_type == MEDIA_TYPE_JSON else "text/plain" + ) + return StreamingResponse(error_generator(), media_type=content_type) async def retrieve_response( diff --git a/src/constants.py b/src/constants.py index 4d4b3237..5d67ecee 100644 --- a/src/constants.py +++ b/src/constants.py @@ -116,6 +116,10 @@ # default RAG tool value DEFAULT_RAG_TOOL = "knowledge_search" +# Media type constants for streaming responses +MEDIA_TYPE_JSON = "application/json" +MEDIA_TYPE_TEXT = "text/plain" + # PostgreSQL connection constants # See: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLMODE POSTGRES_DEFAULT_SSL_MODE = "prefer" diff --git a/src/models/requests.py b/src/models/requests.py index fcfaecca..5032fe34 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -8,6 +8,7 @@ from log import get_logger from utils import suid +from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT logger = get_logger(__name__) @@ -80,6 +81,7 @@ class QueryRequest(BaseModel): system_prompt: The optional system prompt. attachments: The optional attachments. no_tools: Whether to bypass all tools and MCP servers (default: False). + media_type: The optional media type for response format (application/json or text/plain). Example: ```python @@ -144,12 +146,10 @@ class QueryRequest(BaseModel): examples=[True, False], ) - # media_type is not used in 'lightspeed-stack' that only supports application/json. - # the field is kept here to enable compatibility with 'road-core' clients. media_type: Optional[str] = Field( None, - description="Media type (used just to enable compatibility)", - examples=["application/json"], + description="Media type for the response format", + examples=[MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT], ) # provides examples for /docs endpoint @@ -214,10 +214,13 @@ def validate_provider_and_model(self) -> Self: @model_validator(mode="after") def validate_media_type(self) -> Self: - """Log use of media_type that is unsupported but kept for backward compatibility.""" - if self.media_type: - logger.warning( - "media_type was set in the request but is not supported. The value will be ignored." + """Validate media_type field.""" + if self.media_type and self.media_type not in [ + MEDIA_TYPE_JSON, + MEDIA_TYPE_TEXT, + ]: + raise ValueError( + f"media_type must be either '{MEDIA_TYPE_JSON}' or '{MEDIA_TYPE_TEXT}'" ) return self diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index eda89d49..59410398 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -40,6 +40,13 @@ streaming_query_endpoint_handler, retrieve_response, stream_build_event, + stream_event, + stream_end_event, + prompt_too_long_error, + generic_llm_error, + LLM_TOKEN_EVENT, + LLM_TOOL_CALL_EVENT, + LLM_TOOL_RESULT_EVENT, ) from authorization.resolvers import NoopRolesResolver @@ -47,6 +54,7 @@ from models.requests import QueryRequest, Attachment from models.responses import RAGChunk from utils.types import ToolCallSummary, TurnSummary +from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT MOCK_AUTH = ( "017adfa4-7cc6-46e4-b663-3653e1ae69df", @@ -190,11 +198,13 @@ async def test_streaming_query_endpoint_on_connection_error(mocker): "type": "http", } ) - # await the async function - with pytest.raises(HTTPException) as e: - await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH) - assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert e.detail["response"] == "Configuration is not loaded" + # await the async function - should return a streaming response with error + response = await streaming_query_endpoint_handler( + request, query_request, auth=MOCK_AUTH + ) + + assert isinstance(response, StreamingResponse) + assert response.media_type == "text/event-stream" async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False): @@ -792,9 +802,8 @@ def test_stream_build_event_turn_start(): assert result is not None assert "data: " in result - assert '"event": "token"' in result - assert '"token": ""' in result - assert '"id": 0' in result + assert '"event": "start"' in result + assert '"conversation_id"' in result def test_stream_build_event_turn_awaiting_input(): @@ -826,9 +835,8 @@ def test_stream_build_event_turn_awaiting_input(): assert result is not None assert "data: " in result - assert '"event": "token"' in result - assert '"token": ""' in result - assert '"id": 0' in result + assert '"event": "start"' in result + assert '"conversation_id"' in result def test_stream_build_event_turn_complete(): @@ -862,7 +870,6 @@ def test_stream_build_event_turn_complete(): assert "data: " in result assert '"event": "turn_complete"' in result assert '"token": "content"' in result - assert '"id": 0' in result def test_stream_build_event_shield_call_step_complete_no_violation(mocker): @@ -894,7 +901,7 @@ def test_stream_build_event_shield_call_step_complete_no_violation(mocker): assert "data: " in result assert '"event": "token"' in result assert '"token": "No Violation"' in result - assert '"role": "shield_call"' in result + # Role field removed for OLS compatibility assert '"id": 0' in result # Assert that the metric for validation errors is NOT incremented mock_metric.inc.assert_not_called() @@ -937,7 +944,7 @@ def test_stream_build_event_shield_call_step_complete_with_violation(mocker): '"token": "Violation: I don\'t like the cut of your jib (Metadata: {})"' in result ) - assert '"role": "shield_call"' in result + # Role field removed for OLS compatibility assert '"id": 0' in result # Assert that the metric for validation errors is incremented mock_metric.inc.assert_called_once() @@ -965,7 +972,7 @@ def test_stream_build_event_step_progress(): assert "data: " in result assert '"event": "token"' in result assert '"token": "This is a test response"' in result - assert '"role": "inference"' in result + # Role field removed for OLS compatibility assert '"id": 0' in result @@ -993,7 +1000,7 @@ def test_stream_build_event_step_progress_tool_call_str(): assert "data: " in result assert '"event": "tool_call"' in result assert '"token": "tool-called"' in result - assert '"role": "inference"' in result + # Role field removed for OLS compatibility assert '"id": 0' in result @@ -1025,7 +1032,7 @@ def test_stream_build_event_step_progress_tool_call_tool_call(): assert "data: " in result assert '"event": "tool_call"' in result assert '"token": "my-tool"' in result - assert '"role": "inference"' in result + # Role field removed for OLS compatibility assert '"id": 0' in result @@ -1077,7 +1084,7 @@ def test_stream_build_event_step_complete(): '"token": {"tool_name": "knowledge_search", ' '"summary": "knowledge_search tool found 2 chunks:"}' in result ) - assert '"role": "tool_execution"' in result + # Role field removed for OLS compatibility assert '"id": 0' in result @@ -1117,7 +1124,7 @@ def test_stream_build_event_returns_heartbeat(): assert result is not None assert '"id": 0' in result - assert '"event": "heartbeat"' in result + assert '"event": "token"' in result assert '"token": "heartbeat"' in result @@ -1715,3 +1722,272 @@ async def test_streaming_query_handles_none_event(mocker): request, query_request, auth=MOCK_AUTH ) assert isinstance(response, StreamingResponse) + + +# ============================================================================ +# OLS Compatibility Tests +# ============================================================================ + + +class TestOLSStreamEventFormatting: + """Test the stream_event function for both media types (OLS compatibility).""" + + def test_stream_event_json_token(self): + """Test token event formatting for JSON media type.""" + data = {"id": 0, "token": "Hello"} + result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_JSON) + + expected = 'data: {"event": "token", "data": {"id": 0, "token": "Hello"}}\n\n' + assert result == expected + + def test_stream_event_text_token(self): + """Test token event formatting for text media type.""" + + data = {"id": 0, "token": "Hello"} + result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_TEXT) + + assert result == "Hello" + + def test_stream_event_json_tool_call(self): + """Test tool call event formatting for JSON media type.""" + + data = { + "id": 0, + "token": {"tool_name": "search", "arguments": {"query": "test"}}, + } + result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_JSON) + + expected = ( + 'data: {"event": "tool_call", "data": {"id": 0, "token": ' + '{"tool_name": "search", "arguments": {"query": "test"}}}}\n\n' + ) + assert result == expected + + def test_stream_event_text_tool_call(self): + """Test tool call event formatting for text media type.""" + + data = { + "id": 0, + "token": {"tool_name": "search", "arguments": {"query": "test"}}, + } + result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_TEXT) + + expected = ( + '\nTool call: {"id": 0, "token": ' + '{"tool_name": "search", "arguments": {"query": "test"}}}\n' + ) + assert result == expected + + def test_stream_event_json_tool_result(self): + """Test tool result event formatting for JSON media type.""" + + data = { + "id": 0, + "token": {"tool_name": "search", "response": "Found results"}, + } + result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_JSON) + + expected = ( + 'data: {"event": "tool_result", "data": {"id": 0, "token": ' + '{"tool_name": "search", "response": "Found results"}}}\n\n' + ) + assert result == expected + + def test_stream_event_text_tool_result(self): + """Test tool result event formatting for text media type.""" + + data = { + "id": 0, + "token": {"tool_name": "search", "response": "Found results"}, + } + result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_TEXT) + + expected = ( + '\nTool result: {"id": 0, "token": ' + '{"tool_name": "search", "response": "Found results"}}\n' + ) + assert result == expected + + def test_stream_event_unknown_type(self): + """Test handling of unknown event types.""" + + data = {"id": 0, "token": "test"} + result = stream_event(data, "unknown_event", MEDIA_TYPE_TEXT) + + assert result == "" + + +class TestOLSStreamEndEvent: + """Test the stream_end_event function for both media types (OLS compatibility).""" + + def test_stream_end_event_json(self): + """Test end event formatting for JSON media type.""" + + metadata_map = { + "doc1": {"title": "Test Doc 1", "docs_url": "https://example.com/doc1"}, + "doc2": {"title": "Test Doc 2", "docs_url": "https://example.com/doc2"}, + } + result = stream_end_event(metadata_map, MEDIA_TYPE_JSON) + + # Parse the result to verify structure + data_part = result.replace("data: ", "").strip() + parsed = json.loads(data_part) + + assert parsed["event"] == "end" + assert "referenced_documents" in parsed["data"] + assert len(parsed["data"]["referenced_documents"]) == 2 + assert parsed["data"]["referenced_documents"][0]["doc_title"] == "Test Doc 1" + assert ( + parsed["data"]["referenced_documents"][0]["doc_url"] + == "https://example.com/doc1" + ) + assert "available_quotas" in parsed + + def test_stream_end_event_text(self): + """Test end event formatting for text media type.""" + + metadata_map = { + "doc1": {"title": "Test Doc 1", "docs_url": "https://example.com/doc1"}, + "doc2": {"title": "Test Doc 2", "docs_url": "https://example.com/doc2"}, + } + result = stream_end_event(metadata_map, MEDIA_TYPE_TEXT) + + expected = ( + "\n\n---\n\nTest Doc 1: https://example.com/doc1\n" + "Test Doc 2: https://example.com/doc2" + ) + assert result == expected + + def test_stream_end_event_text_no_docs(self): + """Test end event formatting for text media type with no documents.""" + + metadata_map = {} + result = stream_end_event(metadata_map, MEDIA_TYPE_TEXT) + + assert result == "" + + +class TestOLSErrorHandling: + """Test error handling functions (OLS compatibility).""" + + def test_prompt_too_long_error_json(self): + """Test prompt too long error for JSON media type.""" + + error = Exception("Prompt exceeds maximum length") + result = prompt_too_long_error(error, MEDIA_TYPE_JSON) + + data_part = result.replace("data: ", "").strip() + parsed = json.loads(data_part) + assert parsed["event"] == "error" + assert parsed["data"]["status_code"] == 413 + assert parsed["data"]["response"] == "Prompt is too long" + assert parsed["data"]["cause"] == "Prompt exceeds maximum length" + + def test_prompt_too_long_error_text(self): + """Test prompt too long error for text media type.""" + + error = Exception("Prompt exceeds maximum length") + result = prompt_too_long_error(error, MEDIA_TYPE_TEXT) + + assert result == "Prompt is too long: Prompt exceeds maximum length" + + def test_generic_llm_error_json(self): + """Test generic LLM error for JSON media type.""" + + error = Exception("Connection failed") + result = generic_llm_error(error, MEDIA_TYPE_JSON) + + data_part = result.replace("data: ", "").strip() + parsed = json.loads(data_part) + assert parsed["event"] == "error" + assert parsed["data"]["response"] == "Internal server error" + assert parsed["data"]["cause"] == "Connection failed" + + def test_generic_llm_error_text(self): + """Test generic LLM error for text media type.""" + + error = Exception("Connection failed") + result = generic_llm_error(error, MEDIA_TYPE_TEXT) + + assert result == "Error: Connection failed" + + +class TestOLSCompatibilityIntegration: + """Integration tests for OLS compatibility.""" + + def test_media_type_validation(self): + """Test that media type validation works correctly.""" + + # Valid media types + valid_request = QueryRequest(query="test", media_type="application/json") + assert valid_request.media_type == "application/json" + + valid_request = QueryRequest(query="test", media_type="text/plain") + assert valid_request.media_type == "text/plain" + + # Invalid media type should raise error + with pytest.raises(ValueError, match="media_type must be either"): + QueryRequest(query="test", media_type="invalid/type") + + def test_ols_event_structure(self): + """Test that events follow OLS structure.""" + + # Test token event structure + token_data = {"id": 0, "token": "Hello"} + token_event = stream_event(token_data, LLM_TOKEN_EVENT, MEDIA_TYPE_JSON) + + data_part = token_event.replace("data: ", "").strip() + parsed = json.loads(data_part) + + assert parsed["event"] == "token" + assert "id" in parsed["data"] + assert "token" in parsed["data"] + assert "role" not in parsed["data"] # Role field is not included + + # Test tool call event structure + tool_data = { + "id": 0, + "token": {"tool_name": "search", "arguments": {"query": "test"}}, + } + tool_event = stream_event(tool_data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_JSON) + + data_part = tool_event.replace("data: ", "").strip() + parsed = json.loads(data_part) + + assert parsed["event"] == "tool_call" + assert "id" in parsed["data"] + assert "role" not in parsed["data"] + assert "token" in parsed["data"] + + # Test tool result event structure + result_data = { + "id": 0, + "token": {"tool_name": "search", "response": "Found results"}, + } + result_event = stream_event(result_data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_JSON) + + data_part = result_event.replace("data: ", "").strip() + parsed = json.loads(data_part) + + assert parsed["event"] == "tool_result" + assert "id" in parsed["data"] + assert "role" not in parsed["data"] + assert "token" in parsed["data"] + + def test_ols_end_event_structure(self): + """Test that end event follows OLS structure.""" + + metadata_map = { + "doc1": {"title": "Test Doc", "docs_url": "https://example.com/doc"} + } + + end_event = stream_end_event(metadata_map, MEDIA_TYPE_JSON) + data_part = end_event.replace("data: ", "").strip() + parsed = json.loads(data_part) + + assert parsed["event"] == "end" + assert "referenced_documents" in parsed["data"] + assert "truncated" in parsed["data"] + assert "input_tokens" in parsed["data"] + assert "output_tokens" in parsed["data"] + assert "available_quotas" in parsed # At root level, not inside data diff --git a/tests/unit/models/requests/test_query_request.py b/tests/unit/models/requests/test_query_request.py index e0b6fb07..55176337 100644 --- a/tests/unit/models/requests/test_query_request.py +++ b/tests/unit/models/requests/test_query_request.py @@ -152,6 +152,5 @@ def test_validate_media_type(self, mocker) -> None: assert qr.model == "gpt-3.5-turbo" assert qr.media_type == "text/plain" - mock_logger.warning.assert_called_once_with( - "media_type was set in the request but is not supported. The value will be ignored." - ) + # Media type is now fully supported, no warning expected + mock_logger.warning.assert_not_called()