Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
}
}

// WithDisableStreaming prevents the server from responding to GET requests with
// a streaming response. Instead, it will respond with a 405 Method Not Allowed status.
// This can be useful in scenarios where streaming is not desired or supported.
// The default is false, meaning streaming is enabled.
func WithDisableStreaming(disable bool) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
s.disableStreaming = disable
}
}

// WithHTTPContextFunc sets a function that will be called to customise the context
// to the server using the incoming request.
// This can be used to inject context values from headers, for example.
Expand Down Expand Up @@ -141,6 +151,7 @@ type StreamableHTTPServer struct {
listenHeartbeatInterval time.Duration
logger util.Logger
sessionLogLevels *sessionLogLevelsStore
disableStreaming bool

tlsCertFile string
tlsKeyFile string
Expand Down Expand Up @@ -400,6 +411,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
// get request is for listening to notifications
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
if s.disableStreaming {
s.logger.Infof("Rejected GET request: streaming is disabled (session: %s)", r.Header.Get(HeaderKeySessionID))
http.Error(w, "Streaming is disabled on this server", http.StatusMethodNotAllowed)
return
}

sessionID := r.Header.Get(HeaderKeySessionID)
// the specification didn't say we should validate the session id
Expand Down
94 changes: 94 additions & 0 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,100 @@ func TestStreamableHTTPServer_TLS(t *testing.T) {
})
}

func TestStreamableHTTPServer_WithDisableStreaming(t *testing.T) {
t.Run("WithDisableStreaming blocks GET requests", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewTestStreamableHTTPServer(mcpServer, WithDisableStreaming(true))
defer server.Close()

// Attempt a GET request (which should be blocked)
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "text/event-stream")

resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()

// Verify the request is rejected with 405 Method Not Allowed
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405 Method Not Allowed, got %d", resp.StatusCode)
}

// Verify the error message
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}

expectedMessage := "Streaming is disabled on this server"
if !strings.Contains(string(bodyBytes), expectedMessage) {
t.Errorf("Expected error message to contain '%s', got '%s'", expectedMessage, string(bodyBytes))
}
})

t.Run("POST requests still work with WithDisableStreaming", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewTestStreamableHTTPServer(mcpServer, WithDisableStreaming(true))
defer server.Close()

// POST requests should still work
resp, err := postJSON(server.URL, initRequest)
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}

// Verify the response is valid
bodyBytes, _ := io.ReadAll(resp.Body)
var responseMessage jsonRPCResponse
if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION {
t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"])
}
})

t.Run("Streaming works when WithDisableStreaming is false", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewTestStreamableHTTPServer(mcpServer, WithDisableStreaming(false))
defer server.Close()

ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()

// GET request should work when streaming is enabled
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "text/event-stream")

resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}

if resp.Header.Get("content-type") != "text/event-stream" {
t.Errorf("Expected content-type text/event-stream, got %s", resp.Header.Get("content-type"))
}
})
}

func postJSON(url string, bodyObject any) (*http.Response, error) {
jsonBody, _ := json.Marshal(bodyObject)
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
Expand Down
Loading