diff --git a/examples/custom_sse_pattern/main.go b/examples/custom_sse_pattern/main.go new file mode 100644 index 000000000..58f5788cd --- /dev/null +++ b/examples/custom_sse_pattern/main.go @@ -0,0 +1,57 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// Custom context function for SSE connections +func customContextFunc(ctx context.Context, r *http.Request) context.Context { + params := server.GetRouteParams(ctx) + log.Printf("SSE Connection Established - Route Parameters: %+v", params) + log.Printf("Request Path: %s", r.URL.Path) + return ctx +} + +// Message handler for simulating message sending +func messageHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get channel parameter from context + channel := server.GetRouteParam(ctx, "channel") + log.Printf("Processing Message - Channel Parameter: %s", channel) + + if channel == "" { + return mcp.NewToolResultText("Failed to get channel parameter"), nil + } + + message := fmt.Sprintf("Message sent to channel: %s", channel) + return mcp.NewToolResultText(message), nil +} + +func main() { + // Create MCP Server + mcpServer := server.NewMCPServer("test-server", "1.0.0") + + // Register test tool + mcpServer.AddTool(mcp.NewTool("send_message"), messageHandler) + + // Create SSE Server with custom route pattern + sseServer := server.NewSSEServer(mcpServer, + server.WithBaseURL("http://localhost:8080"), + server.WithSSEPattern("/:channel/sse"), + server.WithSSEContextFunc(customContextFunc), + ) + + // Start server + log.Printf("Server started on port :8080") + log.Printf("Test URL: http://localhost:8080/test/sse") + log.Printf("Test URL: http://localhost:8080/news/sse") + + if err := sseServer.Start(":8080"); err != nil { + log.Fatalf("Server error: %v", err) + } +} diff --git a/server/sse.go b/server/sse.go index f69451c6d..a5c30c173 100644 --- a/server/sse.go +++ b/server/sse.go @@ -25,6 +25,7 @@ type sseSession struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool + routeParams RouteParams // Store route parameters in session } // SSEContextFunc is a function that takes an existing context and the current @@ -32,6 +33,28 @@ type sseSession struct { // content. This can be used to inject context values from headers, for example. type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context +// RouteParamsKey is the key type for storing route parameters in context +type RouteParamsKey struct{} + +// RouteParams stores path parameters +type RouteParams map[string]string + +// GetRouteParam retrieves a route parameter from context +func GetRouteParam(ctx context.Context, key string) string { + if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok { + return params[key] + } + return "" +} + +// GetRouteParams retrieves all route parameters from context +func GetRouteParams(ctx context.Context) RouteParams { + if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok { + return params + } + return RouteParams{} +} + func (s *sseSession) SessionID() string { return s.sessionID } @@ -53,18 +76,18 @@ var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - basePath string - useFullURLForMessageEndpoint bool - messageEndpoint string - sseEndpoint string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc - - keepAlive bool - keepAliveInterval time.Duration + server *MCPServer + baseURL string + basePath string + useFullURLForMessageEndpoint bool + messageEndpoint string + sseEndpoint string + ssePattern string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc + keepAlive bool + keepAliveInterval time.Duration } // SSEOption defines a function type for configuring SSEServer @@ -127,6 +150,13 @@ func WithSSEEndpoint(endpoint string) SSEOption { } } +// WithSSEPattern sets the SSE endpoint pattern with route parameters +func WithSSEPattern(pattern string) SSEOption { + return func(s *SSEServer) { + s.ssePattern = pattern + } +} + // WithHTTPServer sets the HTTP server instance func WithHTTPServer(srv *http.Server) SSEOption { return func(s *SSEServer) { @@ -147,8 +177,7 @@ func WithKeepAlive(keepAlive bool) SSEOption { } } -// WithContextFunc sets a function that will be called to customise the context -// to the server using the incoming request. +// WithSSEContextFunc sets a function that will be called to customise the context func WithSSEContextFunc(fn SSEContextFunc) SSEOption { return func(s *SSEServer) { s.contextFunc = fn @@ -158,12 +187,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption { // NewSSEServer creates a new SSE server instance with the given MCP server and options. func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { s := &SSEServer{ - server: server, - sseEndpoint: "/sse", - messageEndpoint: "/message", - useFullURLForMessageEndpoint: true, - keepAlive: false, - keepAliveInterval: 10 * time.Second, + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + keepAlive: false, + keepAliveInterval: 10 * time.Second, } // Apply all options @@ -241,12 +270,21 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { eventQueue: make(chan string, 100), // Buffer for events sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), + routeParams: GetRouteParams(r.Context()), // Store route parameters from context } s.sessions.Store(sessionID, session) defer s.sessions.Delete(sessionID) - if err := s.server.RegisterSession(r.Context(), session); err != nil { + // Create base context with session + ctx := s.server.WithContext(r.Context(), session) + + // Apply custom context function if set + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + if err := s.server.RegisterSession(ctx, session); err != nil { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) return } @@ -268,7 +306,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } case <-session.done: return - case <-r.Context().Done(): + case <-ctx.Done(): return } } @@ -286,14 +324,13 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { session.eventQueue <- fmt.Sprintf(":ping - %s\n\n", time.Now().Format(time.RFC3339)) case <-session.done: return - case <-r.Context().Done(): + case <-ctx.Done(): return } } }() } - // Send the initial endpoint event fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) flusher.Flush() @@ -305,7 +342,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { // Write the event to the response fmt.Fprint(w, event) flusher.Flush() - case <-r.Context().Done(): + case <-ctx.Done(): close(session.done) return } @@ -343,8 +380,15 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { } session := sessionI.(*sseSession) - // Set the client context before handling the message + // Create base context with session ctx := s.server.WithContext(r.Context(), session) + + // Add stored route parameters to context + if len(session.routeParams) > 0 { + ctx = context.WithValue(ctx, RouteParamsKey{}, session.routeParams) + } + + // Apply custom context function if set if s.contextFunc != nil { ctx = s.contextFunc(ctx, r) } @@ -356,7 +400,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } - // Process message through MCPServer + // Process message through MCPServer with the context containing route parameters response := s.server.HandleMessage(ctx, rawMessage) // Only send response if there is one (not for notifications) @@ -423,6 +467,7 @@ func (s *SSEServer) SendEventToSession( return fmt.Errorf("event queue full") } } + func (s *SSEServer) GetUrlPath(input string) (string, error) { parse, err := url.Parse(input) if err != nil { @@ -434,6 +479,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) { func (s *SSEServer) CompleteSseEndpoint() string { return s.baseURL + s.basePath + s.sseEndpoint } + func (s *SSEServer) CompleteSsePath() string { path, err := s.GetUrlPath(s.CompleteSseEndpoint()) if err != nil { @@ -445,6 +491,7 @@ func (s *SSEServer) CompleteSsePath() string { func (s *SSEServer) CompleteMessageEndpoint() string { return s.baseURL + s.basePath + s.messageEndpoint } + func (s *SSEServer) CompleteMessagePath() string { path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) if err != nil { @@ -456,17 +503,61 @@ func (s *SSEServer) CompleteMessagePath() string { // ServeHTTP implements the http.Handler interface. func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path - // Use exact path matching rather than Contains - ssePath := s.CompleteSsePath() - if ssePath != "" && path == ssePath { - s.handleSSE(w, r) - return - } messagePath := s.CompleteMessagePath() + + // Handle message endpoint if messagePath != "" && path == messagePath { s.handleMessage(w, r) return } + // Handle SSE endpoint with route parameters + if s.ssePattern != "" { + // Try pattern matching if pattern is set + fullPattern := s.basePath + s.ssePattern + matches, params := matchPath(fullPattern, path) + if matches { + // Create new context with route parameters + ctx := context.WithValue(r.Context(), RouteParamsKey{}, params) + s.handleSSE(w, r.WithContext(ctx)) + return + } + // If pattern is set but doesn't match, return 404 + http.NotFound(w, r) + return + } + + // If no pattern is set, use the default SSE endpoint + ssePath := s.CompleteSsePath() + if ssePath != "" && path == ssePath { + s.handleSSE(w, r) + return + } + http.NotFound(w, r) } + +// matchPath checks if the given path matches the pattern and extracts parameters +// pattern format: /user/:id/profile/:type +func matchPath(pattern, path string) (bool, RouteParams) { + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + if len(patternParts) != len(pathParts) { + return false, nil + } + + params := make(RouteParams) + for i, part := range patternParts { + if strings.HasPrefix(part, ":") { + // This is a parameter + paramName := strings.TrimPrefix(part, ":") + params[paramName] = pathParts[i] + } else if part != pathParts[i] { + // Static part doesn't match + return false, nil + } + } + + return true, params +}