Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/mark3labs/mcp-go

go 1.23
go 1.23.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

go failed to run on macos without adding in this patch version specifier 🤷


require (
github.com/google/uuid v1.6.0
Expand Down
59 changes: 51 additions & 8 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
package server

import (
"cmp"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"maps"
"slices"
"sort"
"sync"
Expand Down Expand Up @@ -826,21 +828,36 @@ func (s *MCPServer) handleListResources(
request mcp.ListResourcesRequest,
) (*mcp.ListResourcesResult, *requestError) {
s.resourcesMu.RLock()
resources := make([]mcp.Resource, 0, len(s.resources))
for _, entry := range s.resources {
resources = append(resources, entry.resource)
resourceMap := make(map[string]mcp.Resource, len(s.resources))
for uri, entry := range s.resources {
resourceMap[uri] = entry.resource
}
s.resourcesMu.RUnlock()

// Check if there are session-specific resources
session := ClientSessionFromContext(ctx)
if session != nil {
if sessionWithResources, ok := session.(SessionWithResources); ok {
if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil {
// Merge session-specific resources with global resources
for uri, serverResource := range sessionResources {
resourceMap[uri] = serverResource.Resource
}
}
}
}

// Sort the resources by name
sort.Slice(resources, func(i, j int) bool {
return resources[i].Name < resources[j].Name
resourcesList := slices.SortedFunc(maps.Values(resourceMap), func(a, b mcp.Resource) int {
return cmp.Compare(a.Name, b.Name)
})

// Apply pagination
resourcesToReturn, nextCursor, err := listByPagination(
ctx,
s,
request.Params.Cursor,
resources,
resourcesList,
)
if err != nil {
return nil, &requestError{
Expand Down Expand Up @@ -900,9 +917,35 @@ func (s *MCPServer) handleReadResource(
request mcp.ReadResourceRequest,
) (*mcp.ReadResourceResult, *requestError) {
s.resourcesMu.RLock()

// First check session-specific resources
var handler ResourceHandlerFunc
var ok bool

session := ClientSessionFromContext(ctx)
if session != nil {
if sessionWithResources, typeAssertOk := session.(SessionWithResources); typeAssertOk {
if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil {
resource, sessionOk := sessionResources[request.Params.URI]
if sessionOk {
handler = resource.Handler
ok = true
}
}
}
}

// If not found in session tools, check global tools
if !ok {
globalResource, rok := s.resources[request.Params.URI]
if rok {
handler = globalResource.handler
ok = true
}
}

// First try direct resource handlers
if entry, ok := s.resources[request.Params.URI]; ok {
handler := entry.handler
if ok {
s.resourcesMu.RUnlock()

finalHandler := handler
Expand Down
3 changes: 1 addition & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,8 @@ func TestMCPServer_HandleValidMessages(t *testing.T) {
resp, ok := response.(mcp.JSONRPCResponse)
assert.True(t, ok)

listResult, ok := resp.Result.(mcp.ListResourcesResult)
_, ok = resp.Result.(mcp.ListResourcesResult)
assert.True(t, ok)
assert.NotNil(t, listResult.Resources)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this NotNil was erroring because the new listResources code returns a nil slice instead of a zero-length non-nil slice. My understanding is that there is not an important semantic different between the two, so I don't think it makes sense to assert non-nil here.

},
},
}
Expand Down
11 changes: 11 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ type SessionWithTools interface {
SetSessionTools(tools map[string]ServerTool)
}

// SessionWithResources is an extension of ClientSession that can store session-specific resource data
type SessionWithResources interface {
ClientSession
// GetSessionResources returns the resources specific to this session, if any
// This method must be thread-safe for concurrent access
GetSessionResources() map[string]ServerResource
// SetSessionResources sets resources specific to this session
// This method must be thread-safe for concurrent access
SetSessionResources(resources map[string]ServerResource)
}

// SessionWithClientInfo is an extension of ClientSession that can store client info
type SessionWithClientInfo interface {
ClientSession
Expand Down
127 changes: 126 additions & 1 deletion server/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"maps"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -100,6 +101,60 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool
f.sessionTools = toolsCopy
}

// sessionTestClientWithResources implements the SessionWithResources interface for testing
type sessionTestClientWithResources struct {
sessionID string
notificationChannel chan mcp.JSONRPCNotification
initialized bool
sessionResources map[string]ServerResource
mu sync.RWMutex // Mutex to protect concurrent access to sessionResources
}

func (f *sessionTestClientWithResources) SessionID() string {
return f.sessionID
}

func (f *sessionTestClientWithResources) NotificationChannel() chan<- mcp.JSONRPCNotification {
return f.notificationChannel
}

func (f *sessionTestClientWithResources) Initialize() {
f.initialized = true
}

func (f *sessionTestClientWithResources) Initialized() bool {
return f.initialized
}

func (f *sessionTestClientWithResources) GetSessionResources() map[string]ServerResource {
f.mu.RLock()
defer f.mu.RUnlock()

if f.sessionResources == nil {
return nil
}

// Return a copy of the map to prevent concurrent modification
resourcesCopy := make(map[string]ServerResource, len(f.sessionResources))
maps.Copy(resourcesCopy, f.sessionResources)
return resourcesCopy
}

func (f *sessionTestClientWithResources) SetSessionResources(resources map[string]ServerResource) {
f.mu.Lock()
defer f.mu.Unlock()

if resources == nil {
f.sessionResources = nil
return
}

// Create a copy of the map to prevent concurrent modification
resourcesCopy := make(map[string]ServerResource, len(resources))
maps.Copy(resourcesCopy, resources)
f.sessionResources = resourcesCopy
}

// sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing
type sessionTestClientWithClientInfo struct {
sessionID string
Expand Down Expand Up @@ -151,7 +206,7 @@ func (f *sessionTestClientWithClientInfo) SetClientCapabilities(clientCapabiliti
f.clientCapabilities.Store(clientCapabilities)
}

// sessionTestClientWithTools implements the SessionWithLogging interface for testing
// sessionTestClientWithLogging implements the SessionWithLogging interface for testing
type sessionTestClientWithLogging struct {
sessionID string
notificationChannel chan mcp.JSONRPCNotification
Expand Down Expand Up @@ -190,6 +245,7 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
var (
_ ClientSession = (*sessionTestClient)(nil)
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
_ SessionWithResources = (*sessionTestClientWithResources)(nil)
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
_ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
)
Expand Down Expand Up @@ -260,6 +316,75 @@ func TestSessionWithTools_Integration(t *testing.T) {
})
}

func TestSessionWithResources_Integration(t *testing.T) {
server := NewMCPServer("test-server", "1.0.0")

// Create session-specific resources
sessionResource := ServerResource{
Resource: mcp.NewResource("ui://resource", "session-resource"),
Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
return []mcp.ResourceContents{mcp.TextResourceContents{
URI: "ui://resource",
Text: "session-resource result",
}}, nil
},
}

// Create a session with resources
session := &sessionTestClientWithResources{
sessionID: "session-1",
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
initialized: true,
sessionResources: map[string]ServerResource{
"ui://resource": sessionResource,
},
}

// Register the session
err := server.RegisterSession(context.Background(), session)
require.NoError(t, err)

// Test that we can access the session-specific resource
testReq := mcp.ReadResourceRequest{}
testReq.Params.URI = "ui://resource"
testReq.Params.Arguments = map[string]any{}

// Call using session context
sessionCtx := server.WithContext(context.Background(), session)

// Check if the session was stored in the context correctly
s := ClientSessionFromContext(sessionCtx)
require.NotNil(t, s, "Session should be available from context")
assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match")

// Check if the session can be cast to SessionWithResources
swr, ok := s.(SessionWithResources)
require.True(t, ok, "Session should implement SessionWithResources")

// Check if the resources are accessible
resources := swr.GetSessionResources()
require.NotNil(t, resources, "Session resources should be available")
require.Contains(t, resources, "ui://resource", "Session should have ui://resource")

// Test session resource access with session context
t.Run("test session resource access", func(t *testing.T) {
// First test directly getting the resource from session resources
resource, exists := resources["ui://resource"]
require.True(t, exists, "Session resource should exist in the map")
require.NotNil(t, resource, "Session resource should not be nil")

// Now test calling directly with the handler
result, err := resource.Handler(sessionCtx, testReq)
require.NoError(t, err, "No error calling session resource handler directly")
require.NotNil(t, result, "Result should not be nil")
require.Len(t, result, 1, "Result should have one content item")

textContent, ok := result[0].(mcp.TextResourceContents)
require.True(t, ok, "Content should be TextResourceContents")
assert.Equal(t, "session-resource result", textContent.Text, "Result text should match")
})
}

func TestMCPServer_ToolsWithSessionTools(t *testing.T) {
// Basic test to verify that session-specific tools are returned correctly in a tools list
server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true))
Expand Down
23 changes: 23 additions & 0 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type sseSession struct {
initialized atomic.Bool
loggingLevel atomic.Value
tools sync.Map // stores session-specific tools
resources sync.Map // stores session-specific resources
clientInfo atomic.Value // stores session-specific client info
clientCapabilities atomic.Value // stores session-specific client capabilities
}
Expand Down Expand Up @@ -75,6 +76,27 @@ func (s *sseSession) GetLogLevel() mcp.LoggingLevel {
return level.(mcp.LoggingLevel)
}

func (s *sseSession) GetSessionResources() map[string]ServerResource {
resources := make(map[string]ServerResource)
s.resources.Range(func(key, value any) bool {
if resource, ok := value.(ServerResource); ok {
resources[key.(string)] = resource
}
return true
})
return resources
}

func (s *sseSession) SetSessionResources(resources map[string]ServerResource) {
// Clear existing resources
s.resources.Clear()

// Set new resources
for name, resource := range resources {
s.resources.Store(name, resource)
}
}

func (s *sseSession) GetSessionTools() map[string]ServerTool {
tools := make(map[string]ServerTool)
s.tools.Range(func(key, value any) bool {
Expand Down Expand Up @@ -125,6 +147,7 @@ func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities {
var (
_ ClientSession = (*sseSession)(nil)
_ SessionWithTools = (*sseSession)(nil)
_ SessionWithResources = (*sseSession)(nil)
_ SessionWithLogging = (*sseSession)(nil)
_ SessionWithClientInfo = (*sseSession)(nil)
)
Expand Down
Loading
Loading