Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ use crate::schema::{
},
InitializeRequestParams, InitializeResult, RequestId, RpcError,
};
use crate::utils::AbortTaskOnDrop;
use async_trait::async_trait;
use futures::future::try_join_all;
use futures::{StreamExt, TryFutureExt};
#[cfg(feature = "hyper-server")]
use rust_mcp_transport::SessionId;
use rust_mcp_transport::{IoStream, TransportDispatcher};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::sync::{oneshot, watch};
Expand All @@ -41,8 +42,6 @@ pub struct ServerRuntime {
handler: Arc<dyn McpServerHandler>,
// Information about the server
server_details: Arc<InitializeResult>,
// Details about the connected client
client_details: Arc<RwLock<Option<InitializeRequestParams>>>,
#[cfg(feature = "hyper-server")]
session_id: Option<SessionId>,
transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
Expand Down Expand Up @@ -123,12 +122,7 @@ impl McpServer for ServerRuntime {

/// Returns the client information if available, after successful initialization , otherwise returns None
fn client_info(&self) -> Option<InitializeRequestParams> {
if let Ok(details) = self.client_details.read() {
details.clone()
} else {
// Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
None
}
self.client_details_rx.borrow().clone()
}

/// Main runtime loop, processes incoming messages and handles requests
Expand Down Expand Up @@ -404,6 +398,11 @@ impl ServerRuntime {
.await?
.abort_handle();

// ensure keep_alive task will be aborted
let _abort_guard = AbortTaskOnDrop {
handle: abort_alive_task,
};

// in case there is a payload, we consume it by transport to get processed
if let Some(payload) = payload {
transport.consume_string_payload(&payload).await?;
Expand Down Expand Up @@ -439,13 +438,11 @@ impl ServerRuntime {
}
// close the stream after all messages are sent, unless it is a standalone stream
if !stream_id.eq(DEFAULT_STREAM_ID){
abort_alive_task.abort();
return Ok(());
}
}
_ = &mut disconnect_rx => {
self.remove_transport(stream_id).await?;
abort_alive_task.abort();
// Disconnection detected by keep-alive task
return Err(SdkError::connection_closed().into());

Expand All @@ -469,7 +466,6 @@ impl ServerRuntime {
watch::channel::<Option<InitializeRequestParams>>(None);
Self {
server_details,
client_details: Arc::new(RwLock::new(None)),
handler,
session_id: Some(session_id),
transport_map: tokio::sync::RwLock::new(HashMap::new()),
Expand All @@ -495,7 +491,6 @@ impl ServerRuntime {
watch::channel::<Option<InitializeRequestParams>>(None);
Self {
server_details: Arc::new(server_details),
client_details: Arc::new(RwLock::new(None)),
handler,
#[cfg(feature = "hyper-server")]
session_id: None,
Expand Down
17 changes: 17 additions & 0 deletions crates/rust-mcp-sdk/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@ use crate::error::{McpSdkError, SdkResult};
use crate::schema::ProtocolVersion;
use std::cmp::Ordering;

/// A guard type that automatically aborts a Tokio task when dropped.
///
/// This ensures that the associated task does not outlive the scope
/// of this struct, preventing runaway or leaked background tasks.
///
pub struct AbortTaskOnDrop {
/// The handle used to abort the spawned Tokio task.
pub handle: tokio::task::AbortHandle,
}

impl Drop for AbortTaskOnDrop {
fn drop(&mut self) {
// Automatically abort the associated task when this guard is dropped.
self.handle.abort();
}
}

/// Formats an assertion error message for unsupported capabilities.
///
/// Constructs a string describing that a specific entity (e.g., server or client) lacks
Expand Down
Loading