Skip to content

Commit eeee551

Browse files
committed
Merge branch 'main' into feat/add-callback-when-connection-lost
2 parents 24f4205 + a43b104 commit eeee551

20 files changed

+438
-112
lines changed

client/client.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"slices"
89
"sync"
910
"sync/atomic"
1011

@@ -22,6 +23,7 @@ type Client struct {
2223
requestID atomic.Int64
2324
clientCapabilities mcp.ClientCapabilities
2425
serverCapabilities mcp.ServerCapabilities
26+
protocolVersion string
2527
samplingHandler SamplingHandler
2628
}
2729

@@ -187,8 +189,19 @@ func (c *Client) Initialize(
187189
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
188190
}
189191

190-
// Store serverCapabilities
192+
// Validate protocol version
193+
if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) {
194+
return nil, mcp.UnsupportedProtocolVersionError{Version: result.ProtocolVersion}
195+
}
196+
197+
// Store serverCapabilities and protocol version
191198
c.serverCapabilities = result.Capabilities
199+
c.protocolVersion = result.ProtocolVersion
200+
201+
// Set protocol version on HTTP transports
202+
if httpConn, ok := c.transport.(transport.HTTPConnection); ok {
203+
httpConn.SetProtocolVersion(result.ProtocolVersion)
204+
}
192205

