Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import datetime
import json
import logging
import warnings
from collections import defaultdict
from typing import Any, cast

Expand Down Expand Up @@ -184,22 +185,32 @@ def store_blob_events_batch(
)

def get_events(
self, session_id: str, actor_id: str, limit: int = 100
self, session_id: str, actor_id: str, max_results: int | None = 100
) -> list[EventType]:
"""Retrieve events from AgentCore Memory."""
"""Retrieve events from AgentCore Memory.

if limit is not None and limit <= 0:
Args:
session_id: The session ID to retrieve events for
actor_id: The actor ID to retrieve events for
max_results: Maximum number of events to retrieve. Defaults to 100.

Returns:
List of retrieved events
"""

if max_results is not None and max_results <= 0:
return []

all_events = []
next_token = None
limit_reached = False

while True:
params = {
"memoryId": self.memory_id,
"actorId": actor_id,
"sessionId": session_id,
"maxResults": 100,
"maxResults": max_results,
"includePayloads": True,
}

Expand All @@ -218,8 +229,26 @@ def get_events(
except EventDecodingError as e:
logger.warning(f"Failed to decode event: {e}")

if max_results is not None and len(all_events) >= max_results:
limit_reached = True
break

if limit_reached:
break

next_token = response.get("nextToken")
if not next_token or (limit is not None and len(all_events) >= limit):

if limit_reached and next_token:
warnings.warn(
f"Stopped retrieving events at max_results of {max_results}. "
f"There may be additional checkpoints that were not retrieved. "
f"Consider increasing the max_results parameter "
"(defaults to 100). ",
UserWarning,
stacklevel=2,
)

if limit_reached or not next_token:
break

return all_events
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,40 @@ class AgentCoreMemorySaver(BaseCheckpointSaver[str]):
Args:
memory_id: the ID of the memory resource created in AgentCore Memory
serde: serialization protocol to be used. Defaults to JSONPlusSerializer
max_results: maximum number of events to retrieve from AgentCore Memory.
Set to None for no limit. Defaults to 100
"""

def __init__(
self,
memory_id: str,
*,
serde: SerializerProtocol | None = None,
max_results: int | None = 100,
**boto3_kwargs: Any,
) -> None:
super().__init__(serde=serde)

self.memory_id = memory_id
self.max_results = max_results
self.serializer = EventSerializer(self.serde)
self.checkpoint_event_client = AgentCoreEventClient(
memory_id, self.serializer, **boto3_kwargs
)
self.processor = EventProcessor()

def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Get a checkpoint tuple from Bedrock AgentCore Memory."""
def get_tuple(
self,
config: RunnableConfig,
) -> CheckpointTuple | None:
"""Get a checkpoint tuple from Bedrock AgentCore Memory.

Args:
config: The runnable config containing checkpoint information

Returns:
CheckpointTuple if found, None otherwise
"""

# TODO: There is room for caching here on the client side

Expand All @@ -77,7 +91,9 @@ def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
)

events = self.checkpoint_event_client.get_events(
checkpoint_config.session_id, checkpoint_config.actor_id
checkpoint_config.session_id,
checkpoint_config.actor_id,
self.max_results,
)

checkpoints, writes_by_checkpoint, channel_data = self.processor.process_events(
Expand Down Expand Up @@ -122,7 +138,7 @@ def list(
events = self.checkpoint_event_client.get_events(
checkpoint_config.session_id,
checkpoint_config.actor_id,
100 if limit is None else limit,
self.max_results,
)

checkpoints, writes_by_checkpoint, channel_data = self.processor.process_events(
Expand Down