From c5fa125e8ccabcb213406a02ca0fa56fe3e20785 Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Tue, 29 Jul 2025 17:00:11 +0200 Subject: [PATCH 1/2] Allow to pass both session and input list + tests --- src/agents/run.py | 51 +++++++++++++----- tests/test_session.py | 123 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 159 insertions(+), 15 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 2dd9524bb..94a7e6c1b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -4,7 +4,7 @@ import copy import inspect from dataclasses import dataclass, field -from typing import Any, Generic, cast +from typing import Any, Generic, Literal, cast from openai.types.responses import ResponseCompletedEvent from openai.types.responses.response_prompt_param import ( @@ -139,6 +139,11 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ + session_input_handling: Literal["replace", "append"] | None = None + """If a custom input list is given together with the Session, it will + be appended to the session messages or it will replace them. + """ + class RunOptions(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" @@ -343,7 +348,9 @@ async def run( run_config = RunConfig() # Prepare input with session if enabled - prepared_input = await self._prepare_input_with_session(input, session) + prepared_input = await self._prepare_input_with_session( + input, session, run_config.session_input_handling + ) tool_use_tracker = AgentToolUseTracker() @@ -468,7 +475,9 @@ async def run( ) # Save the conversation to session if enabled - await self._save_result_to_session(session, input, result) + await self._save_result_to_session( + session, input, result, run_config.session_input_handling + ) return result elif isinstance(turn_result.next_step, NextStepHandoff): @@ -662,7 +671,9 @@ async def _start_streaming( try: # Prepare input with session if enabled - prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session) + prepared_input = await AgentRunner._prepare_input_with_session( + starting_input, session, run_config.session_input_handling + ) # Update the streamed result with the prepared input streamed_result.input = prepared_input @@ -781,7 +792,7 @@ async def _start_streaming( context_wrapper=context_wrapper, ) await AgentRunner._save_result_to_session( - session, starting_input, temp_result + session, starting_input, temp_result, run_config.session_input_handling ) streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1191,18 +1202,18 @@ async def _prepare_input_with_session( cls, input: str | list[TResponseInputItem], session: Session | None, + session_input_handling: Literal["replace", "append"] | None, ) -> str | list[TResponseInputItem]: """Prepare input by combining it with session history if enabled.""" if session is None: return input - # Validate that we don't have both a session and a list input, as this creates - # ambiguity about whether the list should append to or replace existing session history - if isinstance(input, list): + # If the user doesn't explicitly specify a mode, raise an error + if isinstance(input, list) and not session_input_handling: raise UserError( - "Cannot provide both a session and a list of input items. " - "When using session memory, provide only a string input to append to the " - "conversation, or use session=None and provide a list to manually manage " + "You must specify the `session_input_handling` in the `RunConfig`. " + "Otherwise, when using session memory, provide only a string input to append to " + "the conversation, or use session=None and provide a list to manually manage " "conversation history." ) @@ -1212,8 +1223,17 @@ async def _prepare_input_with_session( # Convert input to list format new_input_list = ItemHelpers.input_to_new_input_list(input) - # Combine history with new input - combined_input = history + new_input_list + if session_input_handling == "append" or session_input_handling is None: + # Append new input to history + combined_input = history + new_input_list + elif session_input_handling == "replace": + # Replace history with new input + combined_input = new_input_list + else: + raise UserError( + "The specified `session_input_handling` is not available. " + "Choose between `append`, `replace` or `None`." + ) return combined_input @@ -1223,11 +1243,16 @@ async def _save_result_to_session( session: Session | None, original_input: str | list[TResponseInputItem], result: RunResult, + saving_mode: Literal["replace", "append"] | None = None, ) -> None: """Save the conversation turn to session.""" if session is None: return + # Remove old history + if saving_mode == "replace": + await session.clear_session() + # Convert original input to list format if needed input_list = ItemHelpers.input_to_new_input_list(original_input) diff --git a/tests/test_session.py b/tests/test_session.py index 032f2bb38..ca1bb0c4e 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,7 +6,7 @@ import pytest -from agents import Agent, Runner, SQLiteSession, TResponseInputItem +from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem from agents.exceptions import UserError from .fake_model import FakeModel @@ -394,7 +394,126 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method) await run_agent_async(runner_method, agent, list_input, session=session) # Verify the error message explains the issue - assert "Cannot provide both a session and a list of input items" in str(exc_info.value) + assert "You must specify the `session_input_handling` in" in str(exc_info.value) assert "manually manage conversation history" in str(exc_info.value) session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_append_list(runner_method): + """Test if the user passes a list of items and want to append them.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session + session_id = "session_1" + session = SQLiteSession(session_id, db_path) + + model.set_next_output([get_text_message("I like cats")]) + _ = await run_agent_async(runner_method, agent, "I like cats", session=session) + + append_input = [ + {"role": "user", "content": "Some random user text"}, + {"role": "assistant", "content": "You're right"}, + {"role": "user", "content": "What did I say I like?"}, + ] + second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"} + model.set_next_output([get_text_message(second_model_response.get("content", ""))]) + + _ = await run_agent_async( + runner_method, + agent, + append_input, + session=session, + run_config=RunConfig(session_input_handling="append"), + ) + + session_items = await session.get_items() + + # Check the items has been appended + assert len(session_items) == 6 + + # Check the items are the last 4 elements + append_input.append(second_model_response) + for sess_item, orig_item in zip(session_items[-4:], append_input): + assert sess_item.get("role") == orig_item.get("role") + + sess_content = sess_item.get("content") + # Narrow to list or str for mypy + assert isinstance(sess_content, (list, str)) + + if isinstance(sess_content, list): + # now mypy knows `content: list[Any]` + assert isinstance(sess_content[0], dict) and "text" in sess_content[0] + val_sess = sess_content[0]["text"] + else: + # here content is str + val_sess = sess_content + + assert val_sess == orig_item["content"] + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_replace_list(runner_method): + """Test if the user passes a list of items and want to replace the history.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session + session_id = "session_1" + session = SQLiteSession(session_id, db_path) + + model.set_next_output([get_text_message("I like cats")]) + _ = await run_agent_async(runner_method, agent, "I like cats", session=session) + + replace_input = [ + {"role": "user", "content": "Some random user text"}, + {"role": "assistant", "content": "You're right"}, + {"role": "user", "content": "What did I say I like?"}, + ] + second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"} + model.set_next_output([get_text_message(second_model_response.get("content", ""))]) + + _ = await run_agent_async( + runner_method, + agent, + replace_input, + session=session, + run_config=RunConfig(session_input_handling="replace"), + ) + + session_items = await session.get_items() + + # Check the new items replaced the history + assert len(session_items) == 4 + + # Check the items are the last 4 elements + replace_input.append(second_model_response) + for sess_item, orig_item in zip(session_items, replace_input): + assert sess_item.get("role") == orig_item.get("role") + sess_content = sess_item.get("content") + # Narrow to list or str for mypy + assert isinstance(sess_content, (list, str)) + + if isinstance(sess_content, list): + # now mypy knows `content: list[Any]` + assert isinstance(sess_content[0], dict) and "text" in sess_content[0] + val_sess = sess_content[0]["text"] + else: + # here content is str + val_sess = sess_content + + assert val_sess == orig_item["content"] + + session.close() From 6c3ca7132746555f0c936eb298406c5ef6cc0628 Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Thu, 31 Jul 2025 17:30:36 +0200 Subject: [PATCH 2/2] Use a callback function to mix history and input --- src/agents/memory/__init__.py | 3 +- src/agents/memory/util.py | 29 +++++++++ src/agents/run.py | 59 +++++++++-------- tests/test_session.py | 115 +++++++--------------------------- 4 files changed, 82 insertions(+), 124 deletions(-) create mode 100644 src/agents/memory/util.py diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 059ca57ab..7f3b45dba 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,3 +1,4 @@ from .session import Session, SQLiteSession +from .util import SessionInputHandler, SessionMixerCallable -__all__ = ["Session", "SQLiteSession"] +__all__ = ["Session", "SessionInputHandler", "SessionMixerCallable", "SQLiteSession"] diff --git a/src/agents/memory/util.py b/src/agents/memory/util.py new file mode 100644 index 000000000..e530b2f30 --- /dev/null +++ b/src/agents/memory/util.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Callable, Union + +from ..items import TResponseInputItem +from ..util._types import MaybeAwaitable + +SessionMixerCallable = Callable[ + [list[TResponseInputItem], list[TResponseInputItem]], + MaybeAwaitable[list[TResponseInputItem]], +] +"""A function that combines session history with new input items. + +Args: + history_items: The list of items from the session history. + new_items: The list of new input items for the current turn. + +Returns: + A list of combined items to be used as input for the agent. Can be sync or async. +""" + + +SessionInputHandler = Union[SessionMixerCallable, None] +"""Defines how to handle session history when new input is provided. + +- `None` (default): The new input is appended to the session history. +- `SessionMixerCallable`: A custom function that receives the history and new input, and + returns the desired combined list of items. +""" diff --git a/src/agents/run.py b/src/agents/run.py index 94a7e6c1b..3646a380a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -4,7 +4,7 @@ import copy import inspect from dataclasses import dataclass, field -from typing import Any, Generic, Literal, cast +from typing import Any, Generic, cast from openai.types.responses import ResponseCompletedEvent from openai.types.responses.response_prompt_param import ( @@ -44,7 +44,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger -from .memory import Session +from .memory import Session, SessionInputHandler from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -139,9 +139,12 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ - session_input_handling: Literal["replace", "append"] | None = None - """If a custom input list is given together with the Session, it will - be appended to the session messages or it will replace them. + session_input_callback: SessionInputHandler = None + """Defines how to handle session history when new input is provided. + + - `None` (default): The new input is appended to the session history. + - `SessionMixerCallable`: A custom function that receives the history and new input, and + returns the desired combined list of items. """ @@ -349,7 +352,7 @@ async def run( # Prepare input with session if enabled prepared_input = await self._prepare_input_with_session( - input, session, run_config.session_input_handling + input, session, run_config.session_input_callback ) tool_use_tracker = AgentToolUseTracker() @@ -475,9 +478,7 @@ async def run( ) # Save the conversation to session if enabled - await self._save_result_to_session( - session, input, result, run_config.session_input_handling - ) + await self._save_result_to_session(session, input, result) return result elif isinstance(turn_result.next_step, NextStepHandoff): @@ -672,7 +673,7 @@ async def _start_streaming( try: # Prepare input with session if enabled prepared_input = await AgentRunner._prepare_input_with_session( - starting_input, session, run_config.session_input_handling + starting_input, session, run_config.session_input_callback ) # Update the streamed result with the prepared input @@ -792,7 +793,7 @@ async def _start_streaming( context_wrapper=context_wrapper, ) await AgentRunner._save_result_to_session( - session, starting_input, temp_result, run_config.session_input_handling + session, starting_input, temp_result ) streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1202,16 +1203,16 @@ async def _prepare_input_with_session( cls, input: str | list[TResponseInputItem], session: Session | None, - session_input_handling: Literal["replace", "append"] | None, + session_input_callback: SessionInputHandler, ) -> str | list[TResponseInputItem]: """Prepare input by combining it with session history if enabled.""" if session is None: return input # If the user doesn't explicitly specify a mode, raise an error - if isinstance(input, list) and not session_input_handling: + if isinstance(input, list) and not session_input_callback: raise UserError( - "You must specify the `session_input_handling` in the `RunConfig`. " + "You must specify the `session_input_callback` in the `RunConfig`. " "Otherwise, when using session memory, provide only a string input to append to " "the conversation, or use session=None and provide a list to manually manage " "conversation history." @@ -1223,36 +1224,34 @@ async def _prepare_input_with_session( # Convert input to list format new_input_list = ItemHelpers.input_to_new_input_list(input) - if session_input_handling == "append" or session_input_handling is None: - # Append new input to history - combined_input = history + new_input_list - elif session_input_handling == "replace": - # Replace history with new input - combined_input = new_input_list + if session_input_callback is None: + return history + new_input_list + elif callable(session_input_callback): + res = session_input_callback(history, new_input_list) + if inspect.isawaitable(res): + return await res + return res else: raise UserError( - "The specified `session_input_handling` is not available. " - "Choose between `append`, `replace` or `None`." + f"Invalid `session_input_callback` value: {session_input_callback}. " + "Choose between `None` or a custom callable function." ) - return combined_input - @classmethod async def _save_result_to_session( cls, session: Session | None, original_input: str | list[TResponseInputItem], result: RunResult, - saving_mode: Literal["replace", "append"] | None = None, ) -> None: - """Save the conversation turn to session.""" + """ + Save the conversation turn to session. + It does not account for any filtering or modification performed by + `RunConfig.session_input_handling`. + """ if session is None: return - # Remove old history - if saving_mode == "replace": - await session.clear_session() - # Convert original input to list format if needed input_list = ItemHelpers.input_to_new_input_list(original_input) diff --git a/tests/test_session.py b/tests/test_session.py index ca1bb0c4e..1cfc62a92 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -394,7 +394,7 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method) await run_agent_async(runner_method, agent, list_input, session=session) # Verify the error message explains the issue - assert "You must specify the `session_input_handling` in" in str(exc_info.value) + assert "You must specify the `session_input_callback` in" in str(exc_info.value) assert "manually manage conversation history" in str(exc_info.value) session.close() @@ -402,7 +402,7 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method) @pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio -async def test_session_memory_append_list(runner_method): +async def test_session_callback_prepared_input(runner_method): """Test if the user passes a list of items and want to append them.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_memory.db" @@ -414,106 +414,35 @@ async def test_session_memory_append_list(runner_method): session_id = "session_1" session = SQLiteSession(session_id, db_path) - model.set_next_output([get_text_message("I like cats")]) - _ = await run_agent_async(runner_method, agent, "I like cats", session=session) - - append_input = [ - {"role": "user", "content": "Some random user text"}, - {"role": "assistant", "content": "You're right"}, - {"role": "user", "content": "What did I say I like?"}, + # Add first messages manually + initial_history: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello there."}, + {"role": "assistant", "content": "Hi, I'm here to assist you."}, ] - second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"} - model.set_next_output([get_text_message(second_model_response.get("content", ""))]) - - _ = await run_agent_async( - runner_method, - agent, - append_input, - session=session, - run_config=RunConfig(session_input_handling="append"), - ) - - session_items = await session.get_items() - - # Check the items has been appended - assert len(session_items) == 6 - - # Check the items are the last 4 elements - append_input.append(second_model_response) - for sess_item, orig_item in zip(session_items[-4:], append_input): - assert sess_item.get("role") == orig_item.get("role") - - sess_content = sess_item.get("content") - # Narrow to list or str for mypy - assert isinstance(sess_content, (list, str)) - - if isinstance(sess_content, list): - # now mypy knows `content: list[Any]` - assert isinstance(sess_content[0], dict) and "text" in sess_content[0] - val_sess = sess_content[0]["text"] - else: - # here content is str - val_sess = sess_content - - assert val_sess == orig_item["content"] - - session.close() - - -@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) -@pytest.mark.asyncio -async def test_session_memory_replace_list(runner_method): - """Test if the user passes a list of items and want to replace the history.""" - with tempfile.TemporaryDirectory() as temp_dir: - db_path = Path(temp_dir) / "test_memory.db" + await session.add_items(initial_history) - model = FakeModel() - agent = Agent(name="test", model=model) + def filter_assistant_messages(history, new_input): + # Only include user messages from history + return [item for item in history if item["role"] == "user"] + new_input - # Session - session_id = "session_1" - session = SQLiteSession(session_id, db_path) + new_turn_input = [{"role": "user", "content": "What your name?"}] + model.set_next_output([get_text_message("I'm gpt-4o")]) - model.set_next_output([get_text_message("I like cats")]) - _ = await run_agent_async(runner_method, agent, "I like cats", session=session) - - replace_input = [ - {"role": "user", "content": "Some random user text"}, - {"role": "assistant", "content": "You're right"}, - {"role": "user", "content": "What did I say I like?"}, - ] - second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"} - model.set_next_output([get_text_message(second_model_response.get("content", ""))]) - - _ = await run_agent_async( + # Run the agent with the callable + await run_agent_async( runner_method, agent, - replace_input, + new_turn_input, session=session, - run_config=RunConfig(session_input_handling="replace"), + run_config=RunConfig(session_input_callback=filter_assistant_messages), ) - session_items = await session.get_items() - - # Check the new items replaced the history - assert len(session_items) == 4 - - # Check the items are the last 4 elements - replace_input.append(second_model_response) - for sess_item, orig_item in zip(session_items, replace_input): - assert sess_item.get("role") == orig_item.get("role") - sess_content = sess_item.get("content") - # Narrow to list or str for mypy - assert isinstance(sess_content, (list, str)) - - if isinstance(sess_content, list): - # now mypy knows `content: list[Any]` - assert isinstance(sess_content[0], dict) and "text" in sess_content[0] - val_sess = sess_content[0]["text"] - else: - # here content is str - val_sess = sess_content + expected_model_input = [ + initial_history[0], # From history + new_turn_input[0], # New input + ] - assert val_sess == orig_item["content"] + assert len(model.last_turn_args["input"]) == 2 + assert model.last_turn_args["input"] == expected_model_input session.close()