|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import dataclasses |
4 | 5 | import inspect |
5 | 6 | from collections.abc import Awaitable |
6 | 7 | from dataclasses import dataclass |
|
51 | 52 | from .models.interface import ModelTracing |
52 | 53 | from .run_context import RunContextWrapper, TContext |
53 | 54 | from .stream_events import RunItemStreamEvent, StreamEvent |
54 | | -from .tool import ComputerTool, FunctionTool, FunctionToolResult |
| 55 | +from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool |
55 | 56 | from .tracing import ( |
56 | 57 | SpanError, |
57 | 58 | Trace, |
@@ -208,34 +209,22 @@ async def execute_tools_and_side_effects( |
208 | 209 | new_step_items.extend(computer_results) |
209 | 210 |
|
210 | 211 | # Reset tool_choice to "auto" after tool execution to prevent infinite loops |
211 | | - if (processed_response.functions or processed_response.computer_actions): |
212 | | - # Reset agent's model_settings |
213 | | - if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): |
214 | | - # Create a new model_settings to avoid modifying the original shared instance |
215 | | - agent.model_settings = ModelSettings( |
216 | | - temperature=agent.model_settings.temperature, |
217 | | - top_p=agent.model_settings.top_p, |
218 | | - frequency_penalty=agent.model_settings.frequency_penalty, |
219 | | - presence_penalty=agent.model_settings.presence_penalty, |
220 | | - tool_choice="auto", # Reset to auto |
221 | | - parallel_tool_calls=agent.model_settings.parallel_tool_calls, |
222 | | - truncation=agent.model_settings.truncation, |
223 | | - max_tokens=agent.model_settings.max_tokens, |
| 212 | + if processed_response.functions or processed_response.computer_actions: |
| 213 | + tools = agent.tools |
| 214 | + # Only reset in the problematic scenarios where loops are likely unintentional |
| 215 | + if cls._should_reset_tool_choice(agent.model_settings, tools): |
| 216 | + agent.model_settings = dataclasses.replace( |
| 217 | + agent.model_settings, |
| 218 | + tool_choice="auto" |
224 | 219 | ) |
225 | | - |
226 | | - # Also reset run_config's model_settings if it exists |
227 | | - if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or |
228 | | - isinstance(run_config.model_settings.tool_choice, str)): |
229 | | - # Create a new model_settings for run_config |
230 | | - run_config.model_settings = ModelSettings( |
231 | | - temperature=run_config.model_settings.temperature, |
232 | | - top_p=run_config.model_settings.top_p, |
233 | | - frequency_penalty=run_config.model_settings.frequency_penalty, |
234 | | - presence_penalty=run_config.model_settings.presence_penalty, |
235 | | - tool_choice="auto", # Reset to auto |
236 | | - parallel_tool_calls=run_config.model_settings.parallel_tool_calls, |
237 | | - truncation=run_config.model_settings.truncation, |
238 | | - max_tokens=run_config.model_settings.max_tokens, |
| 220 | + |
| 221 | + if ( |
| 222 | + run_config.model_settings and |
| 223 | + cls._should_reset_tool_choice(run_config.model_settings, tools) |
| 224 | + ): |
| 225 | + run_config.model_settings = dataclasses.replace( |
| 226 | + run_config.model_settings, |
| 227 | + tool_choice="auto" |
239 | 228 | ) |
240 | 229 |
|
241 | 230 | # Second, check if there are any handoffs |
@@ -328,6 +317,24 @@ async def execute_tools_and_side_effects( |
328 | 317 | next_step=NextStepRunAgain(), |
329 | 318 | ) |
330 | 319 |
|
| 320 | + @classmethod |
| 321 | + def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool: |
| 322 | + if model_settings is None or model_settings.tool_choice is None: |
| 323 | + return False |
| 324 | + |
| 325 | + # for specific tool choices |
| 326 | + if ( |
| 327 | + isinstance(model_settings.tool_choice, str) and |
| 328 | + model_settings.tool_choice not in ["auto", "required", "none"] |
| 329 | + ): |
| 330 | + return True |
| 331 | + |
| 332 | + # for one tool and required tool choice |
| 333 | + if model_settings.tool_choice == "required": |
| 334 | + return len(tools) == 1 |
| 335 | + |
| 336 | + return False |
| 337 | + |
331 | 338 | @classmethod |
332 | 339 | def process_model_response( |
333 | 340 | cls, |
|
0 commit comments