From b05c03d7ba577475c3e2da65ee4d6e529b1dc63e Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 14 Aug 2025 17:48:07 -0400 Subject: [PATCH] Allow modifying the input sent to the model --- src/agents/run.py | 101 +++++++++++++++++-- tests/test_call_model_input_filter.py | 79 +++++++++++++++ tests/test_call_model_input_filter_unit.py | 107 +++++++++++++++++++++ 3 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 tests/test_call_model_input_filter.py create mode 100644 tests/test_call_model_input_filter_unit.py diff --git a/src/agents/run.py b/src/agents/run.py index d0748e514..5f9ec10ac 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -4,7 +4,7 @@ import copy import inspect from dataclasses import dataclass, field -from typing import Any, Generic, cast +from typing import Any, Callable, Generic, cast from openai.types.responses import ResponseCompletedEvent from openai.types.responses.response_prompt_param import ( @@ -56,6 +56,7 @@ from .tracing.span_data import AgentSpanData from .usage import Usage from .util import _coro, _error_tracing +from .util._types import MaybeAwaitable DEFAULT_MAX_TURNS = 10 @@ -81,6 +82,27 @@ def get_default_agent_runner() -> AgentRunner: return DEFAULT_AGENT_RUNNER +@dataclass +class ModelInputData: + """Container for the data that will be sent to the model.""" + + input: list[TResponseInputItem] + instructions: str | None + + +@dataclass +class CallModelData(Generic[TContext]): + """Data passed to `RunConfig.call_model_input_filter` prior to model call.""" + + model_data: ModelInputData + agent: Agent[TContext] + context: TContext | None + + +# Type alias for the optional input filter callback +CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]] + + @dataclass class RunConfig: """Configures settings for the entire agent run.""" @@ -139,6 +161,16 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ + call_model_input_filter: CallModelInputFilter | None = None + """ + Optional callback that is invoked immediately before calling the model. It receives the current + agent, context and the model input (instructions and input items), and must return a possibly + modified `ModelInputData` to use for the model call. + + This allows you to edit the input sent to the model e.g. to stay within a token limit. + For example, you can use this to add a system prompt to the input. + """ + class RunOptions(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" @@ -593,6 +625,47 @@ def run_streamed( ) return streamed_result + @classmethod + async def _maybe_filter_model_input( + cls, + *, + agent: Agent[TContext], + run_config: RunConfig, + context_wrapper: RunContextWrapper[TContext], + input_items: list[TResponseInputItem], + system_instructions: str | None, + ) -> ModelInputData: + """Apply optional call_model_input_filter to modify model input. + + Returns a `ModelInputData` that will be sent to the model. + """ + effective_instructions = system_instructions + effective_input: list[TResponseInputItem] = input_items + + if run_config.call_model_input_filter is None: + return ModelInputData(input=effective_input, instructions=effective_instructions) + + try: + model_input = ModelInputData( + input=copy.deepcopy(effective_input), + instructions=effective_instructions, + ) + filter_payload: CallModelData[TContext] = CallModelData( + model_data=model_input, + agent=agent, + context=context_wrapper.context, + ) + maybe_updated = run_config.call_model_input_filter(filter_payload) + updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated + if not isinstance(updated, ModelInputData): + raise UserError("call_model_input_filter must return a ModelInputData instance") + return updated + except Exception as e: + _error_tracing.attach_error_to_current_span( + SpanError(message="Error in call_model_input_filter", data={"error": str(e)}) + ) + raise + @classmethod async def _run_input_guardrails_with_queue( cls, @@ -863,10 +936,18 @@ async def _run_single_turn_streamed( input = ItemHelpers.input_to_new_input_list(streamed_result.input) input.extend([item.to_input_item() for item in streamed_result.new_items]) + filtered = await cls._maybe_filter_model_input( + agent=agent, + run_config=run_config, + context_wrapper=context_wrapper, + input_items=input, + system_instructions=system_prompt, + ) + # 1. Stream the output events async for event in model.stream_response( - system_prompt, - input, + filtered.instructions, + filtered.input, model_settings, all_tools, output_schema, @@ -1034,7 +1115,6 @@ async def _get_single_step_result_from_streamed_response( run_config: RunConfig, tool_use_tracker: AgentToolUseTracker, ) -> SingleStepResult: - original_input = streamed_result.input pre_step_items = streamed_result.new_items event_queue = streamed_result._event_queue @@ -1161,13 +1241,22 @@ async def _get_new_response( previous_response_id: str | None, prompt_config: ResponsePromptParam | None, ) -> ModelResponse: + # Allow user to modify model input right before the call, if configured + filtered = await cls._maybe_filter_model_input( + agent=agent, + run_config=run_config, + context_wrapper=context_wrapper, + input_items=input, + system_instructions=system_prompt, + ) + model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) new_response = await model.get_response( - system_instructions=system_prompt, - input=input, + system_instructions=filtered.instructions, + input=filtered.input, model_settings=model_settings, tools=all_tools, output_schema=output_schema, diff --git a/tests/test_call_model_input_filter.py b/tests/test_call_model_input_filter.py new file mode 100644 index 000000000..be2dc28e6 --- /dev/null +++ b/tests/test_call_model_input_filter.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from agents import Agent, RunConfig, Runner, UserError +from agents.run import CallModelData, ModelInputData + +from .fake_model import FakeModel +from .test_responses import get_text_input_item, get_text_message + + +@pytest.mark.asyncio +async def test_call_model_input_filter_sync_non_streamed() -> None: + model = FakeModel() + agent = Agent(name="test", model=model) + + # Prepare model output + model.set_next_output([get_text_message("ok")]) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + mi = data.model_data + new_input = list(mi.input) + [get_text_input_item("added-sync")] + return ModelInputData(input=new_input, instructions="filtered-sync") + + await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["system_instructions"] == "filtered-sync" + assert isinstance(model.last_turn_args["input"], list) + assert len(model.last_turn_args["input"]) == 2 + assert model.last_turn_args["input"][-1]["content"] == "added-sync" + + +@pytest.mark.asyncio +async def test_call_model_input_filter_async_streamed() -> None: + model = FakeModel() + agent = Agent(name="test", model=model) + + # Prepare model output + model.set_next_output([get_text_message("ok")]) + + async def filter_fn(data: CallModelData[Any]) -> ModelInputData: + mi = data.model_data + new_input = list(mi.input) + [get_text_input_item("added-async")] + return ModelInputData(input=new_input, instructions="filtered-async") + + result = Runner.run_streamed( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + async for _ in result.stream_events(): + pass + + assert model.last_turn_args["system_instructions"] == "filtered-async" + assert isinstance(model.last_turn_args["input"], list) + assert len(model.last_turn_args["input"]) == 2 + assert model.last_turn_args["input"][-1]["content"] == "added-async" + + +@pytest.mark.asyncio +async def test_call_model_input_filter_invalid_return_type_raises() -> None: + model = FakeModel() + agent = Agent(name="test", model=model) + + def invalid_filter(_data: CallModelData[Any]): + return "bad" + + with pytest.raises(UserError): + await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=invalid_filter), + ) diff --git a/tests/test_call_model_input_filter_unit.py b/tests/test_call_model_input_filter_unit.py new file mode 100644 index 000000000..7cf3a00a9 --- /dev/null +++ b/tests/test_call_model_input_filter_unit.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any + +import pytest +from openai.types.responses import ResponseOutputMessage, ResponseOutputText + +# Make the repository tests helpers importable from this unit test +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "tests")) +from fake_model import FakeModel # type: ignore + +# Import directly from submodules to avoid heavy __init__ side effects +from agents.agent import Agent +from agents.exceptions import UserError +from agents.run import CallModelData, ModelInputData, RunConfig, Runner + + +@pytest.mark.asyncio +async def test_call_model_input_filter_sync_non_streamed_unit() -> None: + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output( + [ + ResponseOutputMessage( + id="1", + type="message", + role="assistant", + content=[ResponseOutputText(text="ok", type="output_text", annotations=[])], + status="completed", + ) + ] + ) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + mi = data.model_data + new_input = list(mi.input) + [ + {"content": "added-sync", "role": "user"} + ] # pragma: no cover - trivial + return ModelInputData(input=new_input, instructions="filtered-sync") + + await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["system_instructions"] == "filtered-sync" + assert isinstance(model.last_turn_args["input"], list) + assert len(model.last_turn_args["input"]) == 2 + assert model.last_turn_args["input"][-1]["content"] == "added-sync" + + +@pytest.mark.asyncio +async def test_call_model_input_filter_async_streamed_unit() -> None: + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output( + [ + ResponseOutputMessage( + id="1", + type="message", + role="assistant", + content=[ResponseOutputText(text="ok", type="output_text", annotations=[])], + status="completed", + ) + ] + ) + + async def filter_fn(data: CallModelData[Any]) -> ModelInputData: + mi = data.model_data + new_input = list(mi.input) + [ + {"content": "added-async", "role": "user"} + ] # pragma: no cover - trivial + return ModelInputData(input=new_input, instructions="filtered-async") + + result = Runner.run_streamed( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + async for _ in result.stream_events(): + pass + + assert model.last_turn_args["system_instructions"] == "filtered-async" + assert isinstance(model.last_turn_args["input"], list) + assert len(model.last_turn_args["input"]) == 2 + assert model.last_turn_args["input"][-1]["content"] == "added-async" + + +@pytest.mark.asyncio +async def test_call_model_input_filter_invalid_return_type_raises_unit() -> None: + model = FakeModel() + agent = Agent(name="test", model=model) + + def invalid_filter(_data: CallModelData[Any]): + return "bad" + + with pytest.raises(UserError): + await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=invalid_filter), + )