diff --git a/server/server.go b/server/server.go index b9fb3612b..4e60e7b04 100644 --- a/server/server.go +++ b/server/server.go @@ -171,6 +171,19 @@ type MCPServer struct { paginationLimit *int sessions sync.Map hooks *Hooks + + // custom handlers for basic methods + InitializeHandler func(ctx context.Context, request mcp.InitializeRequest) (*mcp.InitializeResult, error) + PingHandler func(ctx context.Context, request mcp.PingRequest) (*mcp.EmptyResult, error) + ListResourcesHandler func(ctx context.Context, request mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error) + ListResourceTemplatesHandler func(ctx context.Context, request mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error) + ReadResourceHandler func(ctx context.Context, request mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) + ListPromptsHandler func(ctx context.Context, request mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error) + GetPromptHandler func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) + ListToolsHandler func(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error) + CallToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + SetLevelHandler func(ctx context.Context, request mcp.SetLevelRequest) (*mcp.EmptyResult, error) + NotificationHandler func(ctx context.Context, notification mcp.JSONRPCNotification) } // WithPaginationLimit sets the pagination limit for the server. @@ -647,9 +660,21 @@ func (s *MCPServer) AddNotificationHandler( func (s *MCPServer) handleInitialize( ctx context.Context, - _ any, + id any, request mcp.InitializeRequest, ) (*mcp.InitializeResult, *requestError) { + if s.InitializeHandler != nil { + result, err := s.InitializeHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + capabilities := mcp.ServerCapabilities{} // Only add resource capabilities if they're configured @@ -729,10 +754,21 @@ func (s *MCPServer) protocolVersion(clientVersion string) string { } func (s *MCPServer) handlePing( - _ context.Context, - _ any, - _ mcp.PingRequest, + ctx context.Context, + id any, + request mcp.PingRequest, ) (*mcp.EmptyResult, *requestError) { + if s.PingHandler != nil { + result, err := s.PingHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } return &mcp.EmptyResult{}, nil } @@ -741,6 +777,18 @@ func (s *MCPServer) handleSetLevel( id any, request mcp.SetLevelRequest, ) (*mcp.EmptyResult, *requestError) { + if s.SetLevelHandler != nil { + result, err := s.SetLevelHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + clientSession := ClientSessionFromContext(ctx) if clientSession == nil || !clientSession.Initialized() { return nil, &requestError{ @@ -820,6 +868,18 @@ func (s *MCPServer) handleListResources( id any, request mcp.ListResourcesRequest, ) (*mcp.ListResourcesResult, *requestError) { + if s.ListResourcesHandler != nil { + result, err := s.ListResourcesHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.resourcesMu.RLock() resources := make([]mcp.Resource, 0, len(s.resources)) for _, entry := range s.resources { @@ -858,6 +918,18 @@ func (s *MCPServer) handleListResourceTemplates( id any, request mcp.ListResourceTemplatesRequest, ) (*mcp.ListResourceTemplatesResult, *requestError) { + if s.ListResourceTemplatesHandler != nil { + result, err := s.ListResourceTemplatesHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.resourcesMu.RLock() templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) for _, entry := range s.resourceTemplates { @@ -894,6 +966,18 @@ func (s *MCPServer) handleReadResource( id any, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { + if s.ReadResourceHandler != nil { + result, err := s.ReadResourceHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.resourcesMu.RLock() // First try direct resource handlers if entry, ok := s.resources[request.Params.URI]; ok { @@ -972,6 +1056,18 @@ func (s *MCPServer) handleListPrompts( id any, request mcp.ListPromptsRequest, ) (*mcp.ListPromptsResult, *requestError) { + if s.ListPromptsHandler != nil { + result, err := s.ListPromptsHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.promptsMu.RLock() prompts := make([]mcp.Prompt, 0, len(s.prompts)) for _, prompt := range s.prompts { @@ -1010,6 +1106,18 @@ func (s *MCPServer) handleGetPrompt( id any, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, *requestError) { + if s.GetPromptHandler != nil { + result, err := s.GetPromptHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.promptsMu.RLock() handler, ok := s.promptHandlers[request.Params.Name] s.promptsMu.RUnlock() @@ -1039,6 +1147,17 @@ func (s *MCPServer) handleListTools( id any, request mcp.ListToolsRequest, ) (*mcp.ListToolsResult, *requestError) { + if s.ListToolsHandler != nil { + result, err := s.ListToolsHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } // Get the base tools from the server s.toolsMu.RLock() tools := make([]mcp.Tool, 0, len(s.tools)) @@ -1129,6 +1248,17 @@ func (s *MCPServer) handleToolCall( id any, request mcp.CallToolRequest, ) (*mcp.CallToolResult, *requestError) { + if s.CallToolHandler != nil { + result, err := s.CallToolHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } // First check session-specific tools var tool ServerTool var ok bool @@ -1188,6 +1318,10 @@ func (s *MCPServer) handleNotification( ctx context.Context, notification mcp.JSONRPCNotification, ) mcp.JSONRPCMessage { + if s.NotificationHandler != nil { + s.NotificationHandler(ctx, notification) + return nil + } s.notificationHandlersMu.RLock() handler, ok := s.notificationHandlers[notification.Method] s.notificationHandlersMu.RUnlock() diff --git a/server/sse.go b/server/sse.go index 9c9766cf3..b5ee2ff20 100644 --- a/server/sse.go +++ b/server/sse.go @@ -473,8 +473,12 @@ func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID strin if s.useFullURLForMessageEndpoint && s.baseURL != "" { endpointPath = s.baseURL + endpointPath } - - return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID) + if strings.Contains(endpointPath, "?") { + endpointPath += "&" + } else { + endpointPath += "?" + } + return fmt.Sprintf("%ssessionId=%s", endpointPath, sessionID) } // handleMessage processes incoming JSON-RPC messages from clients and sends responses