Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Net.Http.Headers;
Expand All @@ -17,8 +18,9 @@ internal sealed class StreamableHttpHandler(
IOptionsFactory<McpServerOptions> mcpServerOptionsFactory,
IOptions<HttpServerTransportOptions> httpServerTransportOptions,
StatefulSessionManager sessionManager,
ILoggerFactory loggerFactory,
IServiceProvider applicationServices)
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider applicationServices,
ILoggerFactory loggerFactory)
{
private const string McpSessionIdHeaderName = "Mcp-Session-Id";

Expand Down Expand Up @@ -60,7 +62,7 @@ await WriteJsonRpcErrorAsync(context,
}

InitializeSseResponse(context);
var wroteResponse = await session.Transport.HandlePostRequest(message, context.Response.Body, context.RequestAborted);
var wroteResponse = await session.Transport.HandlePostRequestAsync(message, context.Response.Body, context.RequestAborted);
if (!wroteResponse)
{
// We wound up writing nothing, so there should be no Content-Type response header.
Expand Down Expand Up @@ -94,14 +96,28 @@ await WriteJsonRpcErrorAsync(context,
return;
}

await using var _ = await session.AcquireReferenceAsync(context.RequestAborted);
InitializeSseResponse(context);
// Link the GET request to both RequestAborted and ApplicationStopping.
// The GET request should complete immediately during graceful shutdown without waiting for
// in-flight POST requests to complete. This prevents slow shutdown when clients are still connected.
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
var cancellationToken = sseCts.Token;

// We should flush headers to indicate a 200 success quickly, because the initialization response
// will be sent in response to a different POST request. It might be a while before we send a message
// over this response body.
await context.Response.Body.FlushAsync(context.RequestAborted);
await session.Transport.HandleGetRequest(context.Response.Body, context.RequestAborted);
try
{
await using var _ = await session.AcquireReferenceAsync(cancellationToken);
InitializeSseResponse(context);

// We should flush headers to indicate a 200 success quickly, because the initialization response
// will be sent in response to a different POST request. It might be a while before we send a message
// over this response body.
await context.Response.Body.FlushAsync(sseCts.Token);
await session.Transport.HandleGetRequestAsync(context.Response.Body, sseCts.Token);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
}
}

public async Task HandleDeleteRequestAsync(HttpContext context)
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private protected JsonRpcMessage()
/// <remarks>
/// This property should only be set when implementing a custom <see cref="ITransport"/>
/// that needs to pass additional per-message context or to pass a <see cref="JsonRpcMessageContext.User"/>
/// to <see cref="StreamableHttpServerTransport.HandlePostRequest(JsonRpcMessage, Stream, CancellationToken)"/>
/// to <see cref="StreamableHttpServerTransport.HandlePostRequestAsync(JsonRpcMessage, Stream, CancellationToken)"/>
/// or <see cref="SseResponseStreamTransport.OnMessageReceivedAsync(JsonRpcMessage, CancellationToken)"/> .
/// </remarks>
[JsonIgnore]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ public sealed class StreamableHttpServerTransport : ITransport
/// <summary>
/// Configures whether the transport should be in stateless mode that does not require all requests for a given session
/// to arrive to the same ASP.NET Core application process. Unsolicited server-to-client messages are not supported in this mode,
/// so calling <see cref="HandleGetRequest(Stream, CancellationToken)"/> results in an <see cref="InvalidOperationException"/>.
/// so calling <see cref="HandleGetRequestAsync(Stream, CancellationToken)"/> results in an <see cref="InvalidOperationException"/>.
/// Server-to-client requests are also unsupported, because the responses may arrive at another ASP.NET Core application process.
/// Client sampling and roots capabilities are also disabled in stateless mode, because the server cannot make requests.
/// </summary>
public bool Stateless { get; init; }

/// <summary>
/// Gets a value indicating whether the execution context should flow from the calls to <see cref="HandlePostRequest(JsonRpcMessage, Stream, CancellationToken)"/>
/// Gets a value indicating whether the execution context should flow from the calls to <see cref="HandlePostRequestAsync(JsonRpcMessage, Stream, CancellationToken)"/>
/// to the corresponding <see cref="JsonRpcMessageContext.ExecutionContext"/> property contained in the <see cref="JsonRpcMessage"/> instances returned by the <see cref="MessageReader"/>.
/// </summary>
/// <remarks>
Expand All @@ -76,7 +76,7 @@ public sealed class StreamableHttpServerTransport : ITransport
/// <param name="sseResponseStream">The response stream to write MCP JSON-RPC messages as SSE events to.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task representing the send loop that writes JSON-RPC messages to the SSE response stream.</returns>
public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken = default)
public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default)
{
Throw.IfNull(sseResponseStream);

Expand Down Expand Up @@ -111,7 +111,7 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c
/// If 's an authenticated <see cref="ClaimsPrincipal"/> sent the message, that can be included in the <see cref="JsonRpcMessage.Context"/>.
/// No other part of the context should be set.
/// </para>
public async Task<bool> HandlePostRequest(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default)
public async Task<bool> HandlePostRequestAsync(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default)
{
Throw.IfNull(message);
Throw.IfNull(responseStream);
Expand Down
33 changes: 31 additions & 2 deletions tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using System.ComponentModel;
using System.Diagnostics;
using System.Net;
using System.Security.Claims;

Expand Down Expand Up @@ -114,10 +115,9 @@ public async Task Messages_FromNewUser_AreRejected()
}

[Fact]
public async Task ClaimsPrincipal_CanBeInjectedIntoToolMethod()
public async Task ClaimsPrincipal_CanBeInjected_IntoToolMethod()
{
Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools<ClaimsPrincipalTools>();
Builder.Services.AddHttpContextAccessor();

await using var app = Builder.Build();

Expand Down Expand Up @@ -211,6 +211,35 @@ public async Task Sampling_DoesNotCloseStream_Prematurely()
m.Message.Contains("request '2' for method 'sampling/createMessage'"));
}

