From 0c5e610c564d3677bb059bafd093b998cd922c5f Mon Sep 17 00:00:00 2001 From: Haoyu Sun Date: Tue, 2 Sep 2025 22:48:53 +0200 Subject: [PATCH] LCORE-411: add token usage metrics Signed-off-by: Haoyu Sun --- src/app/endpoints/query.py | 11 ++++- src/app/endpoints/streaming_query.py | 8 +++ src/metrics/utils.py | 31 +++++++++++- tests/unit/app/endpoints/test_query.py | 23 +++++++++ .../app/endpoints/test_streaming_query.py | 10 ++++ tests/unit/metrics/test_utis.py | 49 ++++++++++++++++++- 6 files changed, 128 insertions(+), 4 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 0a465e55..f29b65fc 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -24,6 +24,7 @@ from configuration import configuration from app.database import get_session import metrics +from metrics.utils import update_llm_token_count_from_turn import constants from authorization.middleware import authorize from models.config import Action @@ -218,6 +219,7 @@ async def query_endpoint_handler( query_request, token, mcp_headers=mcp_headers, + provider_id=provider_id, ) # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() @@ -387,12 +389,14 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) -async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches +async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest, token: str, mcp_headers: dict[str, dict[str, str]] | None = None, + *, + provider_id: str = "", ) -> tuple[TurnSummary, str]: """ Retrieve response from LLMs and agents. @@ -411,6 +415,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche Parameters: model_id (str): The identifier of the LLM model to use. + provider_id (str): The identifier of the LLM provider to use. query_request (QueryRequest): The user's query and associated metadata. token (str): The authentication token for authorization. mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. @@ -510,6 +515,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche tool_calls=[], ) + # Update token count metrics for the LLM call + model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id + update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt) + # Check for validation errors in the response steps = response.steps or [] for step in steps: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 4aafc6f9..a78e21b3 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -26,6 +26,7 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration import metrics +from metrics.utils import update_llm_token_count_from_turn from models.config import Action from models.requests import QueryRequest from models.database.conversations import UserConversation @@ -621,6 +622,13 @@ async def response_generator( summary.llm_response = interleaved_content_as_str( p.turn.output_message.content ) + system_prompt = get_system_prompt(query_request, configuration) + try: + update_llm_token_count_from_turn( + p.turn, model_id, provider_id, system_prompt + ) + except Exception: # pylint: disable=broad-except + logger.exception("Failed to update token usage metrics") elif p.event_type == "step_complete": if p.step_details.step_type == "tool_execution": summary.append_tool_calls_from_llama(p.step_details) diff --git a/src/metrics/utils.py b/src/metrics/utils.py index aceddd82..2ba51645 100644 --- a/src/metrics/utils.py +++ b/src/metrics/utils.py @@ -1,9 +1,16 @@ """Utility functions for metrics handling.""" -from configuration import configuration +from typing import cast + +from llama_stack.models.llama.datatypes import RawMessage +from llama_stack.models.llama.llama3.chat_format import ChatFormat +from llama_stack.models.llama.llama3.tokenizer import Tokenizer +from llama_stack_client.types.agents.turn import Turn + +import metrics from client import AsyncLlamaStackClientHolder +from configuration import configuration from log import get_logger -import metrics from utils.common import run_once_async logger = get_logger(__name__) @@ -48,3 +55,23 @@ async def setup_model_metrics() -> None: default_model_value, ) logger.info("Model metrics setup complete") + + +def update_llm_token_count_from_turn( + turn: Turn, model: str, provider: str, system_prompt: str = "" +) -> None: + """Update the LLM calls metrics from a turn.""" + tokenizer = Tokenizer.get_instance() + formatter = ChatFormat(tokenizer) + + raw_message = cast(RawMessage, turn.output_message) + encoded_output = formatter.encode_dialog_prompt([raw_message]) + token_count = len(encoded_output.tokens) if encoded_output.tokens else 0 + metrics.llm_token_received_total.labels(provider, model).inc(token_count) + + input_messages = [RawMessage(role="user", content=system_prompt)] + cast( + list[RawMessage], turn.input_messages + ) + encoded_input = formatter.encode_dialog_prompt(input_messages) + token_count = len(encoded_input.tokens) if encoded_input.tokens else 0 + metrics.llm_token_sent_total.labels(provider, model).inc(token_count) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index b12101b4..d1ecc216 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -46,6 +46,14 @@ def dummy_request() -> Request: return req +def mock_metrics(mocker): + """Helper function to mock metrics operations for query endpoints.""" + mocker.patch( + "app.endpoints.query.update_llm_token_count_from_turn", + return_value=None, + ) + + def mock_database_operations(mocker): """Helper function to mock database operations for query endpoints.""" mocker.patch( @@ -443,6 +451,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -474,6 +483,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -506,6 +516,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -544,6 +555,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -593,6 +605,7 @@ def __repr__(self): "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -645,6 +658,7 @@ def __repr__(self): "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -699,6 +713,7 @@ def __repr__(self): "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -755,6 +770,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" @@ -809,6 +825,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" @@ -864,6 +881,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -933,6 +951,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token( "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -994,6 +1013,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers( "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" @@ -1090,6 +1110,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?") @@ -1326,6 +1347,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?", no_tools=True) model_id = "fake_model_id" @@ -1376,6 +1398,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality( "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), ) + mock_metrics(mocker) query_request = QueryRequest(query="What is OpenStack?", no_tools=False) model_id = "fake_model_id" diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 794e5c18..1770575a 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -58,6 +58,14 @@ def mock_database_operations(mocker): mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details") +def mock_metrics(mocker): + """Helper function to mock metrics operations for streaming query endpoints.""" + mocker.patch( + "app.endpoints.streaming_query.update_llm_token_count_from_turn", + return_value=None, + ) + + SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [ """knowledge_search tool found 2 chunks: BEGIN of knowledge_search tool results. @@ -346,12 +354,14 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) @pytest.mark.asyncio async def test_streaming_query_endpoint_handler(mocker): """Test the streaming query endpoint handler with transcript storage disabled.""" + mock_metrics(mocker) await _test_streaming_query_endpoint_handler(mocker, store_transcript=False) @pytest.mark.asyncio async def test_streaming_query_endpoint_handler_store_transcript(mocker): """Test the streaming query endpoint handler with transcript storage enabled.""" + mock_metrics(mocker) await _test_streaming_query_endpoint_handler(mocker, store_transcript=True) diff --git a/tests/unit/metrics/test_utis.py b/tests/unit/metrics/test_utis.py index b434a0cc..ee6432ca 100644 --- a/tests/unit/metrics/test_utis.py +++ b/tests/unit/metrics/test_utis.py @@ -1,6 +1,6 @@ """Unit tests for functions defined in metrics/utils.py""" -from metrics.utils import setup_model_metrics +from metrics.utils import setup_model_metrics, update_llm_token_count_from_turn async def test_setup_model_metrics(mocker): @@ -74,3 +74,50 @@ async def test_setup_model_metrics(mocker): ], any_order=False, # Order matters here ) + + +def test_update_llm_token_count_from_turn(mocker): + """Test the update_llm_token_count_from_turn function.""" + mocker.patch("metrics.utils.Tokenizer.get_instance") + mock_formatter_class = mocker.patch("metrics.utils.ChatFormat") + mock_formatter = mocker.Mock() + mock_formatter_class.return_value = mock_formatter + + mock_received_metric = mocker.patch( + "metrics.utils.metrics.llm_token_received_total" + ) + mock_sent_metric = mocker.patch("metrics.utils.metrics.llm_token_sent_total") + + mock_turn = mocker.Mock() + # turn.output_message should satisfy the type RawMessage + mock_turn.output_message = {"role": "assistant", "content": "test response"} + # turn.input_messages should satisfy the type list[RawMessage] + mock_turn.input_messages = [{"role": "user", "content": "test input"}] + + # Mock the encoded results with tokens + mock_encoded_output = mocker.Mock() + mock_encoded_output.tokens = ["token1", "token2", "token3"] # 3 tokens + mock_encoded_input = mocker.Mock() + mock_encoded_input.tokens = ["token1", "token2"] # 2 tokens + mock_formatter.encode_dialog_prompt.side_effect = [ + mock_encoded_output, + mock_encoded_input, + ] + + test_model = "test_model" + test_provider = "test_provider" + test_system_prompt = "test system prompt" + + update_llm_token_count_from_turn( + mock_turn, test_model, test_provider, test_system_prompt + ) + + # Verify that llm_token_received_total.labels() was called with correct metrics + mock_received_metric.labels.assert_called_once_with(test_provider, test_model) + mock_received_metric.labels().inc.assert_called_once_with( + 3 + ) # token count from output + + # Verify that llm_token_sent_total.labels() was called with correct metrics + mock_sent_metric.labels.assert_called_once_with(test_provider, test_model) + mock_sent_metric.labels().inc.assert_called_once_with(2) # token count from input