Skip to content
Merged
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ func main() {

// Add the calculator handler
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
op := request.Params.Arguments["operation"].(string)
x := request.Params.Arguments["x"].(float64)
y := request.Params.Arguments["y"].(float64)
args := request.GetArguments()
op := args["operation"].(string)
x := args["x"].(float64)
y := args["y"].(float64)

var result float64
switch op {
Expand Down Expand Up @@ -312,9 +313,10 @@ calculatorTool := mcp.NewTool("calculate",
)

s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
op := request.Params.Arguments["operation"].(string)
x := request.Params.Arguments["x"].(float64)
y := request.Params.Arguments["y"].(float64)
args := request.GetArguments()
op := args["operation"].(string)
x := args["x"].(float64)
y := args["y"].(float64)

var result float64
switch op {
Expand Down Expand Up @@ -355,10 +357,11 @@ httpTool := mcp.NewTool("http_request",
)

s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
method := request.Params.Arguments["method"].(string)
url := request.Params.Arguments["url"].(string)
args := request.GetArguments()
method := args["method"].(string)
url := args["url"].(string)
body := ""
if b, ok := request.Params.Arguments["body"].(string); ok {
if b, ok := args["body"].(string); ok {
body = b
}

Expand Down
2 changes: 1 addition & 1 deletion client/inprocess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestInProcessMCPClient(t *testing.T) {
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string),
Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string),
},
mcp.AudioContent{
Type: "audio",
Expand Down
2 changes: 1 addition & 1 deletion client/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestSSEMCPClient(t *testing.T) {
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string),
Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string),
},
},
}, nil
Expand Down
2 changes: 1 addition & 1 deletion examples/custom_context/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func handleMakeAuthenticatedRequestTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
message, ok := request.Params.Arguments["message"].(string)
message, ok := request.GetArguments()["message"].(string)
if !ok {
return nil, fmt.Errorf("missing message")
}
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamic_path/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func main() {

// Add a trivial tool for demonstration
mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.Params.Arguments["message"])), nil
return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil
})

// Use a dynamic base path based on a path parameter (Go 1.22+)
Expand Down
6 changes: 3 additions & 3 deletions examples/everything/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func handleEchoTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
arguments := request.GetArguments()
message, ok := arguments["message"].(string)
if !ok {
return nil, fmt.Errorf("invalid message argument")
Expand All @@ -331,7 +331,7 @@ func handleAddTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
arguments := request.GetArguments()
a, ok1 := arguments["a"].(float64)
b, ok2 := arguments["b"].(float64)
if !ok1 || !ok2 {
Expand Down Expand Up @@ -382,7 +382,7 @@ func handleLongRunningOperationTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
arguments := request.GetArguments()
progressToken := request.Params.Meta.ProgressToken
duration, _ := arguments["duration"].(float64)
steps, _ := arguments["steps"].(float64)
Expand Down
19 changes: 17 additions & 2 deletions mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ type CallToolResult struct {
type CallToolRequest struct {
Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Name string `json:"name"`
Arguments any `json:"arguments,omitempty"` // Can be map[string]any or any other type
Meta *struct {
// If specified, the caller is requesting out-of-band progress
// notifications for this request (as represented by
Expand All @@ -58,6 +58,21 @@ type CallToolRequest struct {
} `json:"params"`
}

// GetArguments returns the Arguments as map[string]any for backward compatibility
// If Arguments is not a map, it returns an empty map
func (r CallToolRequest) GetArguments() map[string]any {
if args, ok := r.Params.Arguments.(map[string]any); ok {
return args
}
return map[string]any{}
}

// GetRawArguments returns the Arguments as-is without type conversion
// This allows users to access the raw arguments in any format
func (r CallToolRequest) GetRawArguments() any {
return r.Params.Arguments
}

// ToolListChangedNotification is an optional notification from the server to
// the client, informing it that the list of tools it offers has changed. This may
// be issued by servers without any previous subscription from the client.
Expand Down
98 changes: 98 additions & 0 deletions mcp/tools_arguments_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package mcp

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
)

func TestCallToolRequestWithMapArguments(t *testing.T) {
// Create a request with map arguments
req := CallToolRequest{}
req.Params.Name = "test-tool"
req.Params.Arguments = map[string]any{
"key1": "value1",
"key2": 123,
}

// Test GetArguments
args := req.GetArguments()
assert.Equal(t, "value1", args["key1"])
assert.Equal(t, 123, args["key2"])

// Test GetRawArguments
rawArgs := req.GetRawArguments()
mapArgs, ok := rawArgs.(map[string]any)
assert.True(t, ok)
assert.Equal(t, "value1", mapArgs["key1"])
assert.Equal(t, 123, mapArgs["key2"])
}

func TestCallToolRequestWithNonMapArguments(t *testing.T) {
// Create a request with non-map arguments
req := CallToolRequest{}
req.Params.Name = "test-tool"
req.Params.Arguments = "string-argument"

// Test GetArguments (should return empty map)
args := req.GetArguments()
assert.Empty(t, args)

// Test GetRawArguments
rawArgs := req.GetRawArguments()
strArg, ok := rawArgs.(string)
assert.True(t, ok)
assert.Equal(t, "string-argument", strArg)
}

func TestCallToolRequestWithStructArguments(t *testing.T) {
// Create a custom struct
type CustomArgs struct {
Field1 string `json:"field1"`
Field2 int `json:"field2"`
}

// Create a request with struct arguments
req := CallToolRequest{}
req.Params.Name = "test-tool"
req.Params.Arguments = CustomArgs{
Field1: "test",
Field2: 42,
}

// Test GetArguments (should return empty map)
args := req.GetArguments()
assert.Empty(t, args)

// Test GetRawArguments
rawArgs := req.GetRawArguments()
structArg, ok := rawArgs.(CustomArgs)
assert.True(t, ok)
assert.Equal(t, "test", structArg.Field1)
assert.Equal(t, 42, structArg.Field2)
}

func TestCallToolRequestJSONMarshalUnmarshal(t *testing.T) {
// Create a request with map arguments
req := CallToolRequest{}
req.Params.Name = "test-tool"
req.Params.Arguments = map[string]any{
"key1": "value1",
"key2": 123,
}

// Marshal to JSON
data, err := json.Marshal(req)
assert.NoError(t, err)

// Unmarshal from JSON
var unmarshaledReq CallToolRequest
err = json.Unmarshal(data, &unmarshaledReq)
assert.NoError(t, err)

// Check if arguments are correctly unmarshaled
args := unmarshaledReq.GetArguments()
assert.Equal(t, "value1", args["key1"])
assert.Equal(t, float64(123), args["key2"]) // JSON numbers are unmarshaled as float64
}
5 changes: 3 additions & 2 deletions mcp/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,10 +675,11 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult,
}

func ParseArgument(request CallToolRequest, key string, defaultVal any) any {
if _, ok := request.Params.Arguments[key]; !ok {
args := request.GetArguments()
if _, ok := args[key]; !ok {
return defaultVal
} else {
return request.Params.Arguments[key]
return args[key]
}
}

Expand Down
2 changes: 1 addition & 1 deletion mcptest/mcptest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestServer(t *testing.T) {

func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Extract name from request arguments
name, ok := request.Params.Arguments["name"].(string)
name, ok := request.GetArguments()["name"].(string)
if !ok {
name = "World"
}
Expand Down