Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -939,14 +939,36 @@
"response": {
"type": "string",
"title": "Response"
},
"input_tokens": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Input Tokens"
},
"output_tokens": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Output Tokens"
}
},
"type": "object",
"required": [
"response"
],
"title": "QueryResponse",
"description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.",
"description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.\n input_tokens: Number of tokens sent to LLM.\n output_tokens: Number of tokens received from LLM.",
"examples": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand Down
1 change: 1 addition & 0 deletions lightspeed-stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ user_data_collection:
transcripts_storage: "/tmp/data/transcripts"
authentication:
module: "noop"
default_estimation_tokenizer: "cl100k_base"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob. worth to setup it in constants.py and allow to override it by anything else in customization block? WDYT?

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"llama-stack>=0.2.13",
"rich>=14.0.0",
"cachetools>=6.1.0",
"tiktoken>=0.9.0,<1.0.0",
]

[tool.pyright]
Expand Down
1 change: 1 addition & 0 deletions src/app/endpoints/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
{"name": "server2", "provider_id": "provider2", "url": "http://url.com:2"},
{"name": "server3", "provider_id": "provider3", "url": "http://url.com:3"},
],
"default_estimation_tokenizer": "cl100k_base",
},
503: {
"detail": {
Expand Down
49 changes: 40 additions & 9 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency
from utils.suid import get_suid
from utils.token_counter import get_token_counter

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
Expand Down Expand Up @@ -106,7 +107,7 @@ def query_endpoint_handler(
# try to get Llama Stack client
client = LlamaStackClientHolder().get_client()
model_id = select_model_id(client.models.list(), query_request)
response, conversation_id = retrieve_response(
response, conversation_id, token_usage = retrieve_response(
client,
model_id,
query_request,
Expand All @@ -129,7 +130,12 @@ def query_endpoint_handler(
attachments=query_request.attachments or [],
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=response,
input_tokens=token_usage["input_tokens"],
output_tokens=token_usage["output_tokens"],
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -187,13 +193,21 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
return model_id


def _build_toolgroups(client: LlamaStackClient) -> list[Toolgroup] | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it part of this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was to alleviate a "too many local variables" error that appeared while I was fixing merge conflicts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we can just # pylint: disable=too-many-locals for this

"""Build toolgroups from vector DBs and MCP servers."""
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
return (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]


def retrieve_response(
client: LlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[str, str, dict[str, int]]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in client.shields.list()]
if not available_shields:
Expand Down Expand Up @@ -235,19 +249,36 @@ def retrieve_response(
),
}

vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
response = agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=conversation_id,
documents=query_request.get_documents(),
stream=False,
toolgroups=toolgroups or None,
toolgroups=_build_toolgroups(client) or None,
)

return str(response.output_message.content), conversation_id # type: ignore[union-attr]
response_content = str(response.output_message.content) # type: ignore[union-attr]

# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it
# try:
# token_usage = {
# "input_tokens": response.usage.get("prompt_tokens", 0),
# "output_tokens": response.usage.get("completion_tokens", 0),
# }
# except AttributeError:
# Estimate token usage
try:
token_usage = get_token_counter(model_id).count_turn_tokens(
system_prompt, query_request.query, response_content
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate token usage: %s", e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This probably should be an error instead of a warning

token_usage = {
"input_tokens": 0,
"output_tokens": 0,
}

return response_content, conversation_id, token_usage


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
59 changes: 51 additions & 8 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from utils.common import retrieve_user_id
from utils.mcp_headers import mcp_headers_dependency
from utils.suid import get_suid
from utils.token_counter import get_token_counter
from utils.types import GraniteToolParser

from app.endpoints.query import (
Expand Down Expand Up @@ -95,8 +96,13 @@ def stream_start_event(conversation_id: str) -> str:
)


def stream_end_event(metadata_map: dict) -> str:
"""Yield the end of the data stream."""
def stream_end_event(metadata_map: dict, metrics_map: dict[str, int]) -> str:
"""Yield the end of the data stream.

Args:
metadata_map: Dictionary containing metadata about referenced documents
metrics_map: Dictionary containing metrics like 'input_tokens' and 'output_tokens'
"""
return format_stream_data(
{
"event": "end",
Expand All @@ -112,8 +118,8 @@ def stream_end_event(metadata_map: dict) -> str:
)
],
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
"input_tokens": metrics_map.get("input_tokens", 0),
"output_tokens": metrics_map.get("output_tokens", 0),
},
"available_quotas": {}, # TODO(jboos): implement available quotas
}
Expand Down Expand Up @@ -200,7 +206,7 @@ async def streaming_query_endpoint_handler(
# try to get Llama Stack client
client = AsyncLlamaStackClientHolder().get_client()
model_id = select_model_id(await client.models.list(), query_request)
response, conversation_id = await retrieve_response(
response, conversation_id, token_usage = await retrieve_response(
client,
model_id,
query_request,
Expand All @@ -225,7 +231,25 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
chunk_id += 1
yield event

yield stream_end_event(metadata_map)
# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate
# try:
# output_tokens = response.usage.get("completion_tokens", 0)
# except AttributeError:
# Estimate output tokens from complete response
try:
output_tokens = get_token_counter(model_id).count_tokens(
complete_response
)
logger.debug("Estimated output tokens: %s", output_tokens)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate output tokens: %s", e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: s/warning/error/g

output_tokens = 0

metrics_map = {
"input_tokens": token_usage["input_tokens"],
"output_tokens": output_tokens,
}
yield stream_end_event(metadata_map, metrics_map)

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
Expand Down Expand Up @@ -262,7 +286,7 @@ async def retrieve_response(
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[Any, str]:
) -> tuple[Any, str, dict[str, int]]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in await client.shields.list()]
if not available_shields:
Expand Down Expand Up @@ -319,4 +343,23 @@ async def retrieve_response(
toolgroups=toolgroups or None,
)

return response, conversation_id
# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it
# try:
# token_usage = {
# "input_tokens": response.usage.get("prompt_tokens", 0),
# "output_tokens": 0, # Will be calculated during streaming
# }
# except AttributeError:
# # Estimate input tokens (Output will be calculated during streaming)
try:
token_usage = get_token_counter(model_id).count_turn_tokens(
system_prompt, query_request.query
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate token usage: %s", e)
token_usage = {
"input_tokens": 0,
"output_tokens": 0,
}

return response, conversation_id, token_usage
1 change: 1 addition & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class Configuration(BaseModel):
AuthenticationConfiguration()
)
customization: Optional[Customization] = None
default_estimation_tokenizer: str = "cl100k_base"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtto, the constant might be stored in constants.py with some comments what it means.


def dump(self, filename: str = "configuration.json") -> None:
"""Dump actual configuration into JSON file."""
Expand Down
6 changes: 4 additions & 2 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class ModelsResponse(BaseModel):
# - referenced_documents: The optional URLs and titles for the documents used
# to generate the response.
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
# - available_quotas: Quota available as measured by all configured quota limiters
# - tool_calls: List of tool requests.
# - tool_results: List of tool results.
Expand All @@ -28,10 +26,14 @@ class QueryResponse(BaseModel):
Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
input_tokens: Number of tokens sent to LLM.
output_tokens: Number of tokens received from LLM.
"""

conversation_id: Optional[str] = None
response: str
input_tokens: Optional[int] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AGAIK it is not optional - you set 0 for any exceptions, so it's still int ?

output_tokens: Optional[int] = None

# provides examples for /docs endpoint
model_config = {
Expand Down
Loading
Loading