Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 138 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ type MCPServer struct {
paginationLimit *int
sessions sync.Map
hooks *Hooks

// custom handlers for basic methods
InitializeHandler func(ctx context.Context, request mcp.InitializeRequest) (*mcp.InitializeResult, error)
PingHandler func(ctx context.Context, request mcp.PingRequest) (*mcp.EmptyResult, error)
ListResourcesHandler func(ctx context.Context, request mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error)
ListResourceTemplatesHandler func(ctx context.Context, request mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error)
ReadResourceHandler func(ctx context.Context, request mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error)
ListPromptsHandler func(ctx context.Context, request mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error)
GetPromptHandler func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error)
ListToolsHandler func(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
CallToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
SetLevelHandler func(ctx context.Context, request mcp.SetLevelRequest) (*mcp.EmptyResult, error)
NotificationHandler func(ctx context.Context, notification mcp.JSONRPCNotification)
}

// WithPaginationLimit sets the pagination limit for the server.
Expand Down Expand Up @@ -647,9 +660,21 @@ func (s *MCPServer) AddNotificationHandler(

func (s *MCPServer) handleInitialize(
ctx context.Context,
_ any,
id any,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, *requestError) {
if s.InitializeHandler != nil {
result, err := s.InitializeHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

capabilities := mcp.ServerCapabilities{}

// Only add resource capabilities if they're configured
Expand Down Expand Up @@ -729,10 +754,21 @@ func (s *MCPServer) protocolVersion(clientVersion string) string {
}

func (s *MCPServer) handlePing(
_ context.Context,
_ any,
_ mcp.PingRequest,
ctx context.Context,
id any,
request mcp.PingRequest,
) (*mcp.EmptyResult, *requestError) {
if s.PingHandler != nil {
result, err := s.PingHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}
return &mcp.EmptyResult{}, nil
}

Expand All @@ -741,6 +777,18 @@ func (s *MCPServer) handleSetLevel(
id any,
request mcp.SetLevelRequest,
) (*mcp.EmptyResult, *requestError) {
if s.SetLevelHandler != nil {
result, err := s.SetLevelHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

clientSession := ClientSessionFromContext(ctx)
if clientSession == nil || !clientSession.Initialized() {
return nil, &requestError{
Expand Down Expand Up @@ -820,6 +868,18 @@ func (s *MCPServer) handleListResources(
id any,
request mcp.ListResourcesRequest,
) (*mcp.ListResourcesResult, *requestError) {
if s.ListResourcesHandler != nil {
result, err := s.ListResourcesHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.resourcesMu.RLock()
resources := make([]mcp.Resource, 0, len(s.resources))
for _, entry := range s.resources {
Expand Down Expand Up @@ -858,6 +918,18 @@ func (s *MCPServer) handleListResourceTemplates(
id any,
request mcp.ListResourceTemplatesRequest,
) (*mcp.ListResourceTemplatesResult, *requestError) {
if s.ListResourceTemplatesHandler != nil {
result, err := s.ListResourceTemplatesHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.resourcesMu.RLock()
templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
for _, entry := range s.resourceTemplates {
Expand Down Expand Up @@ -894,6 +966,18 @@ func (s *MCPServer) handleReadResource(
id any,
request mcp.ReadResourceRequest,
) (*mcp.ReadResourceResult, *requestError) {
if s.ReadResourceHandler != nil {
result, err := s.ReadResourceHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.resourcesMu.RLock()
// First try direct resource handlers
if entry, ok := s.resources[request.Params.URI]; ok {
Expand Down Expand Up @@ -972,6 +1056,18 @@ func (s *MCPServer) handleListPrompts(
id any,
request mcp.ListPromptsRequest,
) (*mcp.ListPromptsResult, *requestError) {
if s.ListPromptsHandler != nil {
result, err := s.ListPromptsHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.promptsMu.RLock()
prompts := make([]mcp.Prompt, 0, len(s.prompts))
for _, prompt := range s.prompts {
Expand Down Expand Up @@ -1010,6 +1106,18 @@ func (s *MCPServer) handleGetPrompt(
id any,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, *requestError) {
if s.GetPromptHandler != nil {
result, err := s.GetPromptHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.promptsMu.RLock()
handler, ok := s.promptHandlers[request.Params.Name]
s.promptsMu.RUnlock()
Expand Down Expand Up @@ -1039,6 +1147,17 @@ func (s *MCPServer) handleListTools(
id any,
request mcp.ListToolsRequest,
) (*mcp.ListToolsResult, *requestError) {
if s.ListToolsHandler != nil {
result, err := s.ListToolsHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}
// Get the base tools from the server
s.toolsMu.RLock()
tools := make([]mcp.Tool, 0, len(s.tools))
Expand Down Expand Up @@ -1129,6 +1248,17 @@ func (s *MCPServer) handleToolCall(
id any,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, *requestError) {
if s.CallToolHandler != nil {
result, err := s.CallToolHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}
// First check session-specific tools
var tool ServerTool
var ok bool
Expand Down Expand Up @@ -1188,6 +1318,10 @@ func (s *MCPServer) handleNotification(
ctx context.Context,
notification mcp.JSONRPCNotification,
) mcp.JSONRPCMessage {
if s.NotificationHandler != nil {
s.NotificationHandler(ctx, notification)
return nil
}
s.notificationHandlersMu.RLock()
handler, ok := s.notificationHandlers[notification.Method]
s.notificationHandlersMu.RUnlock()
Expand Down
8 changes: 6 additions & 2 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,12 @@ func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID strin
if s.useFullURLForMessageEndpoint && s.baseURL != "" {
endpointPath = s.baseURL + endpointPath
}

return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID)
if strings.Contains(endpointPath, "?") {
endpointPath += "&"
} else {
endpointPath += "?"
}
return fmt.Sprintf("%ssessionId=%s", endpointPath, sessionID)
}

// handleMessage processes incoming JSON-RPC messages from clients and sends responses
Expand Down