99"""
1010
1111import contextlib
12+ import logging
1213from collections .abc import Callable
1314from datetime import timedelta
15+ from types import TracebackType
1416from typing import Any , TypeAlias
1517
18+ import anyio
1619from pydantic import BaseModel
20+ from typing_extensions import Self
1721
1822import mcp
1923from mcp import types
@@ -72,6 +76,14 @@ class ClientSessionGroup:
7276 For auxiliary handlers, such as resource subscription, this is delegated to
7377 the client and can be accessed via the session. For example:
7478 mcp_session_group.get_session("server_name").subscribe_to_resource(...)
79+
80+ Example Usage:
81+ name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
82+ with async ClientSessionGroup(component_name_hook=name_fn) as group:
83+ for server_params in server_params:
84+ group.connect_to_server(server_param)
85+ ...
86+
7587 """
7688
7789 class _ComponentNames (BaseModel ):
@@ -90,6 +102,7 @@ class _ComponentNames(BaseModel):
90102 _sessions : dict [mcp .ClientSession , _ComponentNames ]
91103 _tool_to_session : dict [str , mcp .ClientSession ]
92104 _exit_stack : contextlib .AsyncExitStack
105+ _session_exit_stacks : dict [mcp .ClientSession , contextlib .AsyncExitStack ]
93106
94107 # Optional fn consuming (component_name, serverInfo) for custom names.
95108 # This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +112,7 @@ class _ComponentNames(BaseModel):
99112
100113 def __init__ (
101114 self ,
102- exit_stack : contextlib .AsyncExitStack = contextlib . AsyncExitStack () ,
115+ exit_stack : contextlib .AsyncExitStack | None = None ,
103116 component_name_hook : _ComponentNameHook | None = None ,
104117 ) -> None :
105118 """Initializes the MCP client."""
@@ -110,9 +123,41 @@ def __init__(
110123
111124 self ._sessions = {}
112125 self ._tool_to_session = {}
113- self ._exit_stack = exit_stack
126+ self ._exit_stack = exit_stack or contextlib .AsyncExitStack ()
127+ self ._session_exit_stacks = {}
114128 self ._component_name_hook = component_name_hook
115129
130+ async def __aenter__ (self ) -> Self :
131+ # If ClientSessionGroup itself is managing the lifecycle of _exit_stack
132+ # (i.e., it created it), it should enter it.
133+ # If _exit_stack was passed in, it's assumed the caller manages
134+ # its entry/exit.
135+ # For simplicity and consistency with how AsyncExitStack is often used when
136+ # provided as a dependency, we might not need to enter it here if it's
137+ # managed externally. However, if this class is the primary owner, entering it
138+ # ensures its 'aclose' is called even if passed in. Let's assume the
139+ # passed-in stack is already entered by the caller if needed.
140+ # For now, we just return self as the main stack's lifecycle is tied to aclose.
141+ return self
142+
143+ async def __aexit__ (
144+ self ,
145+ _exc_type : type [BaseException ] | None ,
146+ _exc_val : BaseException | None ,
147+ _exc_tb : TracebackType | None ,
148+ ) -> bool | None :
149+ """Closes session exit stacks and main exit stack upon completion."""
150+ async with anyio .create_task_group () as tg :
151+ for exit_stack in self ._session_exit_stacks .values ():
152+ tg .start_soon (exit_stack .aclose )
153+ await self ._exit_stack .aclose ()
154+ return None
155+
156+ @property
157+ def sessions (self ) -> list [mcp .ClientSession ]:
158+ """Returns the list of sessions being managed."""
159+ return list (self ._sessions .keys ())
160+
116161 @property
117162 def prompts (self ) -> dict [str , types .Prompt ]:
118163 """Returns the prompts as a dictionary of names to prompts."""
@@ -131,33 +176,45 @@ def tools(self) -> dict[str, types.Tool]:
131176 async def call_tool (self , name : str , args : dict [str , Any ]) -> types .CallToolResult :
132177 """Executes a tool given its name and arguments."""
133178 session = self ._tool_to_session [name ]
134- return await session .call_tool (name , args )
179+ session_tool_name = self .tools [name ].name
180+ return await session .call_tool (session_tool_name , args )
135181
136- def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
182+ async def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
137183 """Disconnects from a single MCP server."""
138184
139- if session not in self ._sessions :
185+ session_known_for_components = session in self ._sessions
186+ session_known_for_stack = session in self ._session_exit_stacks
187+
188+ if not session_known_for_components and not session_known_for_stack :
140189 raise McpError (
141190 types .ErrorData (
142191 code = types .INVALID_PARAMS ,
143- message = "Provided session is not being managed." ,
192+ message = "Provided session is not managed or already disconnected ." ,
144193 )
145194 )
146- component_names = self ._sessions [session ]
147-
148- # Remove prompts associated with the session.
149- for name in component_names .prompts :
150- del self ._prompts [name ]
151195
152- # Remove resources associated with the session.
153- for name in component_names .resources :
154- del self ._resources [name ]
155-
156- # Remove tools associated with the session.
157- for name in component_names .tools :
158- del self ._tools [name ]
159-
160- del self ._sessions [session ]
196+ if session_known_for_components :
197+ component_names = self ._sessions .pop (session ) # Pop from _sessions tracking
198+
199+ # Remove prompts associated with the session.
200+ for name in component_names .prompts :
201+ if name in self ._prompts :
202+ del self ._prompts [name ]
203+ # Remove resources associated with the session.
204+ for name in component_names .resources :
205+ if name in self ._resources :
206+ del self ._resources [name ]
207+ # Remove tools associated with the session.
208+ for name in component_names .tools :
209+ if name in self ._tools :
210+ del self ._tools [name ]
211+ if name in self ._tool_to_session :
212+ del self ._tool_to_session [name ]
213+
214+ # Clean up the session's resources via its dedicated exit stack
215+ if session_known_for_stack :
216+ session_stack_to_close = self ._session_exit_stacks .pop (session )
217+ await session_stack_to_close .aclose ()
161218
162219 async def connect_to_server (
163220 self ,
@@ -181,47 +238,66 @@ async def connect_to_server(
181238 tool_to_session_temp : dict [str , mcp .ClientSession ] = {}
182239
183240 # Query the server for its prompts and aggregate to list.
184- prompts = (await session .list_prompts ()).prompts
185- for prompt in prompts :
186- name = self ._component_name (prompt .name , server_info )
187- if name in self ._prompts :
188- raise McpError (
189- types .ErrorData (
190- code = types .INVALID_PARAMS ,
191- message = f"{ name } already exists in group prompts." ,
192- )
193- )
194- prompts_temp [name ] = prompt
195- component_names .prompts .add (name )
241+ try :
242+ prompts = (await session .list_prompts ()).prompts
243+ for prompt in prompts :
244+ name = self ._component_name (prompt .name , server_info )
245+ prompts_temp [name ] = prompt
246+ component_names .prompts .add (name )
247+ except McpError as err :
248+ logging .warning (f"Could not fetch prompts: { err } " )
196249
197250 # Query the server for its resources and aggregate to list.
198- resources = (await session .list_resources ()).resources
199- for resource in resources :
200- name = self ._component_name (resource .name , server_info )
201- if name in self ._resources :
202- raise McpError (
203- types .ErrorData (
204- code = types .INVALID_PARAMS ,
205- message = f"{ name } already exists in group resources." ,
206- )
207- )
208- resources_temp [name ] = resource
209- component_names .resources .add (name )
251+ try :
252+ resources = (await session .list_resources ()).resources
253+ for resource in resources :
254+ name = self ._component_name (resource .name , server_info )
255+ resources_temp [name ] = resource
256+ component_names .resources .add (name )
257+ except McpError as err :
258+ logging .warning (f"Could not fetch resources: { err } " )
210259
211260 # Query the server for its tools and aggregate to list.
212- tools = (await session .list_tools ()).tools
213- for tool in tools :
214- name = self ._component_name (tool .name , server_info )
215- if name in self ._tools :
216- raise McpError (
217- types .ErrorData (
218- code = types .INVALID_PARAMS ,
219- message = f"{ name } already exists in group tools." ,
220- )
261+ try :
262+ tools = (await session .list_tools ()).tools
263+ for tool in tools :
264+ name = self ._component_name (tool .name , server_info )
265+ tools_temp [name ] = tool
266+ tool_to_session_temp [name ] = session
267+ component_names .tools .add (name )
268+ except McpError as err :
269+ logging .warning (f"Could not fetch tools: { err } " )
270+
271+ # Clean up exit stack for session if we couldn't retrieve anything
272+ # from the server.
273+ if not any ((prompts_temp , resources_temp , tools_temp )):
274+ del self ._session_exit_stacks [session ]
275+
276+ # Check for duplicates.
277+ matching_prompts = prompts_temp .keys () & self ._prompts .keys ()
278+ if matching_prompts :
279+ raise McpError (
280+ types .ErrorData (
281+ code = types .INVALID_PARAMS ,
282+ message = f"{ matching_prompts } already exist in group prompts." ,
221283 )
222- tools_temp [name ] = tool
223- tool_to_session_temp [name ] = session
224- component_names .tools .add (name )
284+ )
285+ matching_resources = resources_temp .keys () & self ._resources .keys ()
286+ if matching_resources :
287+ raise McpError (
288+ types .ErrorData (
289+ code = types .INVALID_PARAMS ,
290+ message = f"{ matching_resources } already exist in group resources." ,
291+ )
292+ )
293+ matching_tools = tools_temp .keys () & self ._tools .keys ()
294+ if matching_tools :
295+ raise McpError (
296+ types .ErrorData (
297+ code = types .INVALID_PARAMS ,
298+ message = f"{ matching_tools } already exist in group tools." ,
299+ )
300+ )
225301
226302 # Aggregate components.
227303 self ._sessions [session ] = component_names
@@ -237,33 +313,48 @@ async def _establish_session(
237313 ) -> tuple [types .Implementation , mcp .ClientSession ]:
238314 """Establish a client session to an MCP server."""
239315
240- # Create read and write streams that facilitate io with the server.
241- if isinstance (server_params , StdioServerParameters ):
242- client = mcp .stdio_client (server_params )
243- read , write = await self ._exit_stack .enter_async_context (client )
244- elif isinstance (server_params , SseServerParameters ):
245- client = sse_client (
246- url = server_params .url ,
247- headers = server_params .headers ,
248- timeout = server_params .timeout ,
249- sse_read_timeout = server_params .sse_read_timeout ,
250- )
251- read , write = await self ._exit_stack .enter_async_context (client )
252- else :
253- client = streamablehttp_client (
254- url = server_params .url ,
255- headers = server_params .headers ,
256- timeout = server_params .timeout ,
257- sse_read_timeout = server_params .sse_read_timeout ,
258- terminate_on_close = server_params .terminate_on_close ,
259- )
260- read , write , _ = await self ._exit_stack .enter_async_context (client )
316+ session_specific_stack = contextlib .AsyncExitStack ()
317+ try :
318+ # Create read and write streams that facilitate io with the server.
319+ if isinstance (server_params , StdioServerParameters ):
320+ client = mcp .stdio_client (server_params )
321+ read , write = await self ._exit_stack .enter_async_context (client )
322+ elif isinstance (server_params , SseServerParameters ):
323+ client = sse_client (
324+ url = server_params .url ,
325+ headers = server_params .headers ,
326+ timeout = server_params .timeout ,
327+ sse_read_timeout = server_params .sse_read_timeout ,
328+ )
329+ read , write = await self ._exit_stack .enter_async_context (client )
330+ else :
331+ client = streamablehttp_client (
332+ url = server_params .url ,
333+ headers = server_params .headers ,
334+ timeout = server_params .timeout ,
335+ sse_read_timeout = server_params .sse_read_timeout ,
336+ terminate_on_close = server_params .terminate_on_close ,
337+ )
338+ read , write , _ = await self ._exit_stack .enter_async_context (client )
261339
262- session = await self ._exit_stack .enter_async_context (
263- mcp .ClientSession (read , write )
264- )
265- result = await session .initialize ()
266- return result .serverInfo , session
340+ session = await self ._exit_stack .enter_async_context (
341+ mcp .ClientSession (read , write )
342+ )
343+ result = await session .initialize ()
344+
345+ # Session successfully initialized.
346+ # Store its stack and register the stack with the main group stack.
347+ self ._session_exit_stacks [session ] = session_specific_stack
348+ # session_specific_stack itself becomes a resource managed by the
349+ # main _exit_stack.
350+ await self ._exit_stack .enter_async_context (session_specific_stack )
351+
352+ return result .serverInfo , session
353+ except Exception :
354+ # If anything during this setup fails, ensure the session-specific
355+ # stack is closed.
356+ await session_specific_stack .aclose ()
357+ raise
267358
268359 def _component_name (self , name : str , server_info : types .Implementation ) -> str :
269360 if self ._component_name_hook :
0 commit comments