From 2f0a6c9ac370e10b9c941475eca0c910bbb66f13 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Mon, 18 Aug 2025 21:23:54 -0300 Subject: [PATCH] fix: handle missing client details and abort keep-alive task on drop - Added guard (AbortTaskOnDrop) to ensure keep-alive task is aborted when no longer needed - Fixed bug where client_info was mistakenly returning None --- .../src/mcp_runtimes/server_runtime.rs | 21 +++++++------------ crates/rust-mcp-sdk/src/utils.rs | 17 +++++++++++++++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d7b53a1..d787a10 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -10,6 +10,7 @@ use crate::schema::{ }, InitializeRequestParams, InitializeResult, RequestId, RpcError, }; +use crate::utils::AbortTaskOnDrop; use async_trait::async_trait; use futures::future::try_join_all; use futures::{StreamExt, TryFutureExt}; @@ -17,7 +18,7 @@ use futures::{StreamExt, TryFutureExt}; use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::sync::{oneshot, watch}; @@ -41,8 +42,6 @@ pub struct ServerRuntime { handler: Arc, // Information about the server server_details: Arc, - // Details about the connected client - client_details: Arc>>, #[cfg(feature = "hyper-server")] session_id: Option, transport_map: tokio::sync::RwLock>, @@ -123,12 +122,7 @@ impl McpServer for ServerRuntime { /// Returns the client information if available, after successful initialization , otherwise returns None fn client_info(&self) -> Option { - if let Ok(details) = self.client_details.read() { - details.clone() - } else { - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - None - } + self.client_details_rx.borrow().clone() } /// Main runtime loop, processes incoming messages and handles requests @@ -404,6 +398,11 @@ impl ServerRuntime { .await? .abort_handle(); + // ensure keep_alive task will be aborted + let _abort_guard = AbortTaskOnDrop { + handle: abort_alive_task, + }; + // in case there is a payload, we consume it by transport to get processed if let Some(payload) = payload { transport.consume_string_payload(&payload).await?; @@ -439,13 +438,11 @@ impl ServerRuntime { } // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ - abort_alive_task.abort(); return Ok(()); } } _ = &mut disconnect_rx => { self.remove_transport(stream_id).await?; - abort_alive_task.abort(); // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); @@ -469,7 +466,6 @@ impl ServerRuntime { watch::channel::>(None); Self { server_details, - client_details: Arc::new(RwLock::new(None)), handler, session_id: Some(session_id), transport_map: tokio::sync::RwLock::new(HashMap::new()), @@ -495,7 +491,6 @@ impl ServerRuntime { watch::channel::>(None); Self { server_details: Arc::new(server_details), - client_details: Arc::new(RwLock::new(None)), handler, #[cfg(feature = "hyper-server")] session_id: None, diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index de92a06..e98a1ed 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -4,6 +4,23 @@ use crate::error::{McpSdkError, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; +/// A guard type that automatically aborts a Tokio task when dropped. +/// +/// This ensures that the associated task does not outlive the scope +/// of this struct, preventing runaway or leaked background tasks. +/// +pub struct AbortTaskOnDrop { + /// The handle used to abort the spawned Tokio task. + pub handle: tokio::task::AbortHandle, +} + +impl Drop for AbortTaskOnDrop { + fn drop(&mut self) { + // Automatically abort the associated task when this guard is dropped. + self.handle.abort(); + } +} + /// Formats an assertion error message for unsupported capabilities. /// /// Constructs a string describing that a specific entity (e.g., server or client) lacks