Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
.claude
coverage.out
coverage.txt
.vscode/launch.json
72 changes: 72 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Client struct {
serverCapabilities mcp.ServerCapabilities
protocolVersion string
samplingHandler SamplingHandler
rootsHandler RootsHandler
elicitationHandler ElicitationHandler
}

Expand All @@ -44,6 +45,15 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption {
}
}

// WithRootsHandler sets the roots handler for the client.
// WithRootsHandler returns a ClientOption that sets the client's RootsHandler.
// When provided, the client will declare the roots capability (ListChanged) during initialization.
func WithRootsHandler(handler RootsHandler) ClientOption {
return func(c *Client) {
c.rootsHandler = handler
}
}

// WithElicitationHandler sets the elicitation handler for the client.
// When set, the client will declare elicitation capability during initialization.
func WithElicitationHandler(handler ElicitationHandler) ClientOption {
Expand Down Expand Up @@ -177,6 +187,13 @@ func (c *Client) Initialize(
if c.samplingHandler != nil {
capabilities.Sampling = &struct{}{}
}
if c.rootsHandler != nil {
capabilities.Roots = &struct {
ListChanged bool `json:"listChanged,omitempty"`
}{
ListChanged: true,
}
}
// Add elicitation capability if handler is configured
if c.elicitationHandler != nil {
capabilities.Elicitation = &struct{}{}
Expand Down Expand Up @@ -464,6 +481,28 @@ func (c *Client) Complete(
return &result, nil
}

// RootListChanges sends a roots list-changed notification to the server.
func (c *Client) RootListChanges(
ctx context.Context,
) error {
// Send root list changes notification
notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Method: mcp.MethodNotificationRootsListChanged,
},
}

err := c.transport.SendNotification(ctx, notification)
if err != nil {
return fmt.Errorf(
"failed to send root list change notification: %w",
err,
)
}
return nil
}

// handleIncomingRequest processes incoming requests from the server.
// This is the main entry point for server-to-client requests like sampling and elicitation.
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
Expand All @@ -474,6 +513,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS
return c.handleElicitationRequestTransport(ctx, request)
case string(mcp.MethodPing):
return c.handlePingRequestTransport(ctx, request)
case string(mcp.MethodListRoots):
return c.handleListRootsRequestTransport(ctx, request)
default:
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
}
Expand Down Expand Up @@ -536,6 +577,37 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra
return response, nil
}

// handleListRootsRequestTransport handles list roots requests at the transport level.
func (c *Client) handleListRootsRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.rootsHandler == nil {
return nil, fmt.Errorf("no roots handler configured")
}

// Create the MCP request
mcpRequest := mcp.ListRootsRequest{
Request: mcp.Request{
Method: string(mcp.MethodListRoots),
},
}

// Call the list roots handler
result, err := c.rootsHandler.ListRoots(ctx, mcpRequest)
if err != nil {
return nil, err
}

// Marshal the result
resultBytes, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal result: %w", err)
}

// Create the transport response
response := transport.NewJSONRPCResultResponse(request.ID, json.RawMessage(resultBytes))

return response, nil
}

// handleElicitationRequestTransport handles elicitation requests at the transport level.
func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.elicitationHandler == nil {
Expand Down
17 changes: 17 additions & 0 deletions client/roots.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package client

import (
"context"

"github.com/mark3labs/mcp-go/mcp"
)

// RootsHandler defines the interface for handling roots requests from servers.
// Clients can implement this interface to provide roots list to servers.
type RootsHandler interface {
// ListRoots handles a list root request from the server and returns the roots list.
// The implementation should:
// 1. Validate input against the requested schema
// 2. Return the appropriate response
ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error)
}
11 changes: 9 additions & 2 deletions client/transport/inprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type InProcessTransport struct {
server *server.MCPServer
samplingHandler server.SamplingHandler
elicitationHandler server.ElicitationHandler
rootsHandler server.RootsHandler
session *server.InProcessSession
sessionID string

Expand All @@ -37,6 +38,12 @@ func WithElicitationHandler(handler server.ElicitationHandler) InProcessOption {
}
}

func WithRootsHandler(handler server.RootsHandler) InProcessOption {
return func(t *InProcessTransport) {
t.rootsHandler = handler
}
}

