-
Notifications
You must be signed in to change notification settings - Fork 48
LCORE-178 Token counting #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
53a5441
ac90f08
4c1bd95
67269cb
ebd1424
4ef7798
cd5bb03
3c5a3e1
5818297
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| from utils.endpoints import check_configuration_loaded, get_system_prompt | ||
| from utils.mcp_headers import mcp_headers_dependency | ||
| from utils.suid import get_suid | ||
| from utils.token_counter import get_token_counter | ||
|
|
||
| logger = logging.getLogger("app.endpoints.handlers") | ||
| router = APIRouter(tags=["query"]) | ||
|
|
@@ -106,7 +107,7 @@ def query_endpoint_handler( | |
| # try to get Llama Stack client | ||
| client = LlamaStackClientHolder().get_client() | ||
| model_id = select_model_id(client.models.list(), query_request) | ||
| response, conversation_id = retrieve_response( | ||
| response, conversation_id, token_usage = retrieve_response( | ||
| client, | ||
| model_id, | ||
| query_request, | ||
|
|
@@ -129,7 +130,12 @@ def query_endpoint_handler( | |
| attachments=query_request.attachments or [], | ||
| ) | ||
|
|
||
| return QueryResponse(conversation_id=conversation_id, response=response) | ||
| return QueryResponse( | ||
| conversation_id=conversation_id, | ||
| response=response, | ||
| input_tokens=token_usage["input_tokens"], | ||
| output_tokens=token_usage["output_tokens"], | ||
| ) | ||
|
|
||
| # connection to Llama Stack server | ||
| except APIConnectionError as e: | ||
|
|
@@ -187,13 +193,21 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s | |
| return model_id | ||
|
|
||
|
|
||
| def _build_toolgroups(client: LlamaStackClient) -> list[Toolgroup] | None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it part of this PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was to alleviate a "too many local variables" error that appeared while I was fixing merge conflicts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably we can just # pylint: disable=too-many-locals for this |
||
| """Build toolgroups from vector DBs and MCP servers.""" | ||
| vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] | ||
| return (get_rag_toolgroups(vector_db_ids) or []) + [ | ||
| mcp_server.name for mcp_server in configuration.mcp_servers | ||
| ] | ||
|
|
||
|
|
||
| def retrieve_response( | ||
| client: LlamaStackClient, | ||
| model_id: str, | ||
| query_request: QueryRequest, | ||
| token: str, | ||
| mcp_headers: dict[str, dict[str, str]] | None = None, | ||
| ) -> tuple[str, str]: | ||
| ) -> tuple[str, str, dict[str, int]]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_shields = [shield.identifier for shield in client.shields.list()] | ||
| if not available_shields: | ||
|
|
@@ -235,19 +249,36 @@ def retrieve_response( | |
| ), | ||
| } | ||
|
|
||
| vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] | ||
| toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ | ||
| mcp_server.name for mcp_server in configuration.mcp_servers | ||
| ] | ||
| response = agent.create_turn( | ||
| messages=[UserMessage(role="user", content=query_request.query)], | ||
| session_id=conversation_id, | ||
| documents=query_request.get_documents(), | ||
| stream=False, | ||
| toolgroups=toolgroups or None, | ||
| toolgroups=_build_toolgroups(client) or None, | ||
| ) | ||
|
|
||
| return str(response.output_message.content), conversation_id # type: ignore[union-attr] | ||
| response_content = str(response.output_message.content) # type: ignore[union-attr] | ||
|
|
||
| # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it | ||
| # try: | ||
| # token_usage = { | ||
| # "input_tokens": response.usage.get("prompt_tokens", 0), | ||
| # "output_tokens": response.usage.get("completion_tokens", 0), | ||
| # } | ||
| # except AttributeError: | ||
| # Estimate token usage | ||
| try: | ||
| token_usage = get_token_counter(model_id).count_turn_tokens( | ||
| system_prompt, query_request.query, response_content | ||
| ) | ||
| except Exception as e: # pylint: disable=broad-exception-caught | ||
| logger.warning("Failed to estimate token usage: %s", e) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This probably should be an error instead of a warning |
||
| token_usage = { | ||
| "input_tokens": 0, | ||
| "output_tokens": 0, | ||
| } | ||
|
|
||
| return response_content, conversation_id, token_usage | ||
|
|
||
|
|
||
| def validate_attachments_metadata(attachments: list[Attachment]) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| from utils.common import retrieve_user_id | ||
| from utils.mcp_headers import mcp_headers_dependency | ||
| from utils.suid import get_suid | ||
| from utils.token_counter import get_token_counter | ||
| from utils.types import GraniteToolParser | ||
|
|
||
| from app.endpoints.query import ( | ||
|
|
@@ -95,8 +96,13 @@ def stream_start_event(conversation_id: str) -> str: | |
| ) | ||
|
|
||
|
|
||
| def stream_end_event(metadata_map: dict) -> str: | ||
| """Yield the end of the data stream.""" | ||
| def stream_end_event(metadata_map: dict, metrics_map: dict[str, int]) -> str: | ||
| """Yield the end of the data stream. | ||
|
|
||
| Args: | ||
| metadata_map: Dictionary containing metadata about referenced documents | ||
| metrics_map: Dictionary containing metrics like 'input_tokens' and 'output_tokens' | ||
| """ | ||
| return format_stream_data( | ||
| { | ||
| "event": "end", | ||
|
|
@@ -112,8 +118,8 @@ def stream_end_event(metadata_map: dict) -> str: | |
| ) | ||
| ], | ||
| "truncated": None, # TODO(jboos): implement truncated | ||
| "input_tokens": 0, # TODO(jboos): implement input tokens | ||
| "output_tokens": 0, # TODO(jboos): implement output tokens | ||
| "input_tokens": metrics_map.get("input_tokens", 0), | ||
| "output_tokens": metrics_map.get("output_tokens", 0), | ||
| }, | ||
| "available_quotas": {}, # TODO(jboos): implement available quotas | ||
| } | ||
|
|
@@ -200,7 +206,7 @@ async def streaming_query_endpoint_handler( | |
| # try to get Llama Stack client | ||
| client = AsyncLlamaStackClientHolder().get_client() | ||
| model_id = select_model_id(await client.models.list(), query_request) | ||
| response, conversation_id = await retrieve_response( | ||
| response, conversation_id, token_usage = await retrieve_response( | ||
| client, | ||
| model_id, | ||
| query_request, | ||
|
|
@@ -225,7 +231,25 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: | |
| chunk_id += 1 | ||
| yield event | ||
|
|
||
| yield stream_end_event(metadata_map) | ||
| # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate | ||
| # try: | ||
| # output_tokens = response.usage.get("completion_tokens", 0) | ||
| # except AttributeError: | ||
| # Estimate output tokens from complete response | ||
| try: | ||
| output_tokens = get_token_counter(model_id).count_tokens( | ||
| complete_response | ||
| ) | ||
| logger.debug("Estimated output tokens: %s", output_tokens) | ||
| except Exception as e: # pylint: disable=broad-exception-caught | ||
| logger.warning("Failed to estimate output tokens: %s", e) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: s/warning/error/g |
||
| output_tokens = 0 | ||
|
|
||
| metrics_map = { | ||
| "input_tokens": token_usage["input_tokens"], | ||
| "output_tokens": output_tokens, | ||
| } | ||
| yield stream_end_event(metadata_map, metrics_map) | ||
|
|
||
| if not is_transcripts_enabled(): | ||
| logger.debug("Transcript collection is disabled in the configuration") | ||
|
|
@@ -262,7 +286,7 @@ async def retrieve_response( | |
| query_request: QueryRequest, | ||
| token: str, | ||
| mcp_headers: dict[str, dict[str, str]] | None = None, | ||
| ) -> tuple[Any, str]: | ||
| ) -> tuple[Any, str, dict[str, int]]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_shields = [shield.identifier for shield in await client.shields.list()] | ||
| if not available_shields: | ||
|
|
@@ -319,4 +343,23 @@ async def retrieve_response( | |
| toolgroups=toolgroups or None, | ||
| ) | ||
|
|
||
| return response, conversation_id | ||
| # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it | ||
| # try: | ||
| # token_usage = { | ||
| # "input_tokens": response.usage.get("prompt_tokens", 0), | ||
| # "output_tokens": 0, # Will be calculated during streaming | ||
| # } | ||
| # except AttributeError: | ||
| # # Estimate input tokens (Output will be calculated during streaming) | ||
| try: | ||
| token_usage = get_token_counter(model_id).count_turn_tokens( | ||
| system_prompt, query_request.query | ||
| ) | ||
| except Exception as e: # pylint: disable=broad-exception-caught | ||
| logger.warning("Failed to estimate token usage: %s", e) | ||
| token_usage = { | ||
| "input_tokens": 0, | ||
| "output_tokens": 0, | ||
| } | ||
|
|
||
| return response, conversation_id, token_usage | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,6 +154,7 @@ class Configuration(BaseModel): | |
| AuthenticationConfiguration() | ||
| ) | ||
| customization: Optional[Customization] = None | ||
| default_estimation_tokenizer: str = "cl100k_base" | ||
|
||
|
|
||
| def dump(self, filename: str = "configuration.json") -> None: | ||
| """Dump actual configuration into JSON file.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,8 +16,6 @@ class ModelsResponse(BaseModel): | |
| # - referenced_documents: The optional URLs and titles for the documents used | ||
| # to generate the response. | ||
| # - truncated: Set to True if conversation history was truncated to be within context window. | ||
| # - input_tokens: Number of tokens sent to LLM | ||
| # - output_tokens: Number of tokens received from LLM | ||
| # - available_quotas: Quota available as measured by all configured quota limiters | ||
| # - tool_calls: List of tool requests. | ||
| # - tool_results: List of tool results. | ||
|
|
@@ -28,10 +26,14 @@ class QueryResponse(BaseModel): | |
| Attributes: | ||
| conversation_id: The optional conversation ID (UUID). | ||
| response: The response. | ||
| input_tokens: Number of tokens sent to LLM. | ||
| output_tokens: Number of tokens received from LLM. | ||
| """ | ||
|
|
||
| conversation_id: Optional[str] = None | ||
| response: str | ||
| input_tokens: Optional[int] = None | ||
|
||
| output_tokens: Optional[int] = None | ||
|
|
||
| # provides examples for /docs endpoint | ||
| model_config = { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prob. worth to setup it in
constants.pyand allow to override it by anything else incustomizationblock? WDYT?