Skip to content

Commit 4e9f1e2

Browse files
committed
Add inputSchema validation to lowlevel server
- Add jsonschema dependency for schema validation - Implement tool definition cache in Server class that gets refreshed when list_tools is called - Add _validate_tool_arguments helper method to validate tool arguments against inputSchema - Update call_tool handler to validate arguments before execution - Log warning and skip validation for tools not found in cache - Add comprehensive tests for validation scenarios This ensures tool arguments are validated against their JSON schemas before execution, providing better error messages and preventing invalid tool calls from reaching handlers.
1 parent 679b229 commit 4e9f1e2

File tree

4 files changed

+499
-1
lines changed

4 files changed

+499
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"sse-starlette>=1.6.1",
3232
"pydantic-settings>=2.5.2",
3333
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
34+
"jsonschema==4.20.0",
3435
]
3536

3637
[project.optional-dependencies]

src/mcp/server/lowlevel/server.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ async def main():
7575
from typing import Any, Generic
7676

7777
import anyio
78+
import jsonschema
7879
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7980
from pydantic import AnyUrl
8081
from typing_extensions import TypeVar
@@ -143,6 +144,7 @@ def __init__(
143144
}
144145
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
145146
self.notification_options = NotificationOptions()
147+
self._tool_cache: dict[str, types.Tool] = {}
146148
logger.debug("Initializing server %r", name)
147149

148150
def create_initialization_options(
@@ -373,13 +375,42 @@ def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
373375

374376
async def handler(_: Any):
375377
tools = await func()
378+
# Refresh the tool cache
379+
self._tool_cache.clear()
380+
for tool in tools:
381+
self._tool_cache[tool.name] = tool
376382
return types.ServerResult(types.ListToolsResult(tools=tools))
377383

378384
self.request_handlers[types.ListToolsRequest] = handler
379385
return func
380386

381387
return decorator
382388

389+
async def _validate_tool_arguments(self, tool_name: str, arguments: dict[str, Any]) -> str | None:
390+
"""Validate tool arguments against inputSchema.
391+
392+
Returns None if validation passes, or an error message if validation fails.
393+
"""
394+
# Check if tool is in cache
395+
if tool_name not in self._tool_cache:
396+
# Try to refresh the cache by calling list_tools
397+
if types.ListToolsRequest in self.request_handlers:
398+
logger.debug("Tool cache miss for %s, refreshing cache", tool_name)
399+
await self.request_handlers[types.ListToolsRequest](None)
400+
401+
# Check again after potential refresh
402+
if tool_name in self._tool_cache:
403+
tool = self._tool_cache[tool_name]
404+
try:
405+
# Validate arguments against inputSchema
406+
jsonschema.validate(instance=arguments, schema=tool.inputSchema)
407+
return None
408+
except jsonschema.ValidationError as e:
409+
return f"Input validation error: {e.message}"
410+
else:
411+
logger.warning("Tool '%s' not found in cache, validation will not be performed", tool_name)
412+
return None
413+
383414
def call_tool(self):
384415
def decorator(
385416
func: Callable[
@@ -391,7 +422,20 @@ def decorator(
391422

392423
async def handler(req: types.CallToolRequest):
393424
try:
394-
results = await func(req.params.name, (req.params.arguments or {}))
425+
tool_name = req.params.name
426+
arguments = req.params.arguments or {}
427+
428+
# Validate arguments
429+
validation_error = await self._validate_tool_arguments(tool_name, arguments)
430+
if validation_error:
431+
return types.ServerResult(
432+
types.CallToolResult(
433+
content=[types.TextContent(type="text", text=validation_error)],
434+
isError=True,
435+
)
436+
)
437+
438+
results = await func(tool_name, arguments)
395439
return types.ServerResult(types.CallToolResult(content=list(results), isError=False))
396440
except Exception as e:
397441
return types.ServerResult(

0 commit comments

Comments
 (0)