Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from configuration import configuration
from app.database import get_session
import metrics
from metrics.utils import update_llm_token_count_from_turn
import constants
from authorization.middleware import authorize
from models.config import Action
Expand Down Expand Up @@ -218,6 +219,7 @@ async def query_endpoint_handler(
query_request,
token,
mcp_headers=mcp_headers,
provider_id=provider_id,
)
# Update metrics for the LLM call
metrics.llm_calls_total.labels(provider_id, model_id).inc()
Expand Down Expand Up @@ -387,12 +389,14 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments
client: AsyncLlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
*,
provider_id: str = "",
) -> tuple[TurnSummary, str]:
"""
Retrieve response from LLMs and agents.
Expand All @@ -411,6 +415,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche

Parameters:
model_id (str): The identifier of the LLM model to use.
provider_id (str): The identifier of the LLM provider to use.
query_request (QueryRequest): The user's query and associated metadata.
token (str): The authentication token for authorization.
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.
Expand Down Expand Up @@ -510,6 +515,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
tool_calls=[],
)

# Update token count metrics for the LLM call
model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id
update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt)

Comment on lines +518 to +521
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden metrics update to never break the endpoint; use last path segment.

Instrumentation must be best-effort. If Tokenizer/formatter or turn fields are missing, this can raise and 500 the request. Also, rsplit is safer for nested identifiers.

Apply:

-    # Update token count metrics for the LLM call
-    model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id
-    update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt)
+    # Update token count metrics for the LLM call (best-effort; do not disrupt response)
+    try:
+        model_label = model_id.rsplit("/", 1)[-1]
+        update_llm_token_count_from_turn(
+            response, model_label, provider_id or "unknown", system_prompt
+        )
+    except Exception as err:  # pylint: disable=broad-except
+        logger.warning("Failed to update token count metrics: %s", err)

If you prefer, we can centralize this “safe update” in a tiny helper to reuse here and in streaming_query.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Update token count metrics for the LLM call
model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id
update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt)
# Update token count metrics for the LLM call (best-effort; do not disrupt response)
try:
model_label = model_id.rsplit("/", 1)[-1]
update_llm_token_count_from_turn(
response, model_label, provider_id or "unknown", system_prompt
)
except Exception as err: # pylint: disable=broad-except
logger.warning("Failed to update token count metrics: %s", err)
🤖 Prompt for AI Agents
In src/app/endpoints/query.py around lines 518-521, the metrics update can raise
if tokenizer/formatter or expected fields on the turn/response are missing and
using split("/") can pick the wrong segment for nested identifiers; change to
use rsplit("/", 1) to get the last path segment for model_label and wrap the
call to update_llm_token_count_from_turn in a try/except that swallows
exceptions (and logs a debug/warn) so instrumentation is best-effort and cannot
500 the endpoint; optionally factor this logic into a small helper (used here
and in streaming_query) that extracts the safe model_label and invokes the
update inside a non-throwing guard.

# Check for validation errors in the response
steps = response.steps or []
for step in steps:
Expand Down
8 changes: 8 additions & 0 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from client import AsyncLlamaStackClientHolder
from configuration import configuration
import metrics
from metrics.utils import update_llm_token_count_from_turn
from models.config import Action
from models.requests import QueryRequest
from models.database.conversations import UserConversation
Expand Down Expand Up @@ -621,6 +622,13 @@ async def response_generator(
summary.llm_response = interleaved_content_as_str(
p.turn.output_message.content
)
system_prompt = get_system_prompt(query_request, configuration)
try:
update_llm_token_count_from_turn(
p.turn, model_id, provider_id, system_prompt
)
except Exception: # pylint: disable=broad-except
logger.exception("Failed to update token usage metrics")
elif p.event_type == "step_complete":
if p.step_details.step_type == "tool_execution":
summary.append_tool_calls_from_llama(p.step_details)
Expand Down
31 changes: 29 additions & 2 deletions src/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Utility functions for metrics handling."""

from configuration import configuration
from typing import cast

from llama_stack.models.llama.datatypes import RawMessage
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack_client.types.agents.turn import Turn

import metrics
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from log import get_logger
import metrics
from utils.common import run_once_async

logger = get_logger(__name__)
Expand Down Expand Up @@ -48,3 +55,23 @@ async def setup_model_metrics() -> None:
default_model_value,
)
logger.info("Model metrics setup complete")


def update_llm_token_count_from_turn(
turn: Turn, model: str, provider: str, system_prompt: str = ""
) -> None:
"""Update the LLM calls metrics from a turn."""
tokenizer = Tokenizer.get_instance()
formatter = ChatFormat(tokenizer)

raw_message = cast(RawMessage, turn.output_message)
encoded_output = formatter.encode_dialog_prompt([raw_message])
token_count = len(encoded_output.tokens) if encoded_output.tokens else 0
metrics.llm_token_received_total.labels(provider, model).inc(token_count)

