Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions crates/rmcp/src/transport/sse_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ impl<C: SseClient> SseClientTransport<C> {

let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?;
let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
endpoint.parse::<http::Uri>()?
let ep = endpoint.parse::<http::Uri>()?;
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query;
Uri::from_parts(sse_endpoint_parts)?
} else {
// wait the endpoint event
loop {
Expand All @@ -132,17 +135,12 @@ impl<C: SseClient> SseClientTransport<C> {
let Some("endpoint") = sse.event.as_deref() else {
continue;
};
let sse_endpoint = sse.data.unwrap_or_default();
break sse_endpoint.parse::<http::Uri>()?;
let ep = sse.data.unwrap_or_default();

break message_endpoint(sse_endpoint.clone(), ep)?;
}
};

// sse: <authority><sse_pq> -> <authority><message_pq>
let message_endpoint = {
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
sse_endpoint_parts.path_and_query = message_endpoint.into_parts().path_and_query;
Uri::from_parts(sse_endpoint_parts)?
};
let stream = Box::pin(SseAutoReconnectStream::new(
sse_stream,
SseClientReconnect {
Expand All @@ -160,6 +158,36 @@ impl<C: SseClient> SseClientTransport<C> {
}
}

fn message_endpoint(base: http::Uri, endpoint: String) -> Result<http::Uri, http::uri::InvalidUri> {
// If endpoint is a full URL, parse and return it directly
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
return endpoint.parse::<http::Uri>();
}

let mut base_parts = base.into_parts();
let endpoint_clone = endpoint.clone();

if endpoint.starts_with("?") {
// Query only - keep base path and append query
if let Some(base_path_and_query) = &base_parts.path_and_query {
let base_path = base_path_and_query.path();
base_parts.path_and_query = Some(format!("{}{}", base_path, endpoint).parse()?);
} else {
base_parts.path_and_query = Some(format!("/{}", endpoint).parse()?);
}
} else {
// Path (with optional query) - replace entire path_and_query
let path_to_use = if endpoint.starts_with("/") {
endpoint // Use absolute path as-is
} else {
format!("/{}", endpoint) // Make relative path absolute
};
base_parts.path_and_query = Some(path_to_use.parse()?);
}

http::Uri::from_parts(base_parts).map_err(|_| endpoint_clone.parse::<http::Uri>().unwrap_err())
}

#[derive(Debug, Clone)]
pub struct SseClientConfig {
/// client sse endpoint
Expand Down Expand Up @@ -188,3 +216,33 @@ impl Default for SseClientConfig {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_message_endpoint() {
let base_url = "https://localhost/sse".parse::<http::Uri>().unwrap();

// Query only
let result = message_endpoint(base_url.clone(), "?sessionId=x".to_string()).unwrap();
assert_eq!(result.to_string(), "https://localhost/sse?sessionId=x");

// Relative path with query
let result = message_endpoint(base_url.clone(), "mypath?sessionId=x".to_string()).unwrap();
assert_eq!(result.to_string(), "https://localhost/mypath?sessionId=x");

// Absolute path with query
let result = message_endpoint(base_url.clone(), "/xxx?sessionId=x".to_string()).unwrap();
assert_eq!(result.to_string(), "https://localhost/xxx?sessionId=x");

// Full URL
let result = message_endpoint(
base_url.clone(),
"http://example.com/xxx?sessionId=x".to_string(),
)
.unwrap();
assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x");
}
}
Loading