diff --git a/examples/model_providers/litellm_auto.py b/examples/model_providers/litellm_auto.py index 12b1e8914..5e6942713 100644 --- a/examples/model_providers/litellm_auto.py +++ b/examples/model_providers/litellm_auto.py @@ -2,7 +2,9 @@ import asyncio -from agents import Agent, Runner, function_tool, set_tracing_disabled +from pydantic import BaseModel + +from agents import Agent, ModelSettings, Runner, function_tool, set_tracing_disabled """This example uses the built-in support for LiteLLM. To use this, ensure you have the ANTHROPIC_API_KEY environment variable set. @@ -10,12 +12,18 @@ set_tracing_disabled(disabled=True) +# import logging +# logging.basicConfig(level=logging.DEBUG) @function_tool def get_weather(city: str): print(f"[debug] getting weather for {city}") return f"The weather in {city} is sunny." +class Result(BaseModel): + output_text: str + tool_results: list[str] + async def main(): agent = Agent( @@ -24,6 +32,8 @@ async def main(): # We prefix with litellm/ to tell the Runner to use the LitellmModel model="litellm/anthropic/claude-3-5-sonnet-20240620", tools=[get_weather], + model_settings=ModelSettings(tool_choice="required"), + output_type=Result, ) result = await Runner.run(agent, "What's the weather in Tokyo?") diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 6c417b308..56784004c 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -509,13 +509,29 @@ def process_model_response( # Regular function tool call else: if output.name not in function_map: - _error_tracing.attach_error_to_current_span( - SpanError( - message="Tool not found", - data={"tool_name": output.name}, + if output_schema is not None and output.name == "json_tool_call": + # LiteLLM could generate non-existent tool calls for structured outputs + items.append(ToolCallItem(raw_item=output, agent=agent)) + functions.append( + ToolRunFunction( + tool_call=output, + # this tool does not exist in function_map, so generate ad-hoc one, + # which just parses the input if it's a string, and returns the + # value otherwise + function_tool=_build_litellm_json_tool_call(output), + ) ) - ) - raise ModelBehaviorError(f"Tool {output.name} not found in agent {agent.name}") + continue + else: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Tool not found", + data={"tool_name": output.name}, + ) + ) + error = f"Tool {output.name} not found in agent {agent.name}" + raise ModelBehaviorError(error) + items.append(ToolCallItem(raw_item=output, agent=agent)) functions.append( ToolRunFunction( @@ -1193,3 +1209,21 @@ async def execute( # "id": "out" + call.tool_call.id, # TODO remove this, it should be optional }, ) + + +def _build_litellm_json_tool_call(output: ResponseFunctionToolCall) -> FunctionTool: + async def on_invoke_tool(_ctx: ToolContext[Any], value: Any) -> Any: + if isinstance(value, str): + import json + + return json.loads(value) + return value + + return FunctionTool( + name=output.name, + description=output.name, + params_json_schema={}, + on_invoke_tool=on_invoke_tool, + strict_json_schema=True, + is_enabled=True, + )