@@ -68,11 +68,12 @@ async def main():
6868from __future__ import annotations as _annotations
6969
7070import contextvars
71+ import json
7172import logging
7273import warnings
7374from collections .abc import AsyncIterator , Awaitable , Callable , Iterable
7475from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
75- from typing import Any , Generic
76+ from typing import Any , Generic , cast
7677
7778import anyio
7879import jsonschema
@@ -386,36 +387,48 @@ async def handler(_: Any):
386387
387388 return decorator
388389
389- async def _validate_tool_arguments (self , tool_name : str , arguments : dict [str , Any ]) -> str | None :
390- """Validate tool arguments against inputSchema.
390+ def _make_error_result (self , error_message : str ) -> types .ServerResult :
391+ """Create a ServerResult with an error CallToolResult."""
392+ return types .ServerResult (
393+ types .CallToolResult (
394+ content = [types .TextContent (type = "text" , text = error_message )],
395+ isError = True ,
396+ )
397+ )
398+
399+ async def _get_cached_tool_definition (self , tool_name : str ) -> types .Tool | None :
400+ """Get tool definition from cache, refreshing if necessary.
391401
392- Returns None if validation passes, or an error message if validation fails .
402+ Returns the Tool object if found, None otherwise .
393403 """
394- # Check if tool is in cache
395404 if tool_name not in self ._tool_cache :
396- # Try to refresh the cache by calling list_tools
397405 if types .ListToolsRequest in self .request_handlers :
398406 logger .debug ("Tool cache miss for %s, refreshing cache" , tool_name )
399407 await self .request_handlers [types .ListToolsRequest ](None )
400408
401- # Check again after potential refresh
402- if tool_name in self ._tool_cache :
403- tool = self ._tool_cache [tool_name ]
404- try :
405- # Validate arguments against inputSchema
406- jsonschema .validate (instance = arguments , schema = tool .inputSchema )
407- return None
408- except jsonschema .ValidationError as e :
409- return f"Input validation error: { e .message } "
410- else :
411- logger .warning ("Tool '%s' not found in cache, validation will not be performed" , tool_name )
412- return None
409+ tool = self ._tool_cache .get (tool_name )
410+ if tool is None :
411+ logger .warning ("Tool '%s' not listed, no validation will be performed" , tool_name )
412+
413+ return tool
413414
414415 def call_tool (self ):
416+ """Register a tool call handler.
417+
418+ The handler validates input against inputSchema, calls the tool function, and processes results:
419+ - Content only: returns as-is
420+ - Dict only: serializes to JSON text and returns as content with structuredContent
421+ - Both: returns content and structuredContent
422+
423+ If outputSchema is defined, validates structuredContent or errors if missing.
424+ """
425+
415426 def decorator (
416427 func : Callable [
417428 ...,
418- Awaitable [Iterable [types .ContentBlock ]],
429+ Awaitable [
430+ Iterable [types .ContentBlock ] | dict [str , Any ] | tuple [Iterable [types .ContentBlock ], dict [str , Any ]]
431+ ],
419432 ],
420433 ):
421434 logger .debug ("Registering handler for CallToolRequest" )
@@ -424,26 +437,53 @@ async def handler(req: types.CallToolRequest):
424437 try :
425438 tool_name = req .params .name
426439 arguments = req .params .arguments or {}
440+ tool = await self ._get_cached_tool_definition (tool_name )
427441
428- # Validate arguments
429- validation_error = await self ._validate_tool_arguments (tool_name , arguments )
430- if validation_error :
431- return types .ServerResult (
432- types .CallToolResult (
433- content = [types .TextContent (type = "text" , text = validation_error )],
434- isError = True ,
435- )
436- )
442+ # input validation
443+ if tool :
444+ try :
445+ jsonschema .validate (instance = arguments , schema = tool .inputSchema )
446+ except jsonschema .ValidationError as e :
447+ return self ._make_error_result (f"Input validation error: { e .message } " )
437448
449+ # tool call
438450 results = await func (tool_name , arguments )
439- return types .ServerResult (types .CallToolResult (content = list (results ), isError = False ))
440- except Exception as e :
451+
452+ # output normalization
453+ content : list [types .ContentBlock ]
454+ structured_content : dict [str , Any ] | None
455+
456+ if isinstance (results , tuple ) and len (results ) == 2 :
457+ # tool returned both content and structured content
458+ structured_content = cast (dict [str , Any ], results [1 ])
459+ content = list (cast (Iterable [types .ContentBlock ], results [0 ]))
460+ elif isinstance (results , dict ):
461+ # tool returned structured content only
462+ structured_content = cast (dict [str , Any ], results )
463+ content = [types .TextContent (type = "text" , text = json .dumps (results , indent = 2 ))]
464+ else :
465+ # tool returned content only
466+ structured_content = None
467+ content = list (cast (Iterable [types .ContentBlock ], results ))
468+
469+ # output validation
470+ if tool and tool .outputSchema is not None :
471+ if structured_content is None :
472+ return self ._make_error_result (
473+ "Output validation error: outputSchema defined but no structured output returned"
474+ )
475+ else :
476+ try :
477+ jsonschema .validate (instance = structured_content , schema = tool .outputSchema )
478+ except jsonschema .ValidationError as e :
479+ return self ._make_error_result (f"Output validation error: { e .message } " )
480+
481+ # result
441482 return types .ServerResult (
442- types .CallToolResult (
443- content = [types .TextContent (type = "text" , text = str (e ))],
444- isError = True ,
445- )
483+ types .CallToolResult (content = content , structuredContent = structured_content , isError = False )
446484 )
485+ except Exception as e :
486+ return self ._make_error_result (str (e ))
447487
448488 self .request_handlers [types .CallToolRequest ] = handler
449489 return func
0 commit comments