diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 70976d44f..9f4af7ea5 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -1,6 +1,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Net.Http.Headers; @@ -17,8 +18,9 @@ internal sealed class StreamableHttpHandler( IOptionsFactory mcpServerOptionsFactory, IOptions httpServerTransportOptions, StatefulSessionManager sessionManager, - ILoggerFactory loggerFactory, - IServiceProvider applicationServices) + IHostApplicationLifetime hostApplicationLifetime, + IServiceProvider applicationServices, + ILoggerFactory loggerFactory) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; @@ -60,7 +62,7 @@ await WriteJsonRpcErrorAsync(context, } InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(message, context.Response.Body, context.RequestAborted); + var wroteResponse = await session.Transport.HandlePostRequestAsync(message, context.Response.Body, context.RequestAborted); if (!wroteResponse) { // We wound up writing nothing, so there should be no Content-Type response header. @@ -94,14 +96,28 @@ await WriteJsonRpcErrorAsync(context, return; } - await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); - InitializeSseResponse(context); + // Link the GET request to both RequestAborted and ApplicationStopping. + // The GET request should complete immediately during graceful shutdown without waiting for + // in-flight POST requests to complete. This prevents slow shutdown when clients are still connected. + using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); + var cancellationToken = sseCts.Token; - // We should flush headers to indicate a 200 success quickly, because the initialization response - // will be sent in response to a different POST request. It might be a while before we send a message - // over this response body. - await context.Response.Body.FlushAsync(context.RequestAborted); - await session.Transport.HandleGetRequest(context.Response.Body, context.RequestAborted); + try + { + await using var _ = await session.AcquireReferenceAsync(cancellationToken); + InitializeSseResponse(context); + + // We should flush headers to indicate a 200 success quickly, because the initialization response + // will be sent in response to a different POST request. It might be a while before we send a message + // over this response body. + await context.Response.Body.FlushAsync(cancellationToken); + await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // RequestAborted always triggers when the client disconnects before a complete response body is written, + // but this is how SSE connections are typically closed. + } } public async Task HandleDeleteRequestAsync(HttpContext context) diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index f2fd55f16..2b9700f4f 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -175,16 +175,27 @@ private async Task ReceiveUnsolicitedMessagesAsync() request.Headers.Accept.Add(s_textEventStreamMediaType); CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); - using var response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); - - if (!response.IsSuccessStatusCode) + // Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages. + HttpResponseMessage response; + try + { + response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); + } + catch (HttpRequestException) { - // Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages. return; } - using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false); - await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false); + using (response) + { + if (!response.IsSuccessStatusCode) + { + return; + } + + using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false); + await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false); + } } private async Task ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs index ae15453db..a01a5b58a 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs @@ -40,7 +40,7 @@ private protected JsonRpcMessage() /// /// This property should only be set when implementing a custom /// that needs to pass additional per-message context or to pass a - /// to + /// to /// or . /// [JsonIgnore] diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 4bbb49be9..ee943ea70 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -43,14 +43,14 @@ public sealed class StreamableHttpServerTransport : ITransport /// /// Configures whether the transport should be in stateless mode that does not require all requests for a given session /// to arrive to the same ASP.NET Core application process. Unsolicited server-to-client messages are not supported in this mode, - /// so calling results in an . + /// so calling results in an . /// Server-to-client requests are also unsupported, because the responses may arrive at another ASP.NET Core application process. /// Client sampling and roots capabilities are also disabled in stateless mode, because the server cannot make requests. /// public bool Stateless { get; init; } /// - /// Gets a value indicating whether the execution context should flow from the calls to + /// Gets a value indicating whether the execution context should flow from the calls to /// to the corresponding property contained in the instances returned by the . /// /// @@ -76,7 +76,7 @@ public sealed class StreamableHttpServerTransport : ITransport /// The response stream to write MCP JSON-RPC messages as SSE events to. /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken = default) + public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default) { Throw.IfNull(sseResponseStream); @@ -111,7 +111,7 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c /// If 's an authenticated sent the message, that can be included in the . /// No other part of the context should be set. /// - public async Task HandlePostRequest(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default) + public async Task HandlePostRequestAsync(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default) { Throw.IfNull(message); Throw.IfNull(responseStream); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 0e953e4d7..9f50793ce 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -181,9 +181,10 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia await mcpClient.DisposeAsync(); - // The header should be included in the GET request, the initialized notification, the tools/list call, and the delete request. - // The DELETE request won't be sent for Stateless mode due to the lack of an Mcp-Session-Id. - Assert.Equal(Stateless ? 3 : 4, protocolVersionHeaderValues.Count); + // The GET request might not have started in time, and the DELETE request won't be sent in + // Stateless mode due to the lack of an Mcp-Session-Id, but the header should be included in the + // initialized notification and the tools/list call at a minimum. + Assert.True(protocolVersionHeaderValues.Count > 1); Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v)); } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index d7b9eaa01..166d492f2 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -8,6 +8,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using System.ComponentModel; +using System.Diagnostics; using System.Net; using System.Security.Claims; @@ -114,10 +115,9 @@ public async Task Messages_FromNewUser_AreRejected() } [Fact] - public async Task ClaimsPrincipal_CanBeInjectedIntoToolMethod() + public async Task ClaimsPrincipal_CanBeInjected_IntoToolMethod() { Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); - Builder.Services.AddHttpContextAccessor(); await using var app = Builder.Build(); @@ -211,6 +211,35 @@ public async Task Sampling_DoesNotCloseStream_Prematurely() m.Message.Contains("request '2' for method 'sampling/createMessage'")); } + [Fact] + public async Task Server_ShutsDownQuickly_WhenClientIsConnected() + { + Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + + await using var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + // Connect a client which will open a long-running GET request (SSE or Streamable HTTP) + await using var mcpClient = await ConnectAsync(); + + // Verify the client is connected + var tools = await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.NotEmpty(tools); + + // Now measure how long it takes to stop the server + var stopwatch = Stopwatch.StartNew(); + await app.StopAsync(TestContext.Current.CancellationToken); + stopwatch.Stop(); + + // The server should shut down quickly (within a few seconds). We use 5 seconds as a generous threshold. + // This is much less than the default HostOptions.ShutdownTimeout of 30 seconds. + Assert.True(stopwatch.Elapsed < TimeSpan.FromSeconds(5), + $"Server took {stopwatch.Elapsed.TotalSeconds:F2} seconds to shut down with a connected client. " + + "This suggests the GET request is not respecting ApplicationStopping token."); + } + private ClaimsPrincipal CreateUser(string name) => new ClaimsPrincipal(new ClaimsIdentity( [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)], diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs index b7d2ce643..c632630b0 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs @@ -6,28 +6,29 @@ namespace ModelContextProtocol.AspNetCore.Tests.Utils; public sealed class KestrelInMemoryConnection : ConnectionContext { - private readonly Pipe _clientToServerPipe = new(); - private readonly Pipe _serverToClientPipe = new(); private readonly CancellationTokenSource _connectionClosedCts = new(); private readonly FeatureCollection _features = new(); public KestrelInMemoryConnection() { + Pipe clientToServerPipe = new(); + Pipe serverToClientPipe = new(); + ConnectionClosed = _connectionClosedCts.Token; Transport = new DuplexPipe { - Input = _clientToServerPipe.Reader, - Output = _serverToClientPipe.Writer, + Input = clientToServerPipe.Reader, + Output = serverToClientPipe.Writer, }; - Application = new DuplexPipe + ClientPipe = new DuplexPipe { - Input = _serverToClientPipe.Reader, - Output = _clientToServerPipe.Writer, + Input = serverToClientPipe.Reader, + Output = clientToServerPipe.Writer, }; - ClientStream = new DuplexStream(Application, _connectionClosedCts); + ClientStream = new DuplexStream(ClientPipe, _connectionClosedCts); } - public IDuplexPipe Application { get; } + public IDuplexPipe ClientPipe { get; } public Stream ClientStream { get; } public override IDuplexPipe Transport { get; set; } @@ -41,8 +42,8 @@ public override async ValueTask DisposeAsync() { // This is called by Kestrel. The client should dispose the DuplexStream which // completes the other half of these pipes. - await _serverToClientPipe.Writer.CompleteAsync(); - await _serverToClientPipe.Reader.CompleteAsync(); + await Transport.Input.CompleteAsync(); + await Transport.Output.CompleteAsync(); // Don't bother disposing the _connectionClosedCts, since this is just for testing, // and it's annoying to synchronize with DuplexStream. diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs index 71809ad6c..e5686a16f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs @@ -13,7 +13,11 @@ public sealed class KestrelInMemoryTransport : IConnectionListenerFactory public KestrelInMemoryConnection CreateConnection(EndPoint endpoint) { var connection = new KestrelInMemoryConnection(); - GetAcceptQueue(endpoint).Writer.TryWrite(connection); + if (!GetAcceptQueue(endpoint).Writer.TryWrite(connection)) + { + throw new IOException("The KestrelInMemoryTransport has been shut down."); + }; + return connection; } @@ -37,12 +41,9 @@ private sealed class KestrelInMemoryListener(EndPoint endpoint, Channel AcceptAsync(CancellationToken cancellationToken = default) { - if (await acceptQueue.Reader.WaitToReadAsync(cancellationToken)) + await foreach (var item in acceptQueue.Reader.ReadAllAsync(cancellationToken)) { - while (acceptQueue.Reader.TryRead(out var item)) - { - return item; - } + return item; } return null; diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index cf78c0896..c2898c542 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -370,6 +370,7 @@ private static void HandleStatelessMcp(IApplicationBuilder app) var serviceCollection = new ServiceCollection(); serviceCollection.AddLogging(); serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService()); + serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService()); serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService()); serviceCollection.AddRoutingCore();