input_messages = [RawMessage(role="user", content=system_prompt)] + cast(
list[RawMessage], turn.input_messages
)
encoded_input = formatter.encode_dialog_prompt(input_messages)
token_count = len(encoded_input.tokens) if encoded_input.tokens else 0
metrics.llm_token_sent_total.labels(provider, model).inc(token_count)
23 changes: 23 additions & 0 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def dummy_request() -> Request:
return req


def mock_metrics(mocker):
"""Helper function to mock metrics operations for query endpoints."""
mocker.patch(
"app.endpoints.query.update_llm_token_count_from_turn",
return_value=None,
)


def mock_database_operations(mocker):
"""Helper function to mock database operations for query endpoints."""
mocker.patch(
Expand Down Expand Up @@ -443,6 +451,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -474,6 +483,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -506,6 +516,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -544,6 +555,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -593,6 +605,7 @@ def __repr__(self):
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -645,6 +658,7 @@ def __repr__(self):
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -699,6 +713,7 @@ def __repr__(self):
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -755,6 +770,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
model_id = "fake_model_id"
Expand Down Expand Up @@ -809,6 +825,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
model_id = "fake_model_id"
Expand Down Expand Up @@ -864,6 +881,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker):
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -933,6 +951,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token(
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -994,6 +1013,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"
Expand Down Expand Up @@ -1090,6 +1110,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker):
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?")

Expand Down Expand Up @@ -1326,6 +1347,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag(
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?", no_tools=True)
model_id = "fake_model_id"
Expand Down Expand Up @@ -1376,6 +1398,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
"app.endpoints.query.get_agent",
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
)
mock_metrics(mocker)

query_request = QueryRequest(query="What is OpenStack?", no_tools=False)
model_id = "fake_model_id"
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def mock_database_operations(mocker):
mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details")


def mock_metrics(mocker):
"""Helper function to mock metrics operations for streaming query endpoints."""
mocker.patch(
"app.endpoints.streaming_query.update_llm_token_count_from_turn",
return_value=None,
)


SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [
"""knowledge_search tool found 2 chunks:
BEGIN of knowledge_search tool results.
Expand Down Expand Up @@ -346,12 +354,14 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
@pytest.mark.asyncio
async def test_streaming_query_endpoint_handler(mocker):
"""Test the streaming query endpoint handler with transcript storage disabled."""
mock_metrics(mocker)
await _test_streaming_query_endpoint_handler(mocker, store_transcript=False)


@pytest.mark.asyncio
async def test_streaming_query_endpoint_handler_store_transcript(mocker):
"""Test the streaming query endpoint handler with transcript storage enabled."""
mock_metrics(mocker)
await _test_streaming_query_endpoint_handler(mocker, store_transcript=True)


Expand Down
49 changes: 48 additions & 1 deletion tests/unit/metrics/test_utis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Unit tests for functions defined in metrics/utils.py"""

from metrics.utils import setup_model_metrics
from metrics.utils import setup_model_metrics, update_llm_token_count_from_turn


async def test_setup_model_metrics(mocker):
Expand Down Expand Up @@ -74,3 +74,50 @@ async def test_setup_model_metrics(mocker):
],
any_order=False, # Order matters here
)


def test_update_llm_token_count_from_turn(mocker):
"""Test the update_llm_token_count_from_turn function."""
mocker.patch("metrics.utils.Tokenizer.get_instance")
mock_formatter_class = mocker.patch("metrics.utils.ChatFormat")
mock_formatter = mocker.Mock()
mock_formatter_class.return_value = mock_formatter

mock_received_metric = mocker.patch(
"metrics.utils.metrics.llm_token_received_total"
)
mock_sent_metric = mocker.patch("metrics.utils.metrics.llm_token_sent_total")

mock_turn = mocker.Mock()
# turn.output_message should satisfy the type RawMessage
mock_turn.output_message = {"role": "assistant", "content": "test response"}
# turn.input_messages should satisfy the type list[RawMessage]
mock_turn.input_messages = [{"role": "user", "content": "test input"}]

# Mock the encoded results with tokens
mock_encoded_output = mocker.Mock()
mock_encoded_output.tokens = ["token1", "token2", "token3"] # 3 tokens
mock_encoded_input = mocker.Mock()
mock_encoded_input.tokens = ["token1", "token2"] # 2 tokens
mock_formatter.encode_dialog_prompt.side_effect = [
mock_encoded_output,
mock_encoded_input,
]

test_model = "test_model"
test_provider = "test_provider"
test_system_prompt = "test system prompt"

update_llm_token_count_from_turn(
mock_turn, test_model, test_provider, test_system_prompt
)

# Verify that llm_token_received_total.labels() was called with correct metrics
mock_received_metric.labels.assert_called_once_with(test_provider, test_model)
mock_received_metric.labels().inc.assert_called_once_with(
3
) # token count from output

# Verify that llm_token_sent_total.labels() was called with correct metrics
mock_sent_metric.labels.assert_called_once_with(test_provider, test_model)
mock_sent_metric.labels().inc.assert_called_once_with(2) # token count from input
Loading