Skip to content

Commit df850c4

Browse files
[Feature][Responses API] Stream Function Call - harmony (#24317)
Signed-off-by: chaunceyjiang <[email protected]>
1 parent 720394d commit df850c4

File tree

2 files changed

+213
-70
lines changed

2 files changed

+213
-70
lines changed

tests/entrypoints/openai/test_response_api_with_harmony.py

Lines changed: 136 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@
1616

1717
MODEL_NAME = "openai/gpt-oss-20b"
1818

19+
GET_WEATHER_SCHEMA = {
20+
"type": "function",
21+
"name": "get_weather",
22+
"description": "Get current temperature for provided coordinates in celsius.", # noqa
23+
"parameters": {
24+
"type": "object",
25+
"properties": {
26+
"latitude": {"type": "number"},
27+
"longitude": {"type": "number"},
28+
},
29+
"required": ["latitude", "longitude"],
30+
"additionalProperties": False,
31+
},
32+
"strict": True,
33+
}
34+
1935

2036
@pytest.fixture(scope="module")
2137
def server():
@@ -305,6 +321,54 @@ async def test_streaming_types(client: OpenAI, model_name: str):
305321
assert len(stack_of_event_types) == 0
306322

307323

324+
@pytest.mark.asyncio
325+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
326+
async def test_function_calling_with_streaming_types(client: OpenAI, model_name: str):
327+
# this links the "done" type with the "start" type
328+
# so every "done" type should have a corresponding "start" type
329+
# and every open block should be closed by the end of the stream
330+
pairs_of_event_types = {
331+
"response.completed": "response.created",
332+
"response.output_item.done": "response.output_item.added",
333+
"response.output_text.done": "response.output_text.delta",
334+
"response.reasoning_text.done": "response.reasoning_text.delta",
335+
"response.reasoning_part.done": "response.reasoning_part.added",
336+
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa
337+
}
338+
339+
tools = [GET_WEATHER_SCHEMA]
340+
input_list = [
341+
{
342+
"role": "user",
343+
"content": "What's the weather like in Paris today?",
344+
}
345+
]
346+
stream_response = await client.responses.create(
347+
model=model_name,
348+
input=input_list,
349+
tools=tools,
350+
stream=True,
351+
)
352+
353+
stack_of_event_types = []
354+
async for event in stream_response:
355+
if event.type == "response.created":
356+
stack_of_event_types.append(event.type)
357+
elif event.type == "response.completed":
358+
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
359+
stack_of_event_types.pop()
360+
if event.type.endswith("added"):
361+
stack_of_event_types.append(event.type)
362+
elif event.type.endswith("delta"):
363+
if stack_of_event_types[-1] == event.type:
364+
continue
365+
stack_of_event_types.append(event.type)
366+
elif event.type.endswith("done"):
367+
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
368+
stack_of_event_types.pop()
369+
assert len(stack_of_event_types) == 0
370+
371+
308372
@pytest.mark.asyncio
309373
@pytest.mark.parametrize("model_name", [MODEL_NAME])
310374
@pytest.mark.parametrize("background", [True, False])
@@ -483,23 +547,7 @@ def call_function(name, args):
483547
@pytest.mark.asyncio
484548
@pytest.mark.parametrize("model_name", [MODEL_NAME])
485549
async def test_function_calling(client: OpenAI, model_name: str):
486-
tools = [
487-
{
488-
"type": "function",
489-
"name": "get_weather",
490-
"description": "Get current temperature for provided coordinates in celsius.", # noqa
491-
"parameters": {
492-
"type": "object",
493-
"properties": {
494-
"latitude": {"type": "number"},
495-
"longitude": {"type": "number"},
496-
},
497-
"required": ["latitude", "longitude"],
498-
"additionalProperties": False,
499-
},
500-
"strict": True,
501-
}
502-
]
550+
tools = [GET_WEATHER_SCHEMA]
503551

504552
response = await client.responses.create(
505553
model=model_name,
@@ -565,21 +613,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
565613
},
566614
"strict": True,
567615
},
568-
{
569-
"type": "function",
570-
"name": "get_weather",
571-
"description": "Get current temperature for provided coordinates in celsius.", # noqa
572-
"parameters": {
573-
"type": "object",
574-
"properties": {
575-
"latitude": {"type": "number"},
576-
"longitude": {"type": "number"},
577-
},
578-
"required": ["latitude", "longitude"],
579-
"additionalProperties": False,
580-
},
581-
"strict": True,
582-
},
616+
GET_WEATHER_SCHEMA,
583617
]
584618

