Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 49 additions & 76 deletions examples/everything/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,30 @@ const (
COMPLEX PromptName = "complex_prompt"
)

type MCPServer struct {
server *server.MCPServer
subscriptions map[string]bool
updateTicker *time.Ticker
allResources []mcp.Resource
}

func NewMCPServer() *MCPServer {
s := &MCPServer{
server: server.NewMCPServer(
"example-servers/everything",
"1.0.0",
server.WithResourceCapabilities(true, true),
server.WithPromptCapabilities(true),
server.WithLogging(),
),
subscriptions: make(map[string]bool),
updateTicker: time.NewTicker(5 * time.Second),
allResources: generateResources(),
}
func NewMCPServer() *server.MCPServer {
mcpServer := server.NewMCPServer(
"example-servers/everything",
"1.0.0",
server.WithResourceCapabilities(true, true),
server.WithPromptCapabilities(true),
server.WithLogging(),
)

s.server.AddResource(mcp.NewResource("test://static/resource",
mcpServer.AddResource(mcp.NewResource("test://static/resource",
"Static Resource",
mcp.WithMIMEType("text/plain"),
), s.handleReadResource)
s.server.AddResourceTemplate(
), handleReadResource)
mcpServer.AddResourceTemplate(
mcp.NewResourceTemplate(
"test://dynamic/resource/{id}",
"Dynamic Resource",
),
s.handleResourceTemplate,
handleResourceTemplate,
)
s.server.AddPrompt(mcp.NewPrompt(string(SIMPLE),
mcpServer.AddPrompt(mcp.NewPrompt(string(SIMPLE),
mcp.WithPromptDescription("A simple prompt"),
), s.handleSimplePrompt)
s.server.AddPrompt(mcp.NewPrompt(string(COMPLEX),
), handleSimplePrompt)
mcpServer.AddPrompt(mcp.NewPrompt(string(COMPLEX),
mcp.WithPromptDescription("A complex prompt"),
mcp.WithArgument("temperature",
mcp.ArgumentDescription("The temperature parameter for generation"),
Expand All @@ -73,21 +61,21 @@ func NewMCPServer() *MCPServer {
mcp.ArgumentDescription("The style to use for the response"),
mcp.RequiredArgument(),
),
), s.handleComplexPrompt)
s.server.AddTool(mcp.NewTool(string(ECHO),
), handleComplexPrompt)
mcpServer.AddTool(mcp.NewTool(string(ECHO),
mcp.WithDescription("Echoes back the input"),
mcp.WithString("message",
mcp.Description("Message to echo"),
mcp.Required(),
),
), s.handleEchoTool)
), handleEchoTool)

s.server.AddTool(
mcpServer.AddTool(
mcp.NewTool("notify"),
s.handleSendNotification,
handleSendNotification,
)

s.server.AddTool(mcp.NewTool(string(ADD),
mcpServer.AddTool(mcp.NewTool(string(ADD),
mcp.WithDescription("Adds two numbers"),
mcp.WithNumber("a",
mcp.Description("First number"),
Expand All @@ -97,8 +85,8 @@ func NewMCPServer() *MCPServer {
mcp.Description("Second number"),
mcp.Required(),
),
), s.handleAddTool)
s.server.AddTool(mcp.NewTool(
), handleAddTool)
mcpServer.AddTool(mcp.NewTool(
string(LONG_RUNNING_OPERATION),
mcp.WithDescription(
"Demonstrates a long running operation with progress updates",
Expand All @@ -111,7 +99,7 @@ func NewMCPServer() *MCPServer {
mcp.Description("Number of steps in the operation"),
mcp.DefaultNumber(5),
),
), s.handleLongRunningOperationTool)
), handleLongRunningOperationTool)

// s.server.AddTool(mcp.Tool{
// Name: string(SAMPLE_LLM),
Expand All @@ -131,15 +119,13 @@ func NewMCPServer() *MCPServer {
// },
// },
// }, s.handleSampleLLMTool)
s.server.AddTool(mcp.NewTool(string(GET_TINY_IMAGE),
mcpServer.AddTool(mcp.NewTool(string(GET_TINY_IMAGE),
mcp.WithDescription("Returns the MCP_TINY_IMAGE"),
), s.handleGetTinyImageTool)

s.server.AddNotificationHandler("notification", s.handleNotification)
), handleGetTinyImageTool)

go s.runUpdateInterval()
mcpServer.AddNotificationHandler("notification", handleNotification)

return s
return mcpServer
}

func generateResources() []mcp.Resource {
Expand All @@ -163,7 +149,7 @@ func generateResources() []mcp.Resource {
return resources
}

func (s *MCPServer) runUpdateInterval() {
func runUpdateInterval() {
// for range s.updateTicker.C {
// for uri := range s.subscriptions {
// s.server.HandleMessage(
Expand All @@ -184,7 +170,7 @@ func (s *MCPServer) runUpdateInterval() {
// }
}

func (s *MCPServer) handleReadResource(
func handleReadResource(
ctx context.Context,
request mcp.ReadResourceRequest,
) ([]interface{}, error) {
Expand All @@ -199,7 +185,7 @@ func (s *MCPServer) handleReadResource(
}, nil
}

func (s *MCPServer) handleResourceTemplate(
func handleResourceTemplate(
ctx context.Context,
request mcp.ReadResourceRequest,
) ([]interface{}, error) {
Expand All @@ -214,7 +200,7 @@ func (s *MCPServer) handleResourceTemplate(
}, nil
}

func (s *MCPServer) handleSimplePrompt(
func handleSimplePrompt(
ctx context.Context,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, error) {
Expand All @@ -232,7 +218,7 @@ func (s *MCPServer) handleSimplePrompt(
}, nil
}

func (s *MCPServer) handleComplexPrompt(
func handleComplexPrompt(
ctx context.Context,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, error) {
Expand Down Expand Up @@ -270,7 +256,7 @@ func (s *MCPServer) handleComplexPrompt(
}, nil
}

func (s *MCPServer) handleEchoTool(
func handleEchoTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
Expand All @@ -289,7 +275,7 @@ func (s *MCPServer) handleEchoTool(
}, nil
}

func (s *MCPServer) handleAddTool(
func handleAddTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
Expand All @@ -310,7 +296,7 @@ func (s *MCPServer) handleAddTool(
}, nil
}

func (s *MCPServer) handleSendNotification(
func handleSendNotification(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
Expand Down Expand Up @@ -339,15 +325,11 @@ func (s *MCPServer) handleSendNotification(
}, nil
}

func (s *MCPServer) ServeSSE(addr string) *server.SSEServer {
return server.NewSSEServer(s.server, fmt.Sprintf("http://%s", addr))
}

func (s *MCPServer) ServeStdio() error {
return server.ServeStdio(s.server)
func ServeSSE(mcpServer *server.MCPServer, addr string) *server.SSEServer {
return server.NewSSEServer(mcpServer, fmt.Sprintf("http://%s", addr))
}

func (s *MCPServer) handleLongRunningOperationTool(
func handleLongRunningOperationTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
Expand Down Expand Up @@ -407,7 +389,7 @@ func (s *MCPServer) handleLongRunningOperationTool(
// }, nil
// }

func (s *MCPServer) handleGetTinyImageTool(
func handleGetTinyImageTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
Expand All @@ -430,17 +412,13 @@ func (s *MCPServer) handleGetTinyImageTool(
}, nil
}

func (s *MCPServer) handleNotification(
func handleNotification(
ctx context.Context,
notification mcp.JSONRPCNotification,
) {
log.Printf("Received notification: %s", notification.Method)
}

func (s *MCPServer) Serve() error {
return server.ServeStdio(s.server)
}

func main() {
var transport string
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
Expand All @@ -452,24 +430,19 @@ func main() {
)
flag.Parse()

server := NewMCPServer()
mcpServer := NewMCPServer()

switch transport {
case "stdio":
if err := server.ServeStdio(); err != nil {
log.Fatalf("Server error: %v", err)
}
case "sse":
sseServer := server.ServeSSE("localhost:8080")
// Only check for "sse" since stdio is the default
if transport == "sse" {
sseServer := ServeSSE(mcpServer, "localhost:8080")
log.Printf("SSE server listening on :8080")
if err := sseServer.Start(":8080"); err != nil {
log.Fatalf("Server error: %v", err)
}
default:
log.Fatalf(
"Invalid transport type: %s. Must be 'stdio' or 'sse'",
transport,
)
} else {
if err := server.ServeStdio(mcpServer); err != nil {
log.Fatalf("Server error: %v", err)
}
}
}

Expand Down
12 changes: 12 additions & 0 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,15 @@ func (s *SSEServer) SendEventToSession(
return nil
}
}

// ServeHTTP implements the http.Handler interface.
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/sse":
s.handleSSE(w, r)
case "/message":
s.handleMessage(w, r)
default:
http.NotFound(w, r)
}
}
Loading