[Fact]
public async Task Server_ShutsDownQuickly_WhenClientIsConnected()
{
Builder.Services.AddMcpServer().WithHttpTransport().WithTools<ClaimsPrincipalTools>();

await using var app = Builder.Build();
app.MapMcp();

await app.StartAsync(TestContext.Current.CancellationToken);

// Connect a client which will open a long-running GET request (SSE or Streamable HTTP)
await using var mcpClient = await ConnectAsync();

// Verify the client is connected
var tools = await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotEmpty(tools);

// Now measure how long it takes to stop the server
var stopwatch = Stopwatch.StartNew();
await app.StopAsync(TestContext.Current.CancellationToken);
stopwatch.Stop();

// The server should shut down quickly (within a few seconds). We use 5 seconds as a generous threshold.
// This is much less than the default HostOptions.ShutdownTimeout of 30 seconds.
Assert.True(stopwatch.Elapsed < TimeSpan.FromSeconds(5),
$"Server took {stopwatch.Elapsed.TotalSeconds:F2} seconds to shut down with a connected client. " +
"This suggests the GET request is not respecting ApplicationStopping token.");
}

private ClaimsPrincipal CreateUser(string name)
=> new ClaimsPrincipal(new ClaimsIdentity(
[new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ public sealed class KestrelInMemoryTransport : IConnectionListenerFactory
public KestrelInMemoryConnection CreateConnection(EndPoint endpoint)
{
var connection = new KestrelInMemoryConnection();
GetAcceptQueue(endpoint).Writer.TryWrite(connection);
if (!GetAcceptQueue(endpoint).Writer.TryWrite(connection))
{
throw new IOException("The KestrelInMemoryTransport has been shut down.");
};

return connection;
}

Expand All @@ -37,7 +41,7 @@ private sealed class KestrelInMemoryListener(EndPoint endpoint, Channel<Connecti

public async ValueTask<ConnectionContext?> AcceptAsync(CancellationToken cancellationToken = default)
{
if (await acceptQueue.Reader.WaitToReadAsync(cancellationToken))
while (await acceptQueue.Reader.WaitToReadAsync(cancellationToken))
{
while (acceptQueue.Reader.TryRead(out var item))
{
Expand Down
1 change: 1 addition & 0 deletions tests/ModelContextProtocol.TestSseServer/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ private static void HandleStatelessMcp(IApplicationBuilder app)
var serviceCollection = new ServiceCollection();
serviceCollection.AddLogging();
serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService<ILoggerFactory>());
serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService<IHostApplicationLifetime>());
serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService<DiagnosticListener>());
serviceCollection.AddRoutingCore();

Expand Down
Loading