|  | 
|  | 1 | +"""Module for AWS Bedrock Agent Core memory integration. | 
|  | 2 | +
 | 
|  | 3 | +This module provides integration between LangChain/LangGraph and AWS Bedrock Agent Core | 
|  | 4 | +memory API. It includes a memory store implementation and tools for managing and | 
|  | 5 | +searching memories. | 
|  | 6 | +""" | 
|  | 7 | + | 
|  | 8 | +import json | 
|  | 9 | +import logging | 
|  | 10 | +from typing import List | 
|  | 11 | + | 
|  | 12 | +from bedrock_agentcore.memory import MemoryClient | 
|  | 13 | +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage | 
|  | 14 | +from langchain_core.runnables import RunnableConfig | 
|  | 15 | +from langchain_core.tools import StructuredTool | 
|  | 16 | + | 
|  | 17 | +logger = logging.getLogger(__name__) | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +def create_store_messages_tool( | 
|  | 21 | +    memory_client: MemoryClient, | 
|  | 22 | +    name: str = "store_messages" | 
|  | 23 | +) -> StructuredTool: | 
|  | 24 | +    """Create a tool for storing messages directly with Bedrock Agent Core MemoryClient. | 
|  | 25 | +
 | 
|  | 26 | +    This tool enables AI assistants to store messages in Bedrock Agent Core. | 
|  | 27 | +    The tool expects the following configuration values to be passed via RunnableConfig: | 
|  | 28 | +    - memory_id: The ID of the memory to store in | 
|  | 29 | +    - actor_id: (optional) The actor ID to use | 
|  | 30 | +    - session_id: (optional) The session ID to use | 
|  | 31 | +
 | 
|  | 32 | +    Args: | 
|  | 33 | +        memory_client: The MemoryClient instance to use | 
|  | 34 | +        name: The name of the tool | 
|  | 35 | +
 | 
|  | 36 | +    Returns: | 
|  | 37 | +        A structured tool for storing messages | 
|  | 38 | +    """ | 
|  | 39 | + | 
|  | 40 | +    instructions = ( | 
|  | 41 | +        "Use this tool to store all messages from the user and AI model. These " | 
|  | 42 | +        "messages are processed to extract summary or facts of the conversation, " | 
|  | 43 | +        "which can be later retrieved using the search_memory tool." | 
|  | 44 | +    ) | 
|  | 45 | + | 
|  | 46 | +    def store_messages( | 
|  | 47 | +        messages: List[BaseMessage], | 
|  | 48 | +        config: RunnableConfig, | 
|  | 49 | +    ) -> str: | 
|  | 50 | +        """Stores conversation messages in AWS Bedrock Agent Core Memory. | 
|  | 51 | +
 | 
|  | 52 | +        Args: | 
|  | 53 | +            messages: List of messages to store | 
|  | 54 | +
 | 
|  | 55 | +        Returns: | 
|  | 56 | +            A confirmation message. | 
|  | 57 | +        """ | 
|  | 58 | +        if not (configurable := config.get("configurable", None)): | 
|  | 59 | +            raise ValueError( | 
|  | 60 | +                "A runtime config containing memory_id, actor_id, and session_id is required." | 
|  | 61 | +            ) | 
|  | 62 | +         | 
|  | 63 | +        if not (memory_id := configurable.get("memory_id", None)): | 
|  | 64 | +            raise ValueError( | 
|  | 65 | +                "Missing memory_id in the runtime config." | 
|  | 66 | +            ) | 
|  | 67 | +         | 
|  | 68 | +        if not (session_id := configurable.get("session_id", None)): | 
|  | 69 | +            raise ValueError( | 
|  | 70 | +                "Missing session_id in the runtime config." | 
|  | 71 | +            ) | 
|  | 72 | +         | 
|  | 73 | +        if not (actor_id := configurable.get("actor_id", None)): | 
|  | 74 | +            raise ValueError( | 
|  | 75 | +                "Missing actor_id in the runtime config." | 
|  | 76 | +            ) | 
|  | 77 | +             | 
|  | 78 | +        # Convert BaseMessage list to list of (text, role) tuples | 
|  | 79 | +        # TODO: This should correctly convert to  | 
|  | 80 | +        converted_messages = [] | 
|  | 81 | +        for msg in messages: | 
|  | 82 | +             | 
|  | 83 | +            # Skip if event already saved | 
|  | 84 | +            if msg.additional_kwargs.get("event_id", None) is not None: | 
|  | 85 | +                continue | 
|  | 86 | + | 
|  | 87 | +            # Extract text content | 
|  | 88 | +            content = msg.content | 
|  | 89 | +            if isinstance(content, str): | 
|  | 90 | +                text = content | 
|  | 91 | +            elif isinstance(content, dict) and content['type'] == 'text': | 
|  | 92 | +                text = content['text'] | 
|  | 93 | +            else: | 
|  | 94 | +                continue | 
|  | 95 | +             | 
|  | 96 | +            # Map LangChain roles to Bedrock Agent Core roles | 
|  | 97 | +            # Available roles in Bedrock: USER, ASSISTANT, TOOL | 
|  | 98 | +            if msg.type == "human": | 
|  | 99 | +                role = "USER" | 
|  | 100 | +            elif msg.type == "ai": | 
|  | 101 | +                role = "ASSISTANT" | 
|  | 102 | +            elif msg.type == "tool": | 
|  | 103 | +                role = "TOOL" | 
|  | 104 | +            else: | 
|  | 105 | +                continue  # Skip unsupported message types | 
|  | 106 | +             | 
|  | 107 | +            converted_messages.append((text, role)) | 
|  | 108 | +         | 
|  | 109 | +        # Create event with converted messages directly using the MemoryClient | 
|  | 110 | +        response = memory_client.create_event( | 
|  | 111 | +            memory_id=memory_id, | 
|  | 112 | +            actor_id=actor_id, | 
|  | 113 | +            session_id=session_id, | 
|  | 114 | +            messages=converted_messages | 
|  | 115 | +        ) | 
|  | 116 | +         | 
|  | 117 | +        return f"Memory created with ID: {response.get('eventId')}" | 
|  | 118 | + | 
|  | 119 | +    # Create a StructuredTool with the custom name | 
|  | 120 | +    return StructuredTool.from_function( | 
|  | 121 | +        func=store_messages, name=name, description=instructions | 
|  | 122 | +    ) | 
|  | 123 | + | 
|  | 124 | + | 
|  | 125 | +def create_list_messages_tool( | 
|  | 126 | +    memory_client: MemoryClient, | 
|  | 127 | +    name: str = "list_messages", | 
|  | 128 | +) -> StructuredTool: | 
|  | 129 | +    """Create a tool for listing conversation messages from Bedrock Agent Core Memory. | 
|  | 130 | +
 | 
|  | 131 | +    This tool allows AI assistants to retrieve the message history from a conversation | 
|  | 132 | +    stored in Bedrock Agent Core Memory. | 
|  | 133 | +     | 
|  | 134 | +    The tool expects the following configuration values to be passed via RunnableConfig: | 
|  | 135 | +    - memory_id: The ID of the memory to retrieve from (required) | 
|  | 136 | +    - actor_id: The actor ID to use (required) | 
|  | 137 | +    - session_id: The session ID to use (required) | 
|  | 138 | +
 | 
|  | 139 | +    Args: | 
|  | 140 | +        memory_client: The MemoryClient instance to use | 
|  | 141 | +        name: The name of the tool | 
|  | 142 | +
 | 
|  | 143 | +    Returns: | 
|  | 144 | +        A structured tool for listing conversation messages | 
|  | 145 | +    """ | 
|  | 146 | + | 
|  | 147 | +    instructions = ( | 
|  | 148 | +        "Use this tool to retrieve the conversation history from memory. " | 
|  | 149 | +        "This can help in understanding the context of the current conversation, " | 
|  | 150 | +        "or reviewing past interactions." | 
|  | 151 | +    ) | 
|  | 152 | + | 
|  | 153 | +    def list_messages( | 
|  | 154 | +        max_results: int = 100, | 
|  | 155 | +        config: RunnableConfig = None, | 
|  | 156 | +    ) -> List[BaseMessage]: | 
|  | 157 | +        """List conversation messages from AWS Bedrock Agent Core Memory. | 
|  | 158 | +
 | 
|  | 159 | +        Args: | 
|  | 160 | +            max_results: Maximum number of messages to return | 
|  | 161 | +            config: RunnableConfig containing memory_id, actor_id, and session_id | 
|  | 162 | +
 | 
|  | 163 | +        Returns: | 
|  | 164 | +            A list of LangChain message objects (HumanMessage, AIMessage, ToolMessage) | 
|  | 165 | +        """ | 
|  | 166 | +        if not (configurable := config.get("configurable", None)): | 
|  | 167 | +            raise ValueError( | 
|  | 168 | +                "A runtime config with memory_id, actor_id, and session_id is required" | 
|  | 169 | +                " for list_messages tool." | 
|  | 170 | +            ) | 
|  | 171 | +         | 
|  | 172 | +        if not (memory_id := configurable.get("memory_id", None)): | 
|  | 173 | +            raise ValueError( | 
|  | 174 | +                "Missing memory_id in the runtime config." | 
|  | 175 | +            ) | 
|  | 176 | +             | 
|  | 177 | +        if not (actor_id := configurable.get("actor_id", None)): | 
|  | 178 | +            raise ValueError( | 
|  | 179 | +                "Missing actor_id in the runtime config." | 
|  | 180 | +            ) | 
|  | 181 | +             | 
|  | 182 | +        if not (session_id := configurable.get("session_id", None)): | 
|  | 183 | +            raise ValueError( | 
|  | 184 | +                "Missing session_id in the runtime config." | 
|  | 185 | +            ) | 
|  | 186 | +         | 
|  | 187 | +        events = memory_client.list_events( | 
|  | 188 | +            memory_id=memory_id, | 
|  | 189 | +            actor_id=actor_id, | 
|  | 190 | +            session_id=session_id, | 
|  | 191 | +            max_results=max_results, | 
|  | 192 | +            include_payload=True | 
|  | 193 | +        ) | 
|  | 194 | +         | 
|  | 195 | +        # Extract and format messages as LangChain message objects | 
|  | 196 | +        messages = [] | 
|  | 197 | +        for event in events: | 
|  | 198 | +            # Extract messages from event payload | 
|  | 199 | +            if "payload" in event: | 
|  | 200 | +                for payload_item in event.get("payload", []): | 
|  | 201 | +                    if "conversational" in payload_item: | 
|  | 202 | +                        conv = payload_item["conversational"] | 
|  | 203 | +                        role = conv.get("role", "") | 
|  | 204 | +                        content = conv.get("content", {}).get("text", "") | 
|  | 205 | +                         | 
|  | 206 | +                        # Convert to appropriate LangChain message type based on role | 
|  | 207 | +                        if role == "USER": | 
|  | 208 | +                            message = HumanMessage(content=content) | 
|  | 209 | +                        elif role == "ASSISTANT": | 
|  | 210 | +                            message = AIMessage(content=content) | 
|  | 211 | +                        elif role == "TOOL": | 
|  | 212 | +                            #message = ToolMessage(content=content, tool_call_id="unknown") | 
|  | 213 | +                            # skipping tool events as tool_call_id etc. will be missing | 
|  | 214 | +                            continue | 
|  | 215 | +                        else: | 
|  | 216 | +                            # Skip unknown message types | 
|  | 217 | +                            continue | 
|  | 218 | +                             | 
|  | 219 | +                        # Add metadata if available | 
|  | 220 | +                        if "eventId" in event: | 
|  | 221 | +                            message.additional_kwargs["event_id"] = event["eventId"] | 
|  | 222 | +                        if "eventTimestamp" in event: | 
|  | 223 | +                            pass | 
|  | 224 | +                            # Skip this, this currently not serialized correctly | 
|  | 225 | +                            # message.additional_kwargs["timestamp"] = event["eventTimestamp"] | 
|  | 226 | +                             | 
|  | 227 | +                        messages.append(message) | 
|  | 228 | +         | 
|  | 229 | +        return messages | 
|  | 230 | + | 
|  | 231 | +    # Create a StructuredTool with the custom name | 
|  | 232 | +    return StructuredTool.from_function( | 
|  | 233 | +        func=list_messages, name=name, description=instructions | 
|  | 234 | +    ) | 
|  | 235 | + | 
|  | 236 | + | 
|  | 237 | +def create_search_memory_tool( | 
|  | 238 | +    memory_client: MemoryClient, | 
|  | 239 | +    name: str = "search_memory", | 
|  | 240 | +) -> StructuredTool: | 
|  | 241 | +    """Create a tool for searching memories in AWS Bedrock Agent Core. | 
|  | 242 | +
 | 
|  | 243 | +    This tool allows AI assistants to search through stored memories in AWS | 
|  | 244 | +    Bedrock Agent Core using semantic search. | 
|  | 245 | +     | 
|  | 246 | +    The tool expects the following configuration values to be passed via RunnableConfig: | 
|  | 247 | +    - memory_id: The ID of the memory to search in (required) | 
|  | 248 | +    - namespace: The namespace to search in (required) | 
|  | 249 | +
 | 
|  | 250 | +    Args: | 
|  | 251 | +        memory_client: The MemoryClient instance to use | 
|  | 252 | +        name: The name of the tool | 
|  | 253 | +
 | 
|  | 254 | +    Returns: | 
|  | 255 | +        A structured tool for searching memories. | 
|  | 256 | +    """ | 
|  | 257 | + | 
|  | 258 | +    instructions = ( | 
|  | 259 | +        "Use this tool to search for helpful facts and preferences from the past " | 
|  | 260 | +        "conversations. Based on the namespace and configured memories, this will " | 
|  | 261 | +        "provide summaries, user preferences or semantic search for the session." | 
|  | 262 | +    ) | 
|  | 263 | + | 
|  | 264 | +    def search_memory( | 
|  | 265 | +        query: str, | 
|  | 266 | +        limit: int = 3, | 
|  | 267 | +        config: RunnableConfig = None, | 
|  | 268 | +    ) -> str: | 
|  | 269 | +        """Search for memories in AWS Bedrock Agent Core. | 
|  | 270 | +
 | 
|  | 271 | +        Args: | 
|  | 272 | +            query: The search query to find relevant memories. | 
|  | 273 | +            limit: Maximum number of results to return. | 
|  | 274 | +
 | 
|  | 275 | +        Returns: | 
|  | 276 | +            A string representation of the search results. | 
|  | 277 | +        """ | 
|  | 278 | +        if not (configurable := config.get("configurable", None)): | 
|  | 279 | +            raise ValueError( | 
|  | 280 | +                "A runtime config with memory_id, namespace, and actor_id is required." | 
|  | 281 | +            ) | 
|  | 282 | +         | 
|  | 283 | +        if not (memory_id := configurable.get("memory_id", None)): | 
|  | 284 | +            raise ValueError( | 
|  | 285 | +                "Missing memory_id in the runtime config." | 
|  | 286 | +            ) | 
|  | 287 | +             | 
|  | 288 | +        # Namespace is required | 
|  | 289 | +        if not (namespace_val := configurable.get("namespace", None)): | 
|  | 290 | +            raise ValueError( | 
|  | 291 | +                "Missing namespace in the runtime config." | 
|  | 292 | +            ) | 
|  | 293 | +             | 
|  | 294 | +        # Format the namespace | 
|  | 295 | +        if isinstance(namespace_val, tuple): | 
|  | 296 | +            # Join tuple elements with '/' | 
|  | 297 | +            namespace_str = "/" + "/".join(namespace_val) | 
|  | 298 | +        elif isinstance(namespace_val, str): | 
|  | 299 | +            # Ensure string starts with '/' | 
|  | 300 | +            namespace_str = namespace_val if namespace_val.startswith("/") else f"/{namespace_val}" | 
|  | 301 | +        else: | 
|  | 302 | +            raise ValueError( | 
|  | 303 | +                f"Namespace must be a string or tuple, got {type(namespace_val)}" | 
|  | 304 | +            ) | 
|  | 305 | +                 | 
|  | 306 | +        # Perform the search directly using the MemoryClient | 
|  | 307 | +        memories = memory_client.retrieve_memories( | 
|  | 308 | +            memory_id=memory_id, | 
|  | 309 | +            namespace=namespace_str, | 
|  | 310 | +            query=query, | 
|  | 311 | +            top_k=limit, | 
|  | 312 | +        ) | 
|  | 313 | + | 
|  | 314 | +        # Process and format results | 
|  | 315 | +        results = [] | 
|  | 316 | +        for item in memories: | 
|  | 317 | +            # Extract content from the memory item | 
|  | 318 | +            content = item.get("content", {}).get("text", "") | 
|  | 319 | + | 
|  | 320 | +            # Try to parse JSON content if it looks like JSON | 
|  | 321 | +            if content and content.startswith("{") and content.endswith("}"): | 
|  | 322 | +                try: | 
|  | 323 | +                    content = json.loads(content) | 
|  | 324 | +                except json.JSONDecodeError: | 
|  | 325 | +                    pass | 
|  | 326 | + | 
|  | 327 | +            results.append(content) | 
|  | 328 | + | 
|  | 329 | +        return results | 
|  | 330 | +         | 
|  | 331 | + | 
|  | 332 | +    # Create a StructuredTool with the custom name | 
|  | 333 | +    return StructuredTool.from_function( | 
|  | 334 | +        func=search_memory, | 
|  | 335 | +        name=name, | 
|  | 336 | +        description=instructions | 
|  | 337 | +    ) | 
0 commit comments