Skip to content

Commit 0976711

Browse files
[Refactor] to simplify and extract the shared logic between chat completion and responses (#27961)
Signed-off-by: chaunceyjiang <[email protected]>
1 parent e261d37 commit 0976711

File tree

2 files changed

+112
-62
lines changed

2 files changed

+112
-62
lines changed

vllm/entrypoints/openai/serving_chat.py

Lines changed: 35 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import regex as re
1414
from fastapi import Request
1515
from openai_harmony import Message as OpenAIMessage
16-
from pydantic import TypeAdapter
1716

1817
from vllm.engine.protocol import EngineClient
1918
from vllm.entrypoints.chat_utils import (
@@ -47,8 +46,6 @@
4746
DeltaMessage,
4847
DeltaToolCall,
4948
ErrorResponse,
50-
FunctionCall,
51-
FunctionDefinition,
5249
PromptTokenUsageInfo,
5350
RequestResponseMetadata,
5451
ToolCall,
@@ -1394,6 +1391,16 @@ async def chat_completion_full_generator(
13941391
auto_tools_called = False
13951392
# if auto tools are not enabled, and a named tool choice using
13961393
# outlines is not being used
1394+
tool_calls, content = self._parse_tool_calls_from_content(
1395+
request=request,
1396+
tokenizer=tokenizer,
1397+
content=content,
1398+
enable_auto_tools=self.enable_auto_tools,
1399+
tool_parser_cls=self.tool_parser,
1400+
)
1401+
tool_call_class = (
1402+
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
1403+
)
13971404
if (not self.enable_auto_tools or not self.tool_parser) and (
13981405
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
13991406
and request.tool_choice != "required"
@@ -1407,63 +1414,33 @@ async def chat_completion_full_generator(
14071414
request.tool_choice
14081415
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
14091416
):
1410-
tool_call_class = (
1411-
MistralToolCall
1412-
if isinstance(tokenizer, MistralTokenizer)
1413-
else ToolCall
1414-
)
1417+
assert tool_calls is not None and len(tool_calls) > 0
14151418
message = ChatMessage(
14161419
role=role,
14171420
reasoning_content=reasoning_content,
14181421
content="",
1419-
tool_calls=[
1420-
tool_call_class(
1421-
function=FunctionCall(
1422-
name=request.tool_choice.function.name,
1423-
arguments=content,
1424-
)
1425-
)
1426-
],
1422+
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
14271423
)
14281424

14291425
elif request.tool_choice and request.tool_choice == "required":
1430-
tool_call_class = (
1431-
MistralToolCall
1432-
if isinstance(tokenizer, MistralTokenizer)
1433-
else ToolCall
1434-
)
1435-
1436-
# the fields of FunctionDefinition are a superset of the
1437-
# tool call outputs and can be used for parsing
1438-
assert content is not None
1439-
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
1440-
content
1441-
)
1442-
tool_call_ids = []
1426+
tool_call_class_items = []
1427+
assert tool_calls is not None and len(tool_calls) > 0
14431428
for tool_call in tool_calls:
1444-
tool_call_ids.append(
1445-
make_tool_call_id(
1446-
id_type=self.tool_call_id_type,
1447-
func_name=tool_call.name,
1448-
idx=history_tool_call_cnt,
1429+
tool_call_class_items.append(
1430+
tool_call_class(
1431+
id=make_tool_call_id(
1432+
id_type=self.tool_call_id_type,
1433+
func_name=tool_call.name,
1434+
idx=history_tool_call_cnt,
1435+
),
1436+
function=tool_call,
14491437
)
14501438
)
14511439
history_tool_call_cnt += 1
14521440
message = ChatMessage(
14531441
role=role,
14541442
content="",
1455-
tool_calls=[
1456-
tool_call_class(
1457-
id=tool_call_ids[i],
1458-
function=FunctionCall(
1459-
name=tool_call.name,
1460-
arguments=json.dumps(
1461-
tool_call.parameters, ensure_ascii=False
1462-
),
1463-
),
1464-
)
1465-
for i, tool_call in enumerate(tool_calls)
1466-
],
1443+
tool_calls=tool_call_class_items,
14671444
reasoning_content=reasoning_content,
14681445
)
14691446

@@ -1481,25 +1458,22 @@ async def chat_completion_full_generator(
14811458
and self.enable_auto_tools
14821459
and self.tool_parser
14831460
):
1484-
try:
1485-
tool_parser = self.tool_parser(tokenizer)
1486-
except RuntimeError as e:
1487-
logger.exception("Error in tool parser creation.")
1488-
return self.create_error_response(str(e))
1489-
1490-
tool_call_info = tool_parser.extract_tool_calls(
1491-
content if content is not None else "", request=request
1492-
)
14931461
# In the OpenAI API the finish_reason is "tools_called"
14941462
# if the tool choice is auto and the model produced a tool
14951463
# call. The same is not true for named function calls
1496-
auto_tools_called = tool_call_info.tools_called
1497-
if tool_call_info.tools_called:
1464+
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
1465+
if tool_calls:
14981466
message = ChatMessage(
14991467
role=role,
15001468
reasoning_content=reasoning_content,
1501-
content=tool_call_info.content,
1502-
tool_calls=tool_call_info.tool_calls,
1469+
content=content,
1470+
tool_calls=[
1471+
ToolCall(
1472+
function=tc,
1473+
type="function",
1474+
)
1475+
for tc in tool_calls
1476+
],
15031477
)
15041478

15051479
else:
@@ -1509,8 +1483,8 @@ async def chat_completion_full_generator(
15091483

15101484
# try to use content return from tool parser first,
15111485
# tool parser may do some modify for the content.
1512-
if tool_call_info.content and len(tool_call_info.content) > 0:
1513-
ret_content = tool_call_info.content
1486+
if content and len(content) > 0:
1487+
ret_content = content
15141488
message = ChatMessage(
15151489
role=role,
15161490
reasoning_content=reasoning_content,

vllm/entrypoints/openai/serving_engine.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
from fastapi import Request
15-
from pydantic import BaseModel, ConfigDict, Field
15+
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
1616
from starlette.datastructures import Headers
1717
from typing_extensions import TypeIs
1818

@@ -21,6 +21,10 @@
2121
else:
2222
from typing_extensions import TypedDict
2323

24+
from openai.types.responses import (
25+
ToolChoiceFunction,
26+
)
27+
2428
import vllm.envs as envs
2529
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
2630
from vllm.engine.protocol import EngineClient
@@ -36,6 +40,7 @@
3640
from vllm.entrypoints.context import ConversationContext
3741
from vllm.entrypoints.logger import RequestLogger
3842
from vllm.entrypoints.openai.protocol import (
43+
ChatCompletionNamedToolChoiceParam,
3944
ChatCompletionRequest,
4045
ChatCompletionResponse,
4146
ClassificationRequest,
@@ -49,6 +54,8 @@
4954
EmbeddingResponse,
5055
ErrorInfo,
5156
ErrorResponse,
57+
FunctionCall,
58+
FunctionDefinition,
5259
IOProcessorRequest,
5360
PoolingResponse,
5461
RerankRequest,
@@ -1305,6 +1312,75 @@ def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
13051312
except ValueError:
13061313
return None
13071314

1315+
@staticmethod
1316+
def _parse_tool_calls_from_content(
1317+
request: ResponsesRequest | ChatCompletionRequest,
1318+
tokenizer: AnyTokenizer,
1319+
enable_auto_tools: bool,
1320+
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
1321+
content: str | None = None,
1322+
) -> tuple[list[FunctionCall] | None, str | None]:
1323+
function_calls = list[FunctionCall]()
1324+
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
1325+
assert content is not None
1326+
# Forced Function Call
1327+
function_calls.append(
1328+
FunctionCall(name=request.tool_choice.name, arguments=content)
1329+
)
1330+
content = None # Clear content since tool is called.
1331+
elif request.tool_choice and isinstance(
1332+
request.tool_choice, ChatCompletionNamedToolChoiceParam
1333+
):
1334+
assert content is not None
1335+
# Forced Function Call
1336+
function_calls.append(
1337+
FunctionCall(name=request.tool_choice.function.name, arguments=content)
1338+
)
1339+
content = None # Clear content since tool is called.
1340+
elif request.tool_choice == "required":
1341+
assert content is not None
1342+
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
1343+
function_calls.extend(
1344+
[
1345+
FunctionCall(
1346+
name=tool_call.name,
1347+
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
1348+
)
1349+
for tool_call in tool_calls
1350+
]
1351+
)
1352+
content = None # Clear content since tool is called.
1353+
elif (
1354+
tool_parser_cls
1355+
and enable_auto_tools
1356+
and (request.tool_choice == "auto" or request.tool_choice is None)
1357+
):
1358+
# Automatic Tool Call Parsing
1359+
try:
1360+
tool_parser = tool_parser_cls(tokenizer)
1361+
except RuntimeError as e:
1362+
logger.exception("Error in tool parser creation.")
1363+
raise e
1364+
tool_call_info = tool_parser.extract_tool_calls(
1365+
content if content is not None else "",
1366+
request=request, # type: ignore
1367+
)
1368+
if tool_call_info is not None and tool_call_info.tools_called:
1369+
# extract_tool_calls() returns a list of tool calls.
1370+
function_calls.extend(
1371+
FunctionCall(
1372+
name=tool_call.function.name,
1373+
arguments=tool_call.function.arguments,
1374+
)
1375+
for tool_call in tool_call_info.tool_calls
1376+
)
1377+
content = tool_call_info.content
1378+
else:
1379+
# No tool calls.
1380+
return None, content
1381+
1382+
return function_calls, content
1383+
13081384
@staticmethod
13091385
def _get_decoded_token(
13101386
logprob: Logprob,

0 commit comments

Comments
 (0)