func NewInProcessTransport(server *server.MCPServer) *InProcessTransport {
return &InProcessTransport{
server: server,
Expand Down Expand Up @@ -66,8 +73,8 @@ func (c *InProcessTransport) Start(ctx context.Context) error {
c.startedMu.Unlock()

// Create and register session if we have handlers
if c.samplingHandler != nil || c.elicitationHandler != nil {
c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler)
if c.samplingHandler != nil || c.elicitationHandler != nil || c.rootsHandler != nil {
c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler, c.rootsHandler)
if err := c.server.RegisterSession(ctx, c.session); err != nil {
c.startedMu.Lock()
c.started = false
Expand Down
151 changes: 151 additions & 0 deletions examples/roots_client/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package main

import (
"context"
"fmt"
"log"
"net/url"
"os"
"os/signal"
"path/filepath"
"syscall"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)

// MockRootsHandler implements client.RootsHandler for demonstration.
// In a real implementation, this would enumerate workspace/project roots.
type MockRootsHandler struct{}

func (h *MockRootsHandler) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) {
home, err := os.UserHomeDir()
if err != nil {
log.Printf("Warning: failed to get home directory: %v", err)
home = "/tmp" // fallback for demonstration
}
app := filepath.ToSlash(filepath.Join(home, "app"))
proj := filepath.ToSlash(filepath.Join(home, "projects", "test-project"))
result := &mcp.ListRootsResult{
Roots: []mcp.Root{
{
Name: "app",
URI: (&url.URL{Scheme: "file", Path: app}).String(),
},
{
Name: "test-project",
URI: (&url.URL{Scheme: "file", Path: proj}).String(),
},
},
}
return result, nil
}

// main starts a mock MCP roots client that communicates with a subprocess over stdio.
// It expects the server command as the first command-line argument, creates a stdio
// transport and an MCP client with a MockRootsHandler, starts and initializes the
// client, logs server info and available tools, notifies the server of root list
// changes, invokes the "roots" tool and prints any text content returned, and
// shuts down the client gracefully on SIGINT or SIGTERM.
func main() {
if len(os.Args) < 2 {
log.Fatal("Usage: roots_client <server_command>")
}

serverCommand := os.Args[1]
serverArgs := os.Args[2:]

// Create stdio transport to communicate with the server
stdio := transport.NewStdio(serverCommand, nil, serverArgs...)

// Create roots handler
rootsHandler := &MockRootsHandler{}

// Create client with roots capability
mcpClient := client.NewClient(stdio, client.WithRootsHandler(rootsHandler))

ctx := context.Background()

// Start the client
if err := mcpClient.Start(ctx); err != nil {
log.Fatalf("Failed to start client: %v", err)
}

// Setup graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

// Create a context that cancels on signal
ctx, cancel := context.WithCancel(ctx)
go func() {
<-sigChan
log.Println("Received shutdown signal, closing client...")
cancel()
}()

// Move defer after error checking
defer func() {
if err := mcpClient.Close(); err != nil {
log.Printf("Error closing client: %v", err)
}
}()

// Initialize the connection
initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "roots-stdio-client",
Version: "1.0.0",
},
Capabilities: mcp.ClientCapabilities{
// Roots capability will be automatically added by WithRootsHandler
},
},
})
if err != nil {
log.Fatalf("Failed to initialize: %v", err)
}

log.Printf("Connected to server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version)
log.Printf("Server capabilities: %+v", initResult.Capabilities)

// list tools
toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
log.Fatalf("Failed to list tools: %v", err)
}
log.Printf("Available tools:")
for _, tool := range toolsResult.Tools {
log.Printf(" - %s: %s", tool.Name, tool.Description)
}

// call server tool
request := mcp.CallToolRequest{}
request.Params.Name = "roots"
request.Params.Arguments = map[string]any{"testonly": "yes"}
result, err := mcpClient.CallTool(ctx, request)
if err != nil {
log.Fatalf("failed to call tool roots: %v", err)
} else if result.IsError {
log.Printf("tool reported error")
} else if len(result.Content) > 0 {
resultStr := ""
for _, content := range result.Content {
if textContent, ok := content.(mcp.TextContent); ok {
resultStr += fmt.Sprintf("%s\n", textContent.Text)
}
}
fmt.Printf("client call tool result: %s\n", resultStr)
}

// mock the root change
if err := mcpClient.RootListChanges(ctx); err != nil {
log.Printf("fail to notify root list change: %v", err)
}

// Keep running until cancelled by signal
<-ctx.Done()
log.Println("Client context cancelled")
}
Loading