Skip to content
Merged
96 changes: 35 additions & 61 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import regex as re
from fastapi import Request
from openai_harmony import Message as OpenAIMessage
from pydantic import TypeAdapter

from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
Expand Down Expand Up @@ -47,8 +46,6 @@
DeltaMessage,
DeltaToolCall,
ErrorResponse,
FunctionCall,
FunctionDefinition,
PromptTokenUsageInfo,
RequestResponseMetadata,
ToolCall,
Expand Down Expand Up @@ -1394,6 +1391,16 @@ async def chat_completion_full_generator(
auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
tool_calls, content = self._parse_tool_calls_from_content(
request=request,
tokenizer=tokenizer,
content=content,
enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser,
)
tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
)
if (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required"
Expand All @@ -1407,63 +1414,33 @@ async def chat_completion_full_generator(
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
tool_call_class = (
MistralToolCall
if isinstance(tokenizer, MistralTokenizer)
else ToolCall
)
assert tool_calls is not None and len(tool_calls) > 0
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content="",
tool_calls=[
tool_call_class(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=content,
)
)
],
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
)

elif request.tool_choice and request.tool_choice == "required":
tool_call_class = (
MistralToolCall
if isinstance(tokenizer, MistralTokenizer)
else ToolCall
)

# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
assert content is not None
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
content
)
tool_call_ids = []
tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0
for tool_call in tool_calls:
tool_call_ids.append(
make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
tool_call_class_items.append(
tool_call_class(
id=make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
function=tool_call,
)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(
id=tool_call_ids[i],
function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(
tool_call.parameters, ensure_ascii=False
),
),
)
for i, tool_call in enumerate(tool_calls)
],
tool_calls=tool_call_class_items,
reasoning_content=reasoning_content,
)

Expand All @@ -1481,25 +1458,22 @@ async def chat_completion_full_generator(
and self.enable_auto_tools
and self.tool_parser
):
try:
tool_parser = self.tool_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
return self.create_error_response(str(e))

tool_call_info = tool_parser.extract_tool_calls(
content if content is not None else "", request=request
)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls:
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls,
content=content,
tool_calls=[
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
)

else:
Expand All @@ -1509,8 +1483,8 @@ async def chat_completion_full_generator(

# try to use content return from tool parser first,
# tool parser may do some modify for the content.
if tool_call_info.content and len(tool_call_info.content) > 0:
ret_content = tool_call_info.content
if content and len(content) > 0:
ret_content = content
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
Expand Down
78 changes: 77 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs

Expand All @@ -21,6 +21,10 @@
else:
from typing_extensions import TypedDict

from openai.types.responses import (
ToolChoiceFunction,
)

import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.engine.protocol import EngineClient
Expand All @@ -36,6 +40,7 @@
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ClassificationRequest,
Expand All @@ -49,6 +54,8 @@
EmbeddingResponse,
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
IOProcessorRequest,
PoolingResponse,
RerankRequest,
Expand Down Expand Up @@ -1305,6 +1312,75 @@ def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
except ValueError:
return None

@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer,
enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]()
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.function.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice == "required":
assert content is not None
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
function_calls.extend(
[
FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
)
for tool_call in tool_calls
]
)
content = None # Clear content since tool is called.
elif (
tool_parser_cls
and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
):
# Automatic Tool Call Parsing
try:
tool_parser = tool_parser_cls(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
raise e
tool_call_info = tool_parser.extract_tool_calls(
content if content is not None else "",
request=request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
else:
# No tool calls.
return None, content

return function_calls, content

@staticmethod
def _get_decoded_token(
logprob: Logprob,
Expand Down