193206
// Send initialized notification
194207
notification := mcp.JSONRPCNotification{
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"strings"
8+
"testing"
9+
10+
"github.com/mark3labs/mcp-go/client/transport"
11+
"github.com/mark3labs/mcp-go/mcp"
12+
)
13+
14+
// mockProtocolTransport implements transport.Interface for testing protocol negotiation
15+
type mockProtocolTransport struct {
16+
responses map[string]string
17+
notificationHandler func(mcp.JSONRPCNotification)
18+
started bool
19+
closed bool
20+
}
21+
22+
func (m *mockProtocolTransport) Start(ctx context.Context) error {
23+
m.started = true
24+
return nil
25+
}
26+
27+
func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
28+
responseStr, ok := m.responses[request.Method]
29+
if !ok {
30+
return nil, fmt.Errorf("no mock response for method %s", request.Method)
31+
}
32+
33+
return &transport.JSONRPCResponse{
34+
JSONRPC: "2.0",
35+
ID: request.ID,
36+
Result: json.RawMessage(responseStr),
37+
}, nil
38+
}
39+
40+
func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
41+
return nil
42+
}
43+
44+
func (m *mockProtocolTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
45+
m.notificationHandler = handler
46+
}
47+
48+
func (m *mockProtocolTransport) Close() error {
49+
m.closed = true
50+
return nil
51+
}
52+
53+
func (m *mockProtocolTransport) GetSessionId() string {
54+
return "mock-session"
55+
}
56+
57+
func TestProtocolVersionNegotiation(t *testing.T) {
58+
tests := []struct {
59+
name string
60+
serverVersion string
61+
expectError bool
62+
errorContains string
63+
}{
64+
{
65+
name: "supported latest version",
66+
serverVersion: mcp.LATEST_PROTOCOL_VERSION,
67+
expectError: false,
68+
},
69+
{
70+
name: "supported older version 2025-03-26",
71+
serverVersion: "2025-03-26",
72+
expectError: false,
73+
},
74+
{
75+
name: "supported older version 2024-11-05",
76+
serverVersion: "2024-11-05",
77+
expectError: false,
78+
},
79+
{
80+
name: "unsupported version",
81+
serverVersion: "2023-01-01",
82+
expectError: true,
83+
errorContains: "unsupported protocol version",
84+
},
85+
{
86+
name: "unsupported future version",
87+
serverVersion: "2030-01-01",
88+
expectError: true,
89+
errorContains: "unsupported protocol version",
90+
},
91+
{
92+
name: "empty protocol version",
93+
serverVersion: "",
94+
expectError: true,
95+
errorContains: "unsupported protocol version",
96+
},
97+
{
98+
name: "malformed protocol version - invalid format",
99+
serverVersion: "not-a-date",
100+
expectError: true,
101+
errorContains: "unsupported protocol version",
102+
},
103+
{
104+
name: "malformed protocol version - partial date",
105+
serverVersion: "2025-06",
106+
expectError: true,
107+
errorContains: "unsupported protocol version",
108+
},
109+
{
110+
name: "malformed protocol version - just numbers",
111+
serverVersion: "20250618",
112+
expectError: true,
113+
errorContains: "unsupported protocol version",
114+
},
115+
}
116+
117+
for _, tt := range tests {
118+
t.Run(tt.name, func(t *testing.T) {
119+
// Create mock transport that returns specific version
120+
mockTransport := &mockProtocolTransport{
121+
responses: map[string]string{
122+
"initialize": fmt.Sprintf(`{
123+
"protocolVersion": "%s",
124+
"capabilities": {},
125+
"serverInfo": {"name": "test", "version": "1.0"}
126+
}`, tt.serverVersion),
127+
},
128+
}
129+
130+
client := NewClient(mockTransport)
131+
132+
_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
133+
Params: mcp.InitializeParams{
134+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
135+
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
136+
Capabilities: mcp.ClientCapabilities{},
137+
},
138+
})
139+
140+
if tt.expectError {
141+
if err == nil {
142+
t.Errorf("expected error but got none")
143+
} else if !strings.Contains(err.Error(), tt.errorContains) {
144+
t.Errorf("expected error containing %q, got %q", tt.errorContains, err.Error())
145+
}
146+
// Verify it's the correct error type
147+
if !mcp.IsUnsupportedProtocolVersion(err) {
148+
t.Errorf("expected UnsupportedProtocolVersionError, got %T", err)
149+
}
150+
} else {
151+
if err != nil {
152+
t.Errorf("unexpected error: %v", err)
153+
}
154+
// Verify the protocol version was stored
155+
if client.protocolVersion != tt.serverVersion {
156+
t.Errorf("expected protocol version %q, got %q", tt.serverVersion, client.protocolVersion)
157+
}
158+
}
159+
})
160+
}
161+
}
162+
163+
// mockHTTPTransport implements both transport.Interface and transport.HTTPConnection
164+
type mockHTTPTransport struct {
165+
mockProtocolTransport
166+
protocolVersion string
167+
}
168+
169+
func (m *mockHTTPTransport) SetProtocolVersion(version string) {
170+
m.protocolVersion = version
171+
}
172+
173+
func TestProtocolVersionHeaderSetting(t *testing.T) {
174+
// Create mock HTTP transport
175+
mockTransport := &mockHTTPTransport{
176+
mockProtocolTransport: mockProtocolTransport{
177+
responses: map[string]string{
178+
"initialize": fmt.Sprintf(`{
179+
"protocolVersion": "%s",
180+
"capabilities": {},
181+
"serverInfo": {"name": "test", "version": "1.0"}
182+
}`, mcp.LATEST_PROTOCOL_VERSION),
183+
},
184+
},
185+
}
186+
187+
client := NewClient(mockTransport)
188+
189+
_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
190+
Params: mcp.InitializeParams{
191+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
192+
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
193+
Capabilities: mcp.ClientCapabilities{},
194+
},
195+
})
196+
197+
if err != nil {
198+
t.Fatalf("unexpected error: %v", err)
199+
}
200+
201+
// Verify SetProtocolVersion was called on HTTP transport
202+
if mockTransport.protocolVersion != mcp.LATEST_PROTOCOL_VERSION {
203+
t.Errorf("expected SetProtocolVersion to be called with %q, got %q",
204+
mcp.LATEST_PROTOCOL_VERSION, mockTransport.protocolVersion)
205+
}
206+
}
207+
208+
func TestUnsupportedProtocolVersionError_Is(t *testing.T) {
209+
// Test that errors.Is works correctly with UnsupportedProtocolVersionError
210+
err1 := mcp.UnsupportedProtocolVersionError{Version: "2023-01-01"}
211+
err2 := mcp.UnsupportedProtocolVersionError{Version: "2024-01-01"}
212+
213+
// Test Is method
214+
if !err1.Is(err2) {
215+
t.Error("expected UnsupportedProtocolVersionError.Is to return true for same error type")
216+
}
217+
218+
// Test with different error type
219+
otherErr := fmt.Errorf("some other error")
220+
if err1.Is(otherErr) {
221+
t.Error("expected UnsupportedProtocolVersionError.Is to return false for different error type")
222+
}
223+
224+
// Test IsUnsupportedProtocolVersion helper
225+
if !mcp.IsUnsupportedProtocolVersion(err1) {
226+
t.Error("expected IsUnsupportedProtocolVersion to return true")
227+
}
228+
if mcp.IsUnsupportedProtocolVersion(otherErr) {
229+
t.Error("expected IsUnsupportedProtocolVersion to return false for different error type")
230+
}
231+
}

