Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ For examples, see the [`examples/`](examples/) directory.

### Transports

MCP-Go supports stdio, SSE and streamable-HTTP transport layers.
MCP-Go supports stdio, SSE and streamable-HTTP transport layers. For SSE transport, you can use `SetConnectionLostHandler()` to detect and handle HTTP/2 idle timeout disconnections (NO_ERROR) for implementing reconnection logic.

### Session Management

Expand Down
11 changes: 11 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ func (c *Client) OnNotification(
c.notifications = append(c.notifications, handler)
}

// OnConnectionLost registers a handler function to be called when the connection is lost.
// This is useful for handling HTTP2 idle timeout disconnections that should not be treated as errors.
func (c *Client) OnConnectionLost(handler func(error)) {
type connectionLostSetter interface {
SetConnectionLostHandler(func(error))
}
if setter, ok := c.transport.(connectionLostSetter); ok {
setter.SetConnectionLostHandler(handler)
}
}

// sendRequest sends a JSON-RPC request to the server and waits for a response.
// Returns the raw JSON response message or an error if the request fails.
func (c *Client) sendRequest(
Expand Down
29 changes: 25 additions & 4 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ type SSE struct {
headers map[string]string
headerFunc HTTPHeaderFunc

started atomic.Bool
closed atomic.Bool
cancelSSEStream context.CancelFunc
protocolVersion atomic.Value // string
started atomic.Bool
closed atomic.Bool
cancelSSEStream context.CancelFunc
protocolVersion atomic.Value // string
onConnectionLost func(error)
connectionLostMu sync.RWMutex

// OAuth support
oauthHandler *OAuthHandler
Expand Down Expand Up @@ -204,6 +206,19 @@ func (c *SSE) readSSE(reader io.ReadCloser) {
}
break
}
// Checking whether the connection was terminated due to NO_ERROR in HTTP2 based on RFC9113
// Only handle NO_ERROR specially if onConnectionLost handler is set to maintain backward compatibility
if strings.Contains(err.Error(), "NO_ERROR") {
c.connectionLostMu.RLock()
handler := c.onConnectionLost
c.connectionLostMu.RUnlock()

if handler != nil {
// This is not actually an error - HTTP2 idle timeout disconnection
handler(err)
return
}
}
if !c.closed.Load() {
fmt.Printf("SSE stream error: %v\n", err)
}
Expand Down Expand Up @@ -294,6 +309,12 @@ func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotifi
c.onNotification = handler
}

func (c *SSE) SetConnectionLostHandler(handler func(error)) {
c.connectionLostMu.Lock()
defer c.connectionLostMu.Unlock()
c.onConnectionLost = handler
}

// SendRequest sends a JSON-RPC request to the server and waits for a response.
// Returns the raw JSON response message or an error if the request fails.
func (c *SSE) SendRequest(
Expand Down
247 changes: 247 additions & 0 deletions client/transport/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"errors"
"io"
"strings"
"sync"
"testing"
"time"
Expand All @@ -15,6 +17,39 @@ import (
"github.com/mark3labs/mcp-go/mcp"
)

// mockReaderWithError is a mock io.ReadCloser that simulates reading some data
// and then returning a specific error
type mockReaderWithError struct {
data []byte
err error
position int
closed bool
}

func (m *mockReaderWithError) Read(p []byte) (n int, err error) {
if m.closed {
return 0, io.EOF
}

if m.position >= len(m.data) {
return 0, m.err
}

n = copy(p, m.data[m.position:])
m.position += n

if m.position >= len(m.data) {
return n, m.err
}

return n, nil
}

func (m *mockReaderWithError) Close() error {
m.closed = true
return nil
}

// startMockSSEEchoServer starts a test HTTP server that implements
// a minimal SSE-based echo server for testing purposes.
// It returns the server URL and a function to close the server.
Expand Down Expand Up @@ -508,6 +543,218 @@ func TestSSE(t *testing.T) {
}
})

t.Run("NO_ERROR_WithoutConnectionLostHandler", func(t *testing.T) {
// Test that NO_ERROR without connection lost handler maintains backward compatibility
// When no connection lost handler is set, NO_ERROR should be treated as a regular error

// Create a mock Reader that simulates NO_ERROR
mockReader := &mockReaderWithError{
data: []byte("event: endpoint\ndata: /message\n\n"),
err: errors.New("connection closed: NO_ERROR"),
}

// Create SSE transport
url, closeF := startMockSSEEchoServer()
defer closeF()

trans, err := NewSSE(url)
if err != nil {
t.Fatal(err)
}

// DO NOT set connection lost handler to test backward compatibility

// Capture stderr to verify the error is printed (backward compatible behavior)
// Since we can't easily capture fmt.Printf output in tests, we'll just verify
// that the readSSE method returns without calling any handler

// Directly test the readSSE method with our mock reader
go trans.readSSE(mockReader)

// Wait for readSSE to complete
time.Sleep(100 * time.Millisecond)

// The test passes if readSSE completes without panicking or hanging
// In backward compatibility mode, NO_ERROR should be treated as a regular error
t.Log("Backward compatibility test passed: NO_ERROR handled as regular error when no handler is set")
})

