Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/agents/lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Generic
from typing import Any, Generic, Optional

from typing_extensions import TypeVar

from .agent import Agent, AgentBase
from .items import ModelResponse, TResponseInputItem
from .run_context import RunContextWrapper, TContext
from .tool import Tool

Expand All @@ -14,6 +15,25 @@ class RunHooksBase(Generic[TContext, TAgent]):
override the methods you need.
"""

async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
"""Called just before invoking the LLM for this agent."""
pass

async def on_llm_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
"""Called immediately after the LLM call returns for this agent."""
pass

async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
"""Called before the agent is invoked. Called each time the current agent changes."""
pass
Expand Down Expand Up @@ -106,6 +126,25 @@ async def on_tool_end(
"""Called after a tool is invoked."""
pass

async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
"""Called immediately before the agent issues an LLM call."""
pass

async def on_llm_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
"""Called immediately after the agent receives the LLM response."""
pass


RunHooks = RunHooksBase[TContext, Agent]
"""Run hooks when using `Agent`."""
Expand Down
22 changes: 22 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ async def _run_single_turn_streamed(
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
input.extend([item.to_input_item() for item in streamed_result.new_items])

# THIS IS THE RESOLVED CONFLICT BLOCK
filtered = await cls._maybe_filter_model_input(
agent=agent,
run_config=run_config,
Expand All @@ -943,6 +944,12 @@ async def _run_single_turn_streamed(
system_instructions=system_prompt,
)

# Call hook just before the model is invoked, with the correct system_prompt.
if agent.hooks:
await agent.hooks.on_llm_start(
context_wrapper, agent, filtered.instructions, filtered.input
)

# 1. Stream the output events
async for event in model.stream_response(
filtered.instructions,
Expand Down Expand Up @@ -979,6 +986,10 @@ async def _run_single_turn_streamed(

streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

# Call hook just after the model response is finalized.
if agent.hooks and final_response is not None:
await agent.hooks.on_llm_end(context_wrapper, agent, final_response)

# 2. At this point, the streaming is complete for this turn of the agent loop.
if not final_response:
raise ModelBehaviorError("Model did not produce a final response!")
Expand Down Expand Up @@ -1252,6 +1263,14 @@ async def _get_new_response(
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
# If the agent has hooks, we need to call them before and after the LLM call
if agent.hooks:
await agent.hooks.on_llm_start(
context_wrapper,
agent,
filtered.instructions, # Use filtered instructions
filtered.input, # Use filtered input
)

new_response = await model.get_response(
system_instructions=filtered.instructions,
Expand All @@ -1266,6 +1285,9 @@ async def _get_new_response(
previous_response_id=previous_response_id,
prompt=prompt_config,
)
# If the agent has hooks, we need to call them after the LLM call
if agent.hooks:
await agent.hooks.on_llm_end(context_wrapper, agent, new_response)

context_wrapper.usage.add(new_response.usage)

Expand Down
130 changes: 130 additions & 0 deletions tests/test_agent_llm_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from collections import defaultdict
from typing import Any, Optional

import pytest

from agents.agent import Agent
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
from agents.lifecycle import AgentHooks
from agents.run import Runner
from agents.run_context import RunContextWrapper, TContext
from agents.tool import Tool

from .fake_model import FakeModel
from .test_responses import (
get_function_tool,
get_text_message,
)


class AgentHooksForTests(AgentHooks):
def __init__(self):
self.events: dict[str, int] = defaultdict(int)

def reset(self):
self.events.clear()

async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
self.events["on_start"] += 1

async def on_end(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any
) -> None:
self.events["on_end"] += 1

async def on_handoff(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]
) -> None:
self.events["on_handoff"] += 1

async def on_tool_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool
) -> None:
self.events["on_tool_start"] += 1

async def on_tool_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
) -> None:
self.events["on_tool_end"] += 1

# NEW: LLM hooks
async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
self.events["on_llm_start"] += 1

async def on_llm_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
self.events["on_llm_end"] += 1


# Example test using the above hooks:
@pytest.mark.asyncio
async def test_async_agent_hooks_with_llm():
hooks = AgentHooksForTests()
model = FakeModel()
agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
await Runner.run(agent, input="hello")
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}


# test_sync_agent_hook_with_llm()
def test_sync_agent_hook_with_llm():
hooks = AgentHooksForTests()
model = FakeModel()
agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
Runner.run_sync(agent, input="hello")
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}


# test_streamed_agent_hooks_with_llm():
@pytest.mark.asyncio
async def test_streamed_agent_hooks_with_llm():
hooks = AgentHooksForTests()
model = FakeModel()
agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
stream = Runner.run_streamed(agent, input="hello")

async for event in stream.stream_events():
if event.type == "raw_response_event":
continue
if event.type == "agent_updated_stream_event":
print(f"[EVENT] agent_updated → {event.new_agent.name}")
elif event.type == "run_item_stream_event":
item = event.item
if item.type == "tool_call_item":
print("[EVENT] tool_call_item")
elif item.type == "tool_call_output_item":
print(f"[EVENT] tool_call_output_item → {item.output}")
elif item.type == "message_output_item":
text = ItemHelpers.text_message_output(item)
print(f"[EVENT] message_output_item → {text}")

# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}