585619
response = await client.responses.create(
@@ -643,23 +677,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
643677
@pytest.mark.asyncio
644678
@pytest.mark.parametrize("model_name", [MODEL_NAME])
645679
async def test_function_calling_required(client: OpenAI, model_name: str):
646-
tools = [
647-
{
648-
"type": "function",
649-
"name": "get_weather",
650-
"description": "Get current temperature for provided coordinates in celsius.", # noqa
651-
"parameters": {
652-
"type": "object",
653-
"properties": {
654-
"latitude": {"type": "number"},
655-
"longitude": {"type": "number"},
656-
},
657-
"required": ["latitude", "longitude"],
658-
"additionalProperties": False,
659-
},
660-
"strict": True,
661-
}
662-
]
680+
tools = [GET_WEATHER_SCHEMA]
663681

664682
with pytest.raises(BadRequestError):
665683
await client.responses.create(
@@ -689,23 +707,7 @@ async def test_system_message_with_tools(client: OpenAI, model_name: str):
689707
@pytest.mark.asyncio
690708
@pytest.mark.parametrize("model_name", [MODEL_NAME])
691709
async def test_function_calling_full_history(client: OpenAI, model_name: str):
692-
tools = [
693-
{
694-
"type": "function",
695-
"name": "get_weather",
696-
"description": "Get current temperature for provided coordinates in celsius.", # noqa
697-
"parameters": {
698-
"type": "object",
699-
"properties": {
700-
"latitude": {"type": "number"},
701-
"longitude": {"type": "number"},
702-
},
703-
"required": ["latitude", "longitude"],
704-
"additionalProperties": False,
705-
},
706-
"strict": True,
707-
}
708-
]
710+
tools = [GET_WEATHER_SCHEMA]
709711

710712
input_messages = [
711713
{"role": "user", "content": "What's the weather like in Paris today?"}
@@ -745,6 +747,74 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
745747
assert response_2.output_text is not None
746748

747749

750+
@pytest.mark.asyncio
751+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
752+
async def test_function_calling_with_stream(client: OpenAI, model_name: str):
753+
tools = [GET_WEATHER_SCHEMA]
754+
input_list = [
755+
{
756+
"role": "user",
757+
"content": "What's the weather like in Paris today?",
758+
}
759+
]
760+
stream_response = await client.responses.create(
761+
model=model_name,
762+
input=input_list,
763+
tools=tools,
764+
stream=True,
765+
)
766+
assert stream_response is not None
767+
final_tool_calls = {}
768+
final_tool_calls_named = {}
769+
async for event in stream_response:
770+
if event.type == "response.output_item.added":
771+
if event.item.type != "function_call":
772+
continue
773+
final_tool_calls[event.output_index] = event.item
774+
final_tool_calls_named[event.item.name] = event.item
775+
elif event.type == "response.function_call_arguments.delta":
776+
index = event.output_index
777+
tool_call = final_tool_calls[index]
778+
if tool_call:
779+
tool_call.arguments += event.delta
780+
final_tool_calls_named[tool_call.name] = tool_call
781+
elif event.type == "response.function_call_arguments.done":
782+
assert event.arguments == final_tool_calls_named[event.name].arguments
783+
for tool_call in final_tool_calls.values():
784+
if (
785+
tool_call
786+
and tool_call.type == "function_call"
787+
and tool_call.name == "get_weather"
788+
):
789+
args = json.loads(tool_call.arguments)
790+
result = call_function(tool_call.name, args)
791+
input_list += [tool_call]
792+
break
793+
assert result is not None
794+
response = await client.responses.create(
795+
model=model_name,
796+
input=input_list
797+
+ [
798+
{
799+
"type": "function_call_output",
800+
"call_id": tool_call.call_id,
801+
"output": str(result),
802+
}
803+
],
804+
tools=tools,
805+
stream=True,
806+
)
807+
assert response is not None
808+
async for event in response:
809+
# check that no function call events in the stream
810+
assert event.type != "response.function_call_arguments.delta"
811+
assert event.type != "response.function_call_arguments.done"
812+
# check that the response contains output text
813+
if event.type == "response.completed":
814+
assert len(event.response.output) > 0
815+
assert event.response.output_text is not None
816+
817+
748818
@pytest.mark.asyncio
749819
@pytest.mark.parametrize("model_name", [MODEL_NAME])
750820
async def test_output_messages_enabled(client: OpenAI, model_name: str, server):

vllm/entrypoints/openai/serving_responses.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
ResponseCodeInterpreterToolCallParam,
2424
ResponseContentPartAddedEvent,
2525
ResponseContentPartDoneEvent,
26+
ResponseFunctionCallArgumentsDeltaEvent,
27+
ResponseFunctionCallArgumentsDoneEvent,
2628
ResponseFunctionToolCall,
2729
ResponseFunctionWebSearch,
2830
ResponseOutputItem,
@@ -927,6 +929,11 @@ def _construct_input_messages_with_harmony(
927929
# to add the tool call request to prev_outputs so that the
928930
# parse_response_input can find the tool call request when
929931
# parsing the tool call output.
932+
if (
933+
isinstance(response_msg, dict)
934+
and response_msg.get("type") == "function_call"
935+
):
936+
response_msg = ResponseFunctionToolCall.model_validate(response_msg)
930937
if isinstance(response_msg, ResponseFunctionToolCall):
931938
prev_outputs.append(response_msg)
932939
return messages
@@ -1398,19 +1405,48 @@ async def _process_harmony_streaming_events(
13981405
current_output_index = 0
13991406
current_item_id: str = ""
14001407
sent_output_item_added = False
1401-
1408+
is_first_function_call_delta = False
14021409
async for ctx in result_generator:
14031410
assert isinstance(ctx, StreamingHarmonyContext)
14041411

14051412
if ctx.is_expecting_start():
14061413
current_output_index += 1
14071414
sent_output_item_added = False
1408-
1415+
is_first_function_call_delta = False
14091416
if len(ctx.parser.messages) > 0:
14101417
previous_item = ctx.parser.messages[-1]
14111418
if previous_item.recipient is not None:
1412-
# Deal with tool call here
1413-
pass
1419+
# Deal with tool call
1420+
if previous_item.recipient.startswith("functions."):
1421+
function_name = previous_item.recipient[len("functions.") :]
1422+
yield _increment_sequence_number_and_return(
1423+
ResponseFunctionCallArgumentsDoneEvent(
1424+
type="response.function_call_arguments.done",
1425+
arguments=previous_item.content[0].text,
1426+
name=function_name,
1427+
item_id=current_item_id,
1428+
output_index=current_output_index,
1429+
sequence_number=-1,
1430+
)
1431+
)
1432+
function_call_item = ResponseFunctionToolCall(
1433+
type="function_call",
1434+
arguments=previous_item.content[0].text,
1435+
name=function_name,
1436+
item_id=current_item_id,
1437+
output_index=current_output_index,
1438+
sequence_number=-1,
1439+
call_id=f"fc_{random_uuid()}",
1440+
status="completed",
1441+
)
1442+
yield _increment_sequence_number_and_return(
1443+
ResponseOutputItemDoneEvent(
1444+
type="response.output_item.done",
1445+
sequence_number=-1,
1446+
output_index=current_output_index,
1447+
item=function_call_item,
1448+
)
1449+
)
14141450
elif previous_item.channel == "analysis":
14151451
content = ResponseReasoningTextContent(
14161452
text=previous_item.content[0].text,
@@ -1766,6 +1802,43 @@ async def _process_harmony_streaming_events(
17661802
),
17671803
)
17681804
)
1805+
# developer tools will be triggered on the commentary channel
1806+
# and recipient starts with "functions.TOOL_NAME"
1807+
if (
1808+
ctx.parser.current_channel == "commentary"
1809+
and ctx.parser.current_recipient
1810+
and ctx.parser.current_recipient.startswith("functions.")
1811+
):
1812+
if is_first_function_call_delta is False:
1813+
is_first_function_call_delta = True
1814+
fc_name = ctx.parser.current_recipient[len("functions.") :]
1815+
tool_call_item = ResponseFunctionToolCall(
1816+
name=fc_name,
1817+
type="function_call",
1818+
id=current_item_id,
1819+
call_id=f"call_{random_uuid()}",
1820+
arguments="",
1821+
status="in_progress",
1822+
)
1823+
current_item_id = f"fc_{random_uuid()}"
1824+
yield _increment_sequence_number_and_return(
1825+
ResponseOutputItemAddedEvent(
1826+
type="response.output_item.added",
1827+
sequence_number=-1,
1828+
output_index=current_output_index,
1829+
item=tool_call_item,
1830+
)
1831+
)
1832+
else:
1833+
yield _increment_sequence_number_and_return(
1834+
ResponseFunctionCallArgumentsDeltaEvent(
1835+
item_id=current_item_id,
1836+
delta=ctx.parser.last_content_delta,
1837+
output_index=current_output_index,
1838+
sequence_number=-1,
1839+
type="response.function_call_arguments.delta",
1840+
)
1841+
)
17691842

17701843
async def responses_stream_generator(
17711844
self,

0 commit comments

Comments
 (0)