diff --git a/src/agents/items.py b/src/agents/items.py index c43e9f856..651d73127 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -1,7 +1,6 @@ from __future__ import annotations import abc -import copy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union @@ -277,7 +276,7 @@ def input_to_new_input_list( "role": "user", } ] - return copy.deepcopy(input) + return input.copy() @classmethod def text_message_outputs(cls, items: list[RunItem]) -> str: diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 4629f1bb5..42d61cf2b 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -10,6 +10,7 @@ from ..agent import Agent from ..exceptions import ModelBehaviorError, UserError from ..handoffs import Handoff +from ..logger import logger from ..run_context import RunContextWrapper, TContext from ..tool import FunctionTool from ..tool_context import ToolContext @@ -33,7 +34,7 @@ RealtimeToolStart, ) from .handoffs import realtime_handoff -from .items import InputAudio, InputText, RealtimeItem +from .items import AssistantAudio, InputAudio, InputText, RealtimeItem from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener from .model_events import ( RealtimeModelEvent, @@ -246,7 +247,58 @@ async def on_event(self, event: RealtimeModelEvent) -> None: self._enqueue_guardrail_task(self._item_transcripts[item_id], event.response_id) elif event.type == "item_updated": is_new = not any(item.item_id == event.item.item_id for item in self._history) - self._history = self._get_new_history(self._history, event.item) + + # Preserve previously known transcripts when updating existing items. + # This prevents transcripts from disappearing when an item is later + # retrieved without transcript fields populated. + incoming_item = event.item + existing_item = next( + (i for i in self._history if i.item_id == incoming_item.item_id), None + ) + + if ( + existing_item is not None + and existing_item.type == "message" + and incoming_item.type == "message" + ): + try: + # Merge transcripts for matching content indices + existing_content = existing_item.content + new_content = [] + for idx, entry in enumerate(incoming_item.content): + # Only attempt to preserve for audio-like content + if entry.type in ("audio", "input_audio"): + # Use tuple form for Python 3.9 compatibility + assert isinstance(entry, (InputAudio, AssistantAudio)) + # Determine if transcript is missing/empty on the incoming entry + entry_transcript = entry.transcript + if not entry_transcript: + preserved: str | None = None + # First prefer any transcript from the existing history item + if idx < len(existing_content): + this_content = existing_content[idx] + if isinstance(this_content, AssistantAudio) or isinstance( + this_content, InputAudio + ): + preserved = this_content.transcript + + # If still missing and this is an assistant item, fall back to + # accumulated transcript deltas tracked during the turn. + if not preserved and incoming_item.role == "assistant": + preserved = self._item_transcripts.get(incoming_item.item_id) + + if preserved: + entry = entry.model_copy(update={"transcript": preserved}) + + new_content.append(entry) + + if new_content: + incoming_item = incoming_item.model_copy(update={"content": new_content}) + except Exception: + logger.error("Error merging transcripts", exc_info=True) + pass + + self._history = self._get_new_history(self._history, incoming_item) if is_new: new_item = next( item for item in self._history if item.item_id == event.item.item_id diff --git a/src/agents/run.py b/src/agents/run.py index 5f9ec10ac..3945e5131 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import copy import inspect from dataclasses import dataclass, field from typing import Any, Callable, Generic, cast @@ -387,7 +386,7 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 - original_input: str | list[TResponseInputItem] = copy.deepcopy(prepared_input) + original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -446,7 +445,7 @@ async def run( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(prepared_input), + _copy_str_or_list(prepared_input), context_wrapper, ), self._run_single_turn( @@ -594,7 +593,7 @@ def run_streamed( ) streamed_result = RunResultStreaming( - input=copy.deepcopy(input), + input=_copy_str_or_list(input), new_items=[], current_agent=starting_agent, raw_responses=[], @@ -647,7 +646,7 @@ async def _maybe_filter_model_input( try: model_input = ModelInputData( - input=copy.deepcopy(effective_input), + input=effective_input.copy(), instructions=effective_instructions, ) filter_payload: CallModelData[TContext] = CallModelData( @@ -786,7 +785,7 @@ async def _start_streaming( cls._run_input_guardrails_with_queue( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(ItemHelpers.input_to_new_input_list(prepared_input)), + ItemHelpers.input_to_new_input_list(prepared_input), context_wrapper, streamed_result, current_span, @@ -1376,3 +1375,9 @@ async def _save_result_to_session( DEFAULT_AGENT_RUNNER = AgentRunner() + + +def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]: + if isinstance(input, str): + return input + return input.copy() diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 3b6c5bac6..cd562c522 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -22,6 +22,7 @@ RealtimeToolStart, ) from agents.realtime.items import ( + AssistantAudio, AssistantMessageItem, AssistantText, InputAudio, @@ -1625,3 +1626,65 @@ async def test_update_agent_creates_handoff_and_session_update_event(self, mock_ # Check that the current agent and session settings are updated assert session._current_agent == second_agent + + +class TestTranscriptPreservation: + """Tests ensuring assistant transcripts are preserved across updates.""" + + @pytest.mark.asyncio + async def test_assistant_transcript_preserved_on_item_update(self, mock_model, mock_agent): + session = RealtimeSession(mock_model, mock_agent, None) + + # Initial assistant message with audio transcript present (e.g., from first turn) + initial_item = AssistantMessageItem( + item_id="assist_1", + role="assistant", + content=[AssistantAudio(audio=None, transcript="Hello there")], + ) + session._history = [initial_item] + + # Later, the platform retrieves/updates the same item but without transcript populated + updated_without_transcript = AssistantMessageItem( + item_id="assist_1", + role="assistant", + content=[AssistantAudio(audio=None, transcript=None)], + ) + + await session.on_event(RealtimeModelItemUpdatedEvent(item=updated_without_transcript)) + + # Transcript should be preserved from existing history + assert len(session._history) == 1 + preserved_item = cast(AssistantMessageItem, session._history[0]) + assert isinstance(preserved_item.content[0], AssistantAudio) + assert preserved_item.content[0].transcript == "Hello there" + + @pytest.mark.asyncio + async def test_assistant_transcript_can_fallback_to_deltas(self, mock_model, mock_agent): + session = RealtimeSession(mock_model, mock_agent, None) + + # Simulate transcript deltas accumulated for an assistant item during generation + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="assist_2", delta="partial transcript", response_id="resp_2" + ) + ) + + # Add initial assistant message without transcript + initial_item = AssistantMessageItem( + item_id="assist_2", + role="assistant", + content=[AssistantAudio(audio=None, transcript=None)], + ) + await session.on_event(RealtimeModelItemUpdatedEvent(item=initial_item)) + + # Later update still lacks transcript; merge should fallback to accumulated deltas + update_again = AssistantMessageItem( + item_id="assist_2", + role="assistant", + content=[AssistantAudio(audio=None, transcript=None)], + ) + await session.on_event(RealtimeModelItemUpdatedEvent(item=update_again)) + + preserved_item = cast(AssistantMessageItem, session._history[0]) + assert isinstance(preserved_item.content[0], AssistantAudio) + assert preserved_item.content[0].transcript == "partial transcript" diff --git a/tests/test_items_helpers.py b/tests/test_items_helpers.py index f711f21e1..a94d74547 100644 --- a/tests/test_items_helpers.py +++ b/tests/test_items_helpers.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + from openai.types.responses.response_computer_tool_call import ( ActionScreenshot, ResponseComputerToolCall, @@ -20,8 +22,10 @@ from openai.types.responses.response_output_message_param import ResponseOutputMessageParam from openai.types.responses.response_output_refusal import ResponseOutputRefusal from openai.types.responses.response_output_text import ResponseOutputText +from openai.types.responses.response_output_text_param import ResponseOutputTextParam from openai.types.responses.response_reasoning_item import ResponseReasoningItem, Summary from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam +from pydantic import TypeAdapter from agents import ( Agent, @@ -290,3 +294,34 @@ def test_to_input_items_for_reasoning() -> None: print(converted_dict) print(expected) assert converted_dict == expected + + +def test_input_to_new_input_list_copies_the_ones_produced_by_pydantic() -> None: + # Given a list of message dictionaries, ensure the returned list is a deep copy. + original = ResponseOutputMessageParam( + id="a75654dc-7492-4d1c-bce0-89e8312fbdd7", + content=[ + ResponseOutputTextParam( + type="output_text", + text="Hey, what's up?", + annotations=[], + ) + ], + role="assistant", + status="completed", + type="message", + ) + original_json = json.dumps(original) + output_item = TypeAdapter(ResponseOutputMessageParam).validate_json(original_json) + new_list = ItemHelpers.input_to_new_input_list([output_item]) + assert len(new_list) == 1 + assert new_list[0]["id"] == original["id"] # type: ignore + size = 0 + for i, item in enumerate(original["content"]): + size += 1 # pydantic_core._pydantic_core.ValidatorIterator does not support len() + assert item["type"] == original["content"][i]["type"] # type: ignore + assert item["text"] == original["content"][i]["text"] # type: ignore + assert size == 1 + assert new_list[0]["role"] == original["role"] # type: ignore + assert new_list[0]["status"] == original["status"] # type: ignore + assert new_list[0]["type"] == original["type"]