diff --git a/crates/rmcp/src/transport/sse_client.rs b/crates/rmcp/src/transport/sse_client.rs index f9d0e434..7b6f280c 100644 --- a/crates/rmcp/src/transport/sse_client.rs +++ b/crates/rmcp/src/transport/sse_client.rs @@ -121,7 +121,10 @@ impl SseClientTransport { let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?; let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() { - endpoint.parse::()? + let ep = endpoint.parse::()?; + let mut sse_endpoint_parts = sse_endpoint.clone().into_parts(); + sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query; + Uri::from_parts(sse_endpoint_parts)? } else { // wait the endpoint event loop { @@ -132,17 +135,12 @@ impl SseClientTransport { let Some("endpoint") = sse.event.as_deref() else { continue; }; - let sse_endpoint = sse.data.unwrap_or_default(); - break sse_endpoint.parse::()?; + let ep = sse.data.unwrap_or_default(); + + break message_endpoint(sse_endpoint.clone(), ep)?; } }; - // sse: -> - let message_endpoint = { - let mut sse_endpoint_parts = sse_endpoint.clone().into_parts(); - sse_endpoint_parts.path_and_query = message_endpoint.into_parts().path_and_query; - Uri::from_parts(sse_endpoint_parts)? - }; let stream = Box::pin(SseAutoReconnectStream::new( sse_stream, SseClientReconnect { @@ -160,6 +158,36 @@ impl SseClientTransport { } } +fn message_endpoint(base: http::Uri, endpoint: String) -> Result { + // If endpoint is a full URL, parse and return it directly + if endpoint.starts_with("http://") || endpoint.starts_with("https://") { + return endpoint.parse::(); + } + + let mut base_parts = base.into_parts(); + let endpoint_clone = endpoint.clone(); + + if endpoint.starts_with("?") { + // Query only - keep base path and append query + if let Some(base_path_and_query) = &base_parts.path_and_query { + let base_path = base_path_and_query.path(); + base_parts.path_and_query = Some(format!("{}{}", base_path, endpoint).parse()?); + } else { + base_parts.path_and_query = Some(format!("/{}", endpoint).parse()?); + } + } else { + // Path (with optional query) - replace entire path_and_query + let path_to_use = if endpoint.starts_with("/") { + endpoint // Use absolute path as-is + } else { + format!("/{}", endpoint) // Make relative path absolute + }; + base_parts.path_and_query = Some(path_to_use.parse()?); + } + + http::Uri::from_parts(base_parts).map_err(|_| endpoint_clone.parse::().unwrap_err()) +} + #[derive(Debug, Clone)] pub struct SseClientConfig { /// client sse endpoint @@ -188,3 +216,33 @@ impl Default for SseClientConfig { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_endpoint() { + let base_url = "https://localhost/sse".parse::().unwrap(); + + // Query only + let result = message_endpoint(base_url.clone(), "?sessionId=x".to_string()).unwrap(); + assert_eq!(result.to_string(), "https://localhost/sse?sessionId=x"); + + // Relative path with query + let result = message_endpoint(base_url.clone(), "mypath?sessionId=x".to_string()).unwrap(); + assert_eq!(result.to_string(), "https://localhost/mypath?sessionId=x"); + + // Absolute path with query + let result = message_endpoint(base_url.clone(), "/xxx?sessionId=x".to_string()).unwrap(); + assert_eq!(result.to_string(), "https://localhost/xxx?sessionId=x"); + + // Full URL + let result = message_endpoint( + base_url.clone(), + "http://example.com/xxx?sessionId=x".to_string(), + ) + .unwrap(); + assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x"); + } +}