diff --git a/src/agents/realtime/model.py b/src/agents/realtime/model.py index d7ebe4ffa..c0632aa9b 100644 --- a/src/agents/realtime/model.py +++ b/src/agents/realtime/model.py @@ -118,6 +118,12 @@ class RealtimeModelConfig(TypedDict): the OpenAI Realtime model will use the default OpenAI WebSocket URL. """ + headers: NotRequired[dict[str, str]] + """The headers to use when connecting. If unset, the model will use a sane default. + Note that, when you set this, authorization header won't be set under the hood. + e.g., {"api-key": "your api key here"} for Azure OpenAI Realtime WebSocket connections. + """ + initial_model_settings: NotRequired[RealtimeSessionModelSettings] """The initial model settings to use when connecting.""" diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 766c49f8d..b9048a1ec 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -188,15 +188,23 @@ async def connect(self, options: RealtimeModelConfig) -> None: else: self._tracing_config = "auto" - if not api_key: - raise UserError("API key is required but was not provided.") - url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}") - headers = { - "Authorization": f"Bearer {api_key}", - "OpenAI-Beta": "realtime=v1", - } + headers: dict[str, str] = {} + if options.get("headers") is not None: + # For customizing request headers + headers.update(options["headers"]) + else: + # OpenAI's Realtime API + if not api_key: + raise UserError("API key is required but was not provided.") + + headers.update( + { + "Authorization": f"Bearer {api_key}", + "OpenAI-Beta": "realtime=v1", + } + ) self._websocket = await websockets.connect( url, user_agent_header=_USER_AGENT, @@ -490,9 +498,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): try: if "previous_item_id" in event and event["previous_item_id"] is None: event["previous_item_id"] = "" # TODO (rm) remove - parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python( - event - ) + parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event) except pydantic.ValidationError as e: logger.error(f"Failed to validate server event: {event}", exc_info=True) await self._emit_event( @@ -583,11 +589,13 @@ async def _handle_ws_event(self, event: dict[str, Any]): ): await self._handle_output_item(parsed.item) elif parsed.type == "input_audio_buffer.timeout_triggered": - await self._emit_event(RealtimeModelInputAudioTimeoutTriggeredEvent( - item_id=parsed.item_id, - audio_start_ms=parsed.audio_start_ms, - audio_end_ms=parsed.audio_end_ms, - )) + await self._emit_event( + RealtimeModelInputAudioTimeoutTriggeredEvent( + item_id=parsed.item_id, + audio_start_ms=parsed.audio_start_ms, + audio_end_ms=parsed.audio_end_ms, + ) + ) def _update_created_session(self, session: OpenAISessionObject) -> None: self._created_session = session diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 4c410bf6e..08b8d878f 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -84,8 +84,45 @@ def mock_create_task_func(coro): # Verify internal state assert model._websocket == mock_websocket - assert model._websocket_task is not None - assert model.model == "gpt-4o-realtime-preview" + assert model._websocket_task is not None + assert model.model == "gpt-4o-realtime-preview" + + @pytest.mark.asyncio + async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket): + """If custom headers are provided, use them verbatim without adding defaults.""" + # Even when custom headers are provided, the implementation still requires api_key. + config = { + "api_key": "unused-because-headers-override", + "headers": {"api-key": "azure-key", "x-custom": "1"}, + "url": "wss://custom.example.com/realtime?model=custom", + # Use a valid realtime model name for session.update to validate. + "initial_model_settings": {"model_name": "gpt-4o-realtime-preview"}, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket) as mock_connect: + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + + def mock_create_task_func(coro): + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + + await model.connect(config) + + # Verify WebSocket connection used the provided URL + called_url = mock_connect.call_args[0][0] + assert called_url == "wss://custom.example.com/realtime?model=custom" + + # Verify headers are exactly as provided and no defaults were injected + headers = mock_connect.call_args.kwargs["additional_headers"] + assert headers == {"api-key": "azure-key", "x-custom": "1"} + assert "Authorization" not in headers + assert "OpenAI-Beta" not in headers @pytest.mark.asyncio async def test_connect_with_callable_api_key(self, model, mock_websocket):