client/stdio_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func TestStdioMCPClient(t *testing.T) {
9393
defer cancel()
9494

9595
request := mcp.InitializeRequest{}
96-
request.Params.ProtocolVersion = "1.0"
96+
request.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
9797
request.Params.ClientInfo = mcp.Implementation{
9898
Name: "test-client",
9999
Version: "1.0.0",

client/transport/constants.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package transport
2+
3+
// Common HTTP header constants used across transports
4+
const (
5+
HeaderKeySessionID = "Mcp-Session-Id"
6+
HeaderKeyProtocolVersion = "Mcp-Protocol-Version"
7+
)

client/transport/interface.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ type BidirectionalInterface interface {
4747
SetRequestHandler(handler RequestHandler)
4848
}
4949

50+
// HTTPConnection is a Transport that runs over HTTP and supports
51+
// protocol version headers.
52+
type HTTPConnection interface {
53+
Interface
54+
SetProtocolVersion(version string)
55+
}
56+
5057
type JSONRPCRequest struct {
5158
JSONRPC string `json:"jsonrpc"`
5259
ID mcp.RequestId `json:"id"`

client/transport/sse.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type SSE struct {
3737
started atomic.Bool
3838
closed atomic.Bool
3939
cancelSSEStream context.CancelFunc
40+
protocolVersion atomic.Value // string
4041
onConnectionLost func(error)
4142
connectionLostMu sync.RWMutex
4243

@@ -345,6 +346,12 @@ func (c *SSE) SendRequest(
345346

346347
// Set headers
347348
req.Header.Set("Content-Type", "application/json")
349+
// Set protocol version header if negotiated
350+
if v := c.protocolVersion.Load(); v != nil {
351+
if version, ok := v.(string); ok && version != "" {
352+
req.Header.Set(HeaderKeyProtocolVersion, version)
353+
}
354+
}
348355
for k, v := range c.headers {
349356
req.Header.Set(k, v)
350357
}
@@ -455,6 +462,11 @@ func (c *SSE) GetSessionId() string {
455462
return ""
456463
}
457464

465+
// SetProtocolVersion sets the negotiated protocol version for this connection.
466+
func (c *SSE) SetProtocolVersion(version string) {
467+
c.protocolVersion.Store(version)
468+
}
469+
458470
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
459471
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
460472
if c.endpoint == nil {
@@ -477,6 +489,12 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
477489
}
478490

479491
req.Header.Set("Content-Type", "application/json")
492+
// Set protocol version header if negotiated
493+
if v := c.protocolVersion.Load(); v != nil {
494+
if version, ok := v.(string); ok && version != "" {
495+
req.Header.Set(HeaderKeyProtocolVersion, version)
496+
}
497+
}
480498
// Set custom HTTP headers
481499
for k, v := range c.headers {
482500
req.Header.Set(k, v)

client/transport/sse_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,9 @@ func TestSSE(t *testing.T) {
443443
t.Run("SSEEventWithoutEventField", func(t *testing.T) {
444444
// Test that SSE events with only data field (no event field) are processed correctly
445445
// This tests the fix for issue #369
446-
446+
447447
var messageReceived chan struct{}
448-
448+
449449
// Create a custom mock server that sends SSE events without event field
450450
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
451451
w.Header().Set("Content-Type", "text/event-stream")
@@ -484,7 +484,7 @@ func TestSSE(t *testing.T) {
484484
messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
485485
w.Header().Set("Content-Type", "application/json")
486486
w.WriteHeader(http.StatusAccepted)
487-
487+
488488
// Signal that message was received
489489
close(messageReceived)
490490
})

0 commit comments

Comments
 (0)