t.Run("NO_ERROR_ConnectionLost", func(t *testing.T) {
// Test that NO_ERROR in HTTP/2 connection loss is properly handled
// This test verifies that when a connection is lost in a way that produces
// an error message containing "NO_ERROR", the connection lost handler is called

var connectionLostCalled bool
var connectionLostError error
var mu sync.Mutex

// Create a mock Reader that simulates connection loss with NO_ERROR
mockReader := &mockReaderWithError{
data: []byte("event: endpoint\ndata: /message\n\n"),
err: errors.New("http2: stream closed with error code NO_ERROR"),
}

// Create SSE transport
url, closeF := startMockSSEEchoServer()
defer closeF()

trans, err := NewSSE(url)
if err != nil {
t.Fatal(err)
}

// Set connection lost handler
trans.SetConnectionLostHandler(func(err error) {
mu.Lock()
defer mu.Unlock()
connectionLostCalled = true
connectionLostError = err
})

// Directly test the readSSE method with our mock reader that simulates NO_ERROR
go trans.readSSE(mockReader)

// Wait for connection lost handler to be called
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-timeout:
t.Fatal("Connection lost handler was not called within timeout for NO_ERROR connection loss")
case <-ticker.C:
mu.Lock()
called := connectionLostCalled
err := connectionLostError
mu.Unlock()

if called {
if err == nil {
t.Fatal("Expected connection lost error, got nil")
}

// Verify that the error contains "NO_ERROR" string
if !strings.Contains(err.Error(), "NO_ERROR") {
t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err)
}

t.Logf("Connection lost handler called with NO_ERROR: %v", err)
return
}
}
}
})

t.Run("NO_ERROR_Handling", func(t *testing.T) {
// Test specific NO_ERROR string handling in readSSE method
// This tests the code path at line 209 where NO_ERROR is checked

// Create a mock Reader that simulates an error containing "NO_ERROR"
mockReader := &mockReaderWithError{
data: []byte("event: endpoint\ndata: /message\n\n"),
err: errors.New("connection closed: NO_ERROR"),
}

// Create SSE transport
url, closeF := startMockSSEEchoServer()
defer closeF()

trans, err := NewSSE(url)
if err != nil {
t.Fatal(err)
}

var connectionLostCalled bool
var connectionLostError error
var mu sync.Mutex

// Set connection lost handler to verify it's called for NO_ERROR
trans.SetConnectionLostHandler(func(err error) {
mu.Lock()
defer mu.Unlock()
connectionLostCalled = true
connectionLostError = err
})

// Directly test the readSSE method with our mock reader
go trans.readSSE(mockReader)

// Wait for connection lost handler to be called
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-timeout:
t.Fatal("Connection lost handler was not called within timeout for NO_ERROR")
case <-ticker.C:
mu.Lock()
called := connectionLostCalled
err := connectionLostError
mu.Unlock()

if called {
if err == nil {
t.Fatal("Expected connection lost error with NO_ERROR, got nil")
}

// Verify that the error contains "NO_ERROR" string
if !strings.Contains(err.Error(), "NO_ERROR") {
t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err)
}

t.Logf("Successfully handled NO_ERROR: %v", err)
return
}
}
}
})

t.Run("RegularError_DoesNotTriggerConnectionLost", func(t *testing.T) {
// Test that regular errors (not containing NO_ERROR) do not trigger connection lost handler

// Create a mock Reader that simulates a regular error
mockReader := &mockReaderWithError{
data: []byte("event: endpoint\ndata: /message\n\n"),
err: errors.New("regular connection error"),
}

// Create SSE transport
url, closeF := startMockSSEEchoServer()
defer closeF()

trans, err := NewSSE(url)
if err != nil {
t.Fatal(err)
}

var connectionLostCalled bool
var mu sync.Mutex

// Set connection lost handler - this should NOT be called for regular errors
trans.SetConnectionLostHandler(func(err error) {
mu.Lock()
defer mu.Unlock()
connectionLostCalled = true
})

// Directly test the readSSE method with our mock reader
go trans.readSSE(mockReader)

// Wait and verify connection lost handler is NOT called
time.Sleep(200 * time.Millisecond)

mu.Lock()
called := connectionLostCalled
mu.Unlock()

if called {
t.Error("Connection lost handler should not be called for regular errors")
}
})

}

func TestSSEErrors(t *testing.T) {
Expand Down