diff --git a/docs/clients/tools.mdx b/docs/clients/tools.mdx index 3821725cb..68d4424bc 100644 --- a/docs/clients/tools.mdx +++ b/docs/clients/tools.mdx @@ -37,10 +37,13 @@ Execute a tool using `call_tool()` with the tool name and arguments: async with client: # Simple tool call result = await client.call_tool("add", {"a": 5, "b": 3}) - # result -> list[mcp.types.TextContent | mcp.types.ImageContent | ...] + # result -> CallToolResult with structured and unstructured data - # Access the result content - print(result[0].text) # Assuming TextContent, e.g., '8' + # Access structured data (automatically deserialized) + print(result.data) # 8 (int) or {"result": 8} for primitive types + + # Access traditional content blocks + print(result.content[0].text) # "8" (TextContent) ``` ### Advanced Execution Options @@ -72,21 +75,97 @@ async with client: ## Handling Results -Tool execution returns a list of content objects. The most common types are: + + +Tool execution returns a `CallToolResult` object with both structured and traditional content. FastMCP's standout feature is the `.data` property, which doesn't just provide raw JSON but actually hydrates complete Python objects including complex types like datetimes, UUIDs, and custom classes. + +### CallToolResult Properties + + + + **FastMCP exclusive**: Fully hydrated Python objects with complex type support (datetimes, UUIDs, custom classes). Goes beyond JSON to provide complete object reconstruction from output schemas. + + + + Standard MCP content blocks (`TextContent`, `ImageContent`, `AudioContent`, etc.) available from all MCP servers. + + + + Standard MCP structured JSON data as sent by the server, available from all MCP servers that support structured outputs. + -- **`TextContent`**: Text-based results with a `.text` attribute -- **`ImageContent`**: Image data with image-specific attributes -- **`BlobContent`**: Binary data content + + Boolean indicating if the tool execution failed. + + + +### Structured Data Access + +FastMCP's `.data` property provides fully hydrated Python objects, not just JSON dictionaries. This includes complex type reconstruction: ```python +from datetime import datetime +from uuid import UUID + async with client: result = await client.call_tool("get_weather", {"city": "London"}) - for content in result: - if hasattr(content, 'text'): - print(f"Text result: {content.text}") - elif hasattr(content, 'data'): - print(f"Binary data: {len(content.data)} bytes") + # FastMCP reconstructs complete Python objects from the server's output schema + weather = result.data # Server-defined WeatherReport object + print(f"Temperature: {weather.temperature}°C at {weather.timestamp}") + print(f"Station: {weather.station_id}") + print(f"Humidity: {weather.humidity}%") + + # The timestamp is a real datetime object, not a string! + assert isinstance(weather.timestamp, datetime) + assert isinstance(weather.station_id, UUID) + + # Compare with raw structured JSON (standard MCP) + print(f"Raw JSON: {result.structured_content}") + # {"temperature": 20, "timestamp": "2024-01-15T14:30:00Z", "station_id": "123e4567-..."} + + # Traditional content blocks (standard MCP) + print(f"Text content: {result.content[0].text}") +``` + +### Fallback Behavior + +For tools without output schemas or when deserialization fails, `.data` will be `None`: + +```python +async with client: + result = await client.call_tool("legacy_tool", {"param": "value"}) + + if result.data is not None: + # Structured output available and successfully deserialized + print(f"Structured: {result.data}") + else: + # No structured output or deserialization failed - use content blocks + for content in result.content: + if hasattr(content, 'text'): + print(f"Text result: {content.text}") + elif hasattr(content, 'data'): + print(f"Binary data: {len(content.data)} bytes") +``` + +### Primitive Type Unwrapping + + +FastMCP servers automatically wrap non-object results (like `int`, `str`, `bool`) in a `{"result": value}` structure to create valid structured outputs. FastMCP clients understand this convention and automatically unwrap the value in `.data` for convenience, so you get the original primitive value instead of a wrapper object. + + +```python +async with client: + result = await client.call_tool("calculate_sum", {"a": 5, "b": 3}) + + # FastMCP client automatically unwraps for convenience + print(result.data) # 8 (int) - the original value + + # Raw structured content shows the server-side wrapping + print(result.structured_content) # {"result": 8} + + # Other MCP clients would need to manually access ["result"] + # value = result.structured_content["result"] # Not needed with FastMCP! ``` ## Error Handling @@ -101,14 +180,32 @@ from fastmcp.exceptions import ToolError async with client: try: result = await client.call_tool("potentially_failing_tool", {"param": "value"}) - print("Tool succeeded:", result) + print("Tool succeeded:", result.data) except ToolError as e: print(f"Tool failed: {e}") ``` ### Manual Error Checking -For more granular control, use `call_tool_mcp()` which returns the raw MCP protocol object with an `isError` flag: +You can disable automatic error raising and manually check the result: + +```python +async with client: + result = await client.call_tool( + "potentially_failing_tool", + {"param": "value"}, + raise_on_error=False + ) + + if result.is_error: + print(f"Tool failed: {result.content[0].text}") + else: + print(f"Tool succeeded: {result.data}") +``` + +### Raw MCP Protocol Access + +For complete control, use `call_tool_mcp()` which returns the raw MCP protocol object: ```python async with client: @@ -119,6 +216,7 @@ async with client: print(f"Tool failed: {result.content}") else: print(f"Tool succeeded: {result.content}") + # Note: No automatic deserialization with call_tool_mcp() ``` ## Argument Handling diff --git a/docs/patterns/tool-transformation.mdx b/docs/patterns/tool-transformation.mdx index f736f791c..c9528282e 100644 --- a/docs/patterns/tool-transformation.mdx +++ b/docs/patterns/tool-transformation.mdx @@ -89,6 +89,7 @@ The `Tool.from_tool()` class method is the primary way to create a transformed t - `description`: An optional description for the new tool. - `transform_args`: A dictionary of `ArgTransform` objects, one for each argument you want to modify. - `transform_fn`: An optional function that will be called instead of the parent tool's logic. +- `output_schema`: Control output schema and structured outputs (see [Output Schema Control](#output-schema-control)). - `tags`: An optional set of tags for the new tool. - `annotations`: An optional set of `ToolAnnotations` for the new tool. - `serializer`: An optional function that will be called to serialize the result of the new tool. @@ -439,7 +440,44 @@ mcp.add_tool(new_tool) In the above example, `**kwargs` receives the renamed argument `b`, not the original argument `y`. It is therefore recommended to use with `forward()`, not `forward_raw()`. - + + +## Output Schema Control + + + +Transformed tools inherit output schemas from their parent by default, but you can control this behavior: + +**Inherit from Parent (Default)** +```python +Tool.from_tool(parent_tool, name="renamed_tool") +``` +The transformed tool automatically uses the parent tool's output schema and structured output behavior. + +**Custom Output Schema** +```python +Tool.from_tool(parent_tool, output_schema={ + "type": "object", + "properties": {"status": {"type": "string"}} +}) +``` +Provide your own schema that differs from the parent. The tool must return data matching this schema. + +**Remove Output Schema** +```python +Tool.from_tool(parent_tool, output_schema=False) +``` +Removes the output schema declaration. Automatic structured content still works for object-like returns (dict, dataclass, Pydantic models) but primitive types won't be structured. + +**Full Control with Transform Functions** +```python +async def custom_output(**kwargs) -> ToolResult: + result = await forward(**kwargs) + return ToolResult(content=[...], structured_content={...}) + +Tool.from_tool(parent_tool, transform_fn=custom_output) +``` +Use a transform function returning `ToolResult` for complete control over both content blocks and structured outputs. ## Common Patterns diff --git a/docs/servers/tools.mdx b/docs/servers/tools.mdx index 070b4bf8c..7ba5535a4 100644 --- a/docs/servers/tools.mdx +++ b/docs/servers/tools.mdx @@ -288,55 +288,229 @@ Use `async def` when your tool needs to perform operations that might wait for e ### Return Values -FastMCP automatically converts the value returned by your function into the appropriate MCP content format for the client: -- **`str`**: Sent as `TextContent`. -- **`dict`, `list`, Pydantic `BaseModel`**: Serialized to a JSON string and sent as `TextContent`. -- **`bytes`**: Base64 encoded and sent as `BlobResourceContents` (often within an `EmbeddedResource`). -- **`fastmcp.utilities.types.Image`**: A helper class for easily returning image data. Sent as `ImageContent`. -- **`fastmcp.utilities.types.Audio`**: A helper class for easily returning audio data. Sent as `AudioContent`. -- **`fastmcp.utilities.types.File`**: A helper class for easily returning binary data as base64-encoded content. Sent as `EmbeddedResource`. -- **A list of any of the above**: Automatically converts each item appropriately. -- **`None`**: Results in an empty response (no content is sent back to the client). +FastMCP tools can return data in two complementary formats: **traditional content blocks** (like text and images) and **structured outputs** (machine-readable JSON). When you add return type annotations, FastMCP automatically generates **output schemas** to validate the structured data and enables clients to deserialize results back to Python objects. -FastMCP will attempt to serialize other types to a string if possible. +Understanding how these three concepts work together: - -At this time, FastMCP responds only to your tool's return *value*, not its return *annotation*. - +- **Return Values**: What your Python function returns (determines both content blocks and structured data) +- **Structured Outputs**: JSON data sent alongside traditional content for machine processing +- **Output Schemas**: JSON Schema declarations that describe and validate the structured output format -```python +The following sections explain each concept in detail. + +#### Content Blocks + +FastMCP automatically converts tool return values into appropriate MCP content blocks: + +- **`str`**: Sent as `TextContent` +- **`bytes`**: Base64 encoded and sent as `BlobResourceContents` (within an `EmbeddedResource`) +- **`fastmcp.utilities.types.Image`**: Sent as `ImageContent` +- **`fastmcp.utilities.types.Audio`**: Sent as `AudioContent` +- **`fastmcp.utilities.types.File`**: Sent as base64-encoded `EmbeddedResource` +- **A list of any of the above**: Converts each item appropriately +- **`None`**: Results in an empty response + +#### Structured Output + + + +The 6/18/2025 MCP spec update [introduced](https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content) structured content, which is a new way to return data from tools. Structured content is a JSON object that is sent alongside traditional content. FastMCP automatically creates structured outputs alongside traditional content when your tool returns data that has a JSON object representation. This provides machine-readable JSON data that clients can deserialize back to Python objects. + +**Automatic Structured Content Rules:** +- **Object-like results** (`dict`, Pydantic models, dataclasses) → Always become structured content (even without output schema) +- **Non-object results** (`int`, `str`, `list`) → Only become structured content if there's an output schema to validate/serialize them +- **All results** → Always become traditional content blocks for backward compatibility + + +This automatic behavior enables clients to receive machine-readable data alongside human-readable content without requiring explicit output schemas for object-like returns. + + +##### Object-like Results (Automatic Structured Content) + + +```python Dict Return (No Schema Needed) +@mcp.tool +def get_user_data(user_id: str) -> dict: + """Get user data without type annotation.""" + return {"name": "Alice", "age": 30, "active": True} +``` + +```json Traditional Content +"{\n \"name\": \"Alice\",\n \"age\": 30,\n \"active\": true\n}" +``` + +```json Structured Content (Automatic) +{ + "name": "Alice", + "age": 30, + "active": true +} +``` + + +##### Non-object Results (Schema Required) + + +```python Integer Return (No Schema) +@mcp.tool +def calculate_sum(a: int, b: int): + """Calculate sum without return annotation.""" + return a + b # Returns 8 +``` + +```json Traditional Content Only +"8" +``` + +```python Integer Return (With Schema) +@mcp.tool +def calculate_sum(a: int, b: int) -> int: + """Calculate sum with return annotation.""" + return a + b # Returns 8 +``` + +```json Traditional Content +"8" +``` + +```json Structured Content (From Schema) +{ + "result": 8 +} +``` + + +##### Complex Type Example + + +```python Tool Definition +from dataclasses import dataclass from fastmcp import FastMCP -from fastmcp.utilities.types import Image -import io -try: - from PIL import Image as PILImage -except ImportError: - raise ImportError("Please install the `pillow` library to run this example.") +mcp = FastMCP() -mcp = FastMCP("Image Demo") +@dataclass +class Person: + name: str + age: int + email: str @mcp.tool -def generate_image(width: int, height: int, color: str) -> Image: - """Generates a solid color image.""" - # Create image using Pillow - img = PILImage.new("RGB", (width, height), color=color) +def get_user_profile(user_id: str) -> Person: + """Get a user's profile information.""" + return Person(name="Alice", age=30, email="alice@example.com") +``` + +```json Generated Output Schema +{ + "properties": { + "name": {"title": "Name", "type": "string"}, + "age": {"title": "Age", "type": "integer"}, + "email": {"title": "Email", "type": "string"} + }, + "required": ["name", "age", "email"], + "title": "Person", + "type": "object" +} +``` + +```json Structured Output +{ + "name": "Alice", + "age": 30, + "email": "alice@example.com" +} +``` + + +#### Output Schemas + + + +The 6/18/2025 MCP spec update [introduced](https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema) output schemas, which are a new way to describe the expected output format of a tool. When an output schema is provided, the tool *must* return structured output that matches the schema. + +When you add return type annotations to your functions, FastMCP automatically generates JSON schemas that describe the expected output format. These schemas help MCP clients understand and validate the structured data they receive. - # Save to a bytes buffer - buffer = io.BytesIO() - img.save(buffer, format="PNG") - img_bytes = buffer.getvalue() +##### Primitive Type Wrapping - # Return using FastMCP's Image helper - return Image(data=img_bytes, format="png") +For primitive return types (like `int`, `str`, `bool`), FastMCP automatically wraps the result under a `"result"` key to create valid structured output: + +```python Primitive Return Type @mcp.tool -def do_nothing() -> None: - """This tool performs an action but returns no data.""" - print("Performing a side effect...") - return None +def calculate_sum(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b +``` + +```json Generated Schema (Wrapped) +{ + "type": "object", + "properties": { + "result": {"type": "integer"} + }, + "x-fastmcp-wrap-result": true +} +``` + +```json Structured Output +{ + "result": 8 +} ``` + + +##### Manual Schema Control + +You can override the automatically generated schema by providing a custom `output_schema`: + +```python +@mcp.tool(output_schema={ + "type": "object", + "properties": { + "data": {"type": "string"}, + "metadata": {"type": "object"} + } +}) +def custom_schema_tool() -> dict: + """Tool with custom output schema.""" + return {"data": "Hello", "metadata": {"version": "1.0"}} +``` + +Schema generation works for most common types including basic types, collections, union types, Pydantic models, TypedDict structures, and dataclasses. + + +**Important Constraints**: +- Output schemas must be object types (`"type": "object"`) +- If you provide an output schema, your tool **must** return structured output that matches it +- However, you can provide structured output without an output schema (using `ToolResult`) + + +#### Full Control with ToolResult + +For complete control over both traditional content and structured output, return a `ToolResult` object: + +```python +from fastmcp.tools.tool import ToolResult + +@mcp.tool +def advanced_tool() -> ToolResult: + """Tool with full control over output.""" + return ToolResult( + content=[TextContent(text="Human-readable summary")], + structured_content={"data": "value", "count": 42} + ) +``` + +When returning `ToolResult`: +- You control exactly what content and structured data is sent +- Output schemas are optional - structured content can be provided without a schema +- Clients receive both traditional content blocks and structured data + + +If your return type annotation cannot be converted to a JSON schema (e.g., complex custom classes without Pydantic support), the output schema will be omitted but the tool will still function normally with traditional content. + ### Error Handling diff --git a/pyproject.toml b/pyproject.toml index 86108e1d2..d4fbbed4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "rich>=13.9.4", "typer>=0.15.2", "authlib>=1.5.2", + "pydantic[email]>=2.11.7", ] requires-python = ">=3.10" readme = "README.md" diff --git a/src/fastmcp/client/client.py b/src/fastmcp/client/client.py index dbf088606..f160e978a 100644 --- a/src/fastmcp/client/client.py +++ b/src/fastmcp/client/client.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import asyncio import datetime from contextlib import AsyncExitStack, asynccontextmanager +from dataclasses import dataclass from pathlib import Path from typing import Any, Generic, Literal, cast, overload @@ -10,7 +13,6 @@ import pydantic_core from exceptiongroup import catch from mcp import ClientSession -from mcp.types import ContentBlock from pydantic import AnyUrl import fastmcp @@ -30,7 +32,10 @@ from fastmcp.exceptions import ToolError from fastmcp.server import FastMCP from fastmcp.utilities.exceptions import get_catch_handlers +from fastmcp.utilities.json_schema_type import json_schema_to_type +from fastmcp.utilities.logging import get_logger from fastmcp.utilities.mcp_config import MCPConfig +from fastmcp.utilities.types import get_cached_typeadapter from .transports import ( ClientTransportT, @@ -56,6 +61,8 @@ "ProgressHandler", ] +logger = get_logger(__name__) + class Client(Generic[ClientTransportT]): """ @@ -99,34 +106,39 @@ def __new__( cls, transport: ClientTransportT, **kwargs: Any, - ) -> "Client[ClientTransportT]": ... + ) -> Client[ClientTransportT]: ... @overload def __new__( cls, transport: AnyUrl, **kwargs - ) -> "Client[SSETransport|StreamableHttpTransport]": ... + ) -> Client[SSETransport | StreamableHttpTransport]: ... @overload def __new__( cls, transport: FastMCP | FastMCP1Server, **kwargs - ) -> "Client[FastMCPTransport]": ... + ) -> Client[FastMCPTransport]: ... @overload def __new__( cls, transport: Path, **kwargs - ) -> "Client[PythonStdioTransport|NodeStdioTransport]": ... + ) -> Client[PythonStdioTransport | NodeStdioTransport]: ... @overload def __new__( cls, transport: MCPConfig | dict[str, Any], **kwargs - ) -> "Client[MCPConfigTransport]": ... + ) -> Client[MCPConfigTransport]: ... @overload def __new__( cls, transport: str, **kwargs - ) -> "Client[PythonStdioTransport|NodeStdioTransport|SSETransport|StreamableHttpTransport]": ... - - def __new__(cls, transport, **kwargs) -> "Client": + ) -> Client[ + PythonStdioTransport + | NodeStdioTransport + | SSETransport + | StreamableHttpTransport + ]: ... + + def __new__(cls, transport, **kwargs) -> Client: instance = super().__new__(cls) return instance @@ -675,7 +687,8 @@ async def call_tool( arguments: dict[str, Any] | None = None, timeout: datetime.timedelta | float | int | None = None, progress_handler: ProgressHandler | None = None, - ) -> list[ContentBlock]: + raise_on_error: bool = True, + ) -> CallToolResult: """Call a tool on the server. Unlike call_tool_mcp, this method raises a ToolError if the tool call results in an error. @@ -687,8 +700,13 @@ async def call_tool( progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None. Returns: - list[mcp.types.TextContent | mcp.types.ImageContent | mcp.types.AudioContent | mcp.types.EmbeddedResource]: - The content returned by the tool. + CallToolResult: + The content returned by the tool. If the tool returns structured + outputs, they are returned as a dataclass (if an output schema + is available) or a dictionary; otherwise, a list of content + blocks is returned. Note: to receive both structured and + unstructured outputs, use call_tool_mcp instead and access the + raw result object. Raises: ToolError: If the tool call results in an error. @@ -700,7 +718,43 @@ async def call_tool( timeout=timeout, progress_handler=progress_handler, ) - if result.isError: + data = None + if result.isError and raise_on_error: msg = cast(mcp.types.TextContent, result.content[0]).text raise ToolError(msg) - return result.content + elif result.structuredContent: + try: + if name not in self.session._tool_output_schemas: + await self.session.list_tools() + if name in self.session._tool_output_schemas: + output_schema = self.session._tool_output_schemas.get(name) + if output_schema: + if output_schema.get("x-fastmcp-wrap-result"): + output_schema = output_schema.get("properties", {}).get( + "result" + ) + structured_content = result.structuredContent.get("result") + else: + structured_content = result.structuredContent + output_type = json_schema_to_type(output_schema) + type_adapter = get_cached_typeadapter(output_type) + data = type_adapter.validate_python(structured_content) + else: + data = result.structuredContent + except Exception as e: + logger.error(f"Error parsing structured content: {e}") + + return CallToolResult( + content=result.content, + structured_content=result.structuredContent, + data=data, + is_error=result.isError, + ) + + +@dataclass +class CallToolResult: + content: list[mcp.types.ContentBlock] + structured_content: dict[str, Any] | None + data: Any = None + is_error: bool = False diff --git a/src/fastmcp/server/low_level.py b/src/fastmcp/server/low_level.py index 620abe713..7dd3e9d4b 100644 --- a/src/fastmcp/server/low_level.py +++ b/src/fastmcp/server/low_level.py @@ -4,12 +4,14 @@ LifespanResultT, NotificationOptions, RequestT, - Server, +) +from mcp.server.lowlevel.server import ( + Server as _Server, ) from mcp.server.models import InitializationOptions -class LowLevelServer(Server[LifespanResultT, RequestT]): +class LowLevelServer(_Server[LifespanResultT, RequestT]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # FastMCP servers support notifications for all components diff --git a/src/fastmcp/server/openapi.py b/src/fastmcp/server/openapi.py index 02a68de5d..1e8051ace 100644 --- a/src/fastmcp/server/openapi.py +++ b/src/fastmcp/server/openapi.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Literal import httpx -from mcp.types import ContentBlock, ToolAnnotations +from mcp.types import ToolAnnotations from pydantic.networks import AnyUrl import fastmcp @@ -21,7 +21,7 @@ from fastmcp.resources import Resource, ResourceTemplate from fastmcp.server.dependencies import get_http_headers from fastmcp.server.server import FastMCP -from fastmcp.tools.tool import Tool, _convert_to_content +from fastmcp.tools.tool import Tool, ToolResult from fastmcp.utilities import openapi from fastmcp.utilities.logging import get_logger from fastmcp.utilities.openapi import ( @@ -254,7 +254,7 @@ def __repr__(self) -> str: """Custom representation to prevent recursion errors when printing.""" return f"OpenAPITool(name={self.name!r}, method={self._route.method}, path={self._route.path})" - async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: + async def run(self, arguments: dict[str, Any]) -> ToolResult: """Execute the HTTP request based on the route configuration.""" # Prepare URL @@ -450,10 +450,11 @@ async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: # Try to parse as JSON first try: result = response.json() - except (json.JSONDecodeError, ValueError): - # Return text content if not JSON - result = response.text - return _convert_to_content(result) + if not isinstance(result, dict): + result = {"result": result} + return ToolResult(structured_content=result) + except json.JSONDecodeError: + return ToolResult(content=response.text) except httpx.HTTPStatusError as e: # Handle HTTP errors (4xx, 5xx) diff --git a/src/fastmcp/server/proxy.py b/src/fastmcp/server/proxy.py index f80c9c38e..07650fc27 100644 --- a/src/fastmcp/server/proxy.py +++ b/src/fastmcp/server/proxy.py @@ -8,7 +8,6 @@ from mcp.types import ( METHOD_NOT_FOUND, BlobResourceContents, - ContentBlock, GetPromptResult, TextResourceContents, ) @@ -23,7 +22,7 @@ from fastmcp.resources.resource_manager import ResourceManager from fastmcp.server.context import Context from fastmcp.server.server import FastMCP -from fastmcp.tools.tool import Tool +from fastmcp.tools.tool import Tool, ToolResult from fastmcp.tools.tool_manager import ToolManager from fastmcp.utilities.logging import get_logger @@ -67,9 +66,7 @@ async def list_tools(self) -> list[Tool]: tools_dict = await self.get_tools() return list(tools_dict.values()) - async def call_tool( - self, key: str, arguments: dict[str, Any] - ) -> list[ContentBlock]: + async def call_tool(self, key: str, arguments: dict[str, Any]) -> ToolResult: """Calls a tool, trying local/mounted first, then proxy if not found.""" try: # First try local and mounted tools @@ -77,7 +74,11 @@ async def call_tool( except NotFoundError: # If not found locally, try proxy async with self.client: - return await self.client.call_tool(key, arguments) + result = await self.client.call_tool(key, arguments) + return ToolResult( + content=result.content, + structured_content=result.structured_content, + ) class ProxyResourceManager(ResourceManager): @@ -226,13 +227,14 @@ def from_mcp_tool(cls, client: Client, mcp_tool: mcp.types.Tool) -> ProxyTool: description=mcp_tool.description, parameters=mcp_tool.inputSchema, annotations=mcp_tool.annotations, + output_schema=mcp_tool.outputSchema, ) async def run( self, arguments: dict[str, Any], context: Context | None = None, - ) -> list[ContentBlock]: + ) -> ToolResult: """Executes the tool by making a call through the client.""" # This is where the remote execution logic lives. async with self._client: @@ -242,7 +244,10 @@ async def run( ) if result.isError: raise ToolError(cast(mcp.types.TextContent, result.content[0]).text) - return result.content + return ToolResult( + content=result.content, + structured_content=result.structuredContent, + ) class ProxyResource(Resource): diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 76b1373a6..d3252b771 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -58,11 +58,12 @@ from fastmcp.server.middleware import Middleware, MiddlewareContext from fastmcp.settings import Settings from fastmcp.tools import ToolManager -from fastmcp.tools.tool import FunctionTool, Tool +from fastmcp.tools.tool import FunctionTool, Tool, ToolResult from fastmcp.utilities.cache import TimedCache from fastmcp.utilities.components import FastMCPComponent from fastmcp.utilities.logging import get_logger from fastmcp.utilities.mcp_config import MCPConfig +from fastmcp.utilities.types import NotSet, NotSetT if TYPE_CHECKING: from fastmcp.client import Client @@ -592,7 +593,7 @@ async def _handler( async def _mcp_call_tool( self, key: str, arguments: dict[str, Any] - ) -> list[ContentBlock]: + ) -> list[ContentBlock] | tuple[list[ContentBlock], dict[str, Any]]: """ Handle MCP 'callTool' requests. @@ -609,22 +610,21 @@ async def _mcp_call_tool( async with fastmcp.server.context.Context(fastmcp=self): try: - return await self._call_tool(key, arguments) + result = await self._call_tool(key, arguments) + return result.to_mcp_result() except DisabledError: raise NotFoundError(f"Unknown tool: {key}") except NotFoundError: raise NotFoundError(f"Unknown tool: {key}") - async def _call_tool( - self, key: str, arguments: dict[str, Any] - ) -> list[ContentBlock]: + async def _call_tool(self, key: str, arguments: dict[str, Any]) -> ToolResult: """ Applies this server's middleware and delegates the filtered call to the manager. """ async def _handler( context: MiddlewareContext[mcp.types.CallToolRequestParams], - ) -> list[ContentBlock]: + ) -> ToolResult: tool = await self._tool_manager.get_tool(context.message.name) if not self._should_enable_component(tool): raise NotFoundError(f"Unknown tool: {context.message.name!r}") @@ -792,6 +792,7 @@ def tool( name: str | None = None, description: str | None = None, tags: set[str] | None = None, + output_schema: dict[str, Any] | None | NotSetT = NotSet, annotations: ToolAnnotations | dict[str, Any] | None = None, exclude_args: list[str] | None = None, enabled: bool | None = None, @@ -805,6 +806,7 @@ def tool( name: str | None = None, description: str | None = None, tags: set[str] | None = None, + output_schema: dict[str, Any] | None | NotSetT = NotSet, annotations: ToolAnnotations | dict[str, Any] | None = None, exclude_args: list[str] | None = None, enabled: bool | None = None, @@ -817,6 +819,7 @@ def tool( name: str | None = None, description: str | None = None, tags: set[str] | None = None, + output_schema: dict[str, Any] | None | NotSetT = NotSet, annotations: ToolAnnotations | dict[str, Any] | None = None, exclude_args: list[str] | None = None, enabled: bool | None = None, @@ -839,6 +842,7 @@ def tool( name: Optional name for the tool (keyword-only, alternative to name_or_fn) description: Optional description of what the tool does tags: Optional set of tags for categorizing the tool + output_schema: Optional JSON schema for the tool's output annotations: Optional annotations about the tool's behavior exclude_args: Optional list of argument names to exclude from the tool schema enabled: Optional boolean to enable or disable the tool @@ -895,6 +899,7 @@ def my_tool(x: int) -> str: name=tool_name, description=description, tags=tags, + output_schema=output_schema, annotations=annotations, exclude_args=exclude_args, serializer=self._tool_serializer, @@ -925,6 +930,7 @@ def my_tool(x: int) -> str: name=tool_name, description=description, tags=tags, + output_schema=output_schema, annotations=annotations, exclude_args=exclude_args, enabled=enabled, diff --git a/src/fastmcp/tools/tool.py b/src/fastmcp/tools/tool.py index 49588df28..35dec76f6 100644 --- a/src/fastmcp/tools/tool.py +++ b/src/fastmcp/tools/tool.py @@ -3,12 +3,13 @@ import inspect from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Annotated, Any, Literal +import mcp.types import pydantic_core from mcp.types import ContentBlock, TextContent, ToolAnnotations from mcp.types import Tool as MCPTool -from pydantic import Field +from pydantic import Field, PydanticSchemaGenerationError from fastmcp.server.dependencies import get_context from fastmcp.utilities.components import FastMCPComponent @@ -18,8 +19,11 @@ Audio, File, Image, + NotSet, + NotSetT, find_kwarg_by_type, get_cached_typeadapter, + replace_type, ) if TYPE_CHECKING: @@ -28,20 +32,91 @@ logger = get_logger(__name__) +class _UnserializableType: + pass + + def default_serializer(data: Any) -> str: return pydantic_core.to_json(data, fallback=str, indent=2).decode() +def _wrap_schema_if_needed(schema: dict[str, Any] | None) -> dict[str, Any] | None: + """Wrap non-object schemas with result property for structured output. + + This wrapping allows primitive types (int, str, etc.) to be returned as + structured content by placing them under a "result" key. + + Args: + schema: The JSON schema to potentially wrap + + Returns: + Wrapped schema if needed, or original schema if already an object type + """ + if schema and schema.get("type") != "object": + return { + "type": "object", + "properties": {"result": schema}, + "x-fastmcp-wrap-result": True, + } + return schema + + +class ToolResult: + def __init__( + self, + content: list[ContentBlock] | Any | None = None, + structured_content: dict[str, Any] | Any | None = None, + ): + if content is None and structured_content is None: + raise ValueError("Either content or structured_content must be provided") + elif content is None: + content = structured_content + + self.content = _convert_to_content(content) + + if structured_content is not None: + try: + structured_content = pydantic_core.to_jsonable_python( + structured_content + ) + except pydantic_core.PydanticSerializationError as e: + logger.error( + f"Could not serialize structured content. If this is unexpected, set your tool's output_schema to None to disable automatic serialization: {e}" + ) + raise + if not isinstance(structured_content, dict): + raise ValueError( + "structured_content must be a dict or None. " + f"Got {type(structured_content).__name__}: {structured_content!r}. " + "Tools should wrap non-dict values based on their output_schema." + ) + self.structured_content: dict[str, Any] | None = structured_content + + def to_mcp_result( + self, + ) -> list[ContentBlock] | tuple[list[ContentBlock], dict[str, Any]]: + if self.structured_content is None: + return self.content + return self.content, self.structured_content + + class Tool(FastMCPComponent): """Internal tool registration info.""" - parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") - annotations: ToolAnnotations | None = Field( - default=None, description="Additional annotations about the tool" - ) - serializer: Callable[[Any], str] | None = Field( - default=None, description="Optional custom serializer for tool results" - ) + parameters: Annotated[ + dict[str, Any], Field(description="JSON schema for tool parameters") + ] + output_schema: Annotated[ + dict[str, Any] | None, Field(description="JSON schema for tool output") + ] = None + annotations: Annotated[ + ToolAnnotations | None, + Field(description="Additional annotations about the tool"), + ] = None + serializer: Annotated[ + Callable[[Any], str] | None, + Field(description="Optional custom serializer for tool results"), + ] = None def enable(self) -> None: super().enable() @@ -64,6 +139,7 @@ def to_mcp_tool(self, **overrides: Any) -> MCPTool: "name": self.name, "description": self.description, "inputSchema": self.parameters, + "outputSchema": self.output_schema, "annotations": self.annotations, } return MCPTool(**kwargs | overrides) @@ -76,6 +152,7 @@ def from_function( tags: set[str] | None = None, annotations: ToolAnnotations | None = None, exclude_args: list[str] | None = None, + output_schema: dict[str, Any] | None | NotSetT | Literal[False] = NotSet, serializer: Callable[[Any], str] | None = None, enabled: bool | None = None, ) -> FunctionTool: @@ -87,12 +164,21 @@ def from_function( tags=tags, annotations=annotations, exclude_args=exclude_args, + output_schema=output_schema, serializer=serializer, enabled=enabled, ) - async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: - """Run the tool with arguments.""" + async def run(self, arguments: dict[str, Any]) -> ToolResult: + """ + Run the tool with arguments. + + This method is not implemented in the base Tool class and must be + implemented by subclasses. + + `run()` can EITHER return a list of ContentBlocks, or a tuple of + (list of ContentBlocks, dict of structured output). + """ raise NotImplementedError("Subclasses must implement run()") @classmethod @@ -105,6 +191,7 @@ def from_tool( description: str | None = None, tags: set[str] | None = None, annotations: ToolAnnotations | None = None, + output_schema: dict[str, Any] | None | Literal[False] = None, serializer: Callable[[Any], str] | None = None, enabled: bool | None = None, ) -> TransformedTool: @@ -118,6 +205,7 @@ def from_tool( description=description, tags=tags, annotations=annotations, + output_schema=output_schema, serializer=serializer, enabled=enabled, ) @@ -135,6 +223,7 @@ def from_function( tags: set[str] | None = None, annotations: ToolAnnotations | None = None, exclude_args: list[str] | None = None, + output_schema: dict[str, Any] | None | NotSetT | Literal[False] = NotSet, serializer: Callable[[Any], str] | None = None, enabled: bool | None = None, ) -> FunctionTool: @@ -145,18 +234,32 @@ def from_function( if name is None and parsed_fn.name == "": raise ValueError("You must provide a name for lambda functions") + if isinstance(output_schema, NotSetT): + output_schema = _wrap_schema_if_needed(parsed_fn.output_schema) + elif output_schema is False: + output_schema = None + # Note: explicit schemas (dict) are used as-is without auto-wrapping + + # Validate that explicit schemas are object type for structured content + if output_schema is not None and isinstance(output_schema, dict): + if output_schema.get("type") != "object": + raise ValueError( + f'Output schemas must have "type" set to "object" due to MCP spec limitations. Received: {output_schema!r}' + ) + return cls( fn=parsed_fn.fn, name=name or parsed_fn.name, description=description or parsed_fn.description, - parameters=parsed_fn.parameters, - tags=tags or set(), + parameters=parsed_fn.input_schema, + output_schema=output_schema, annotations=annotations, + tags=tags or set(), serializer=serializer, enabled=enabled if enabled is not None else True, ) - async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: + async def run(self, arguments: dict[str, Any]) -> ToolResult: """Run the tool with arguments.""" from fastmcp.server.context import Context @@ -168,10 +271,37 @@ async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: type_adapter = get_cached_typeadapter(self.fn) result = type_adapter.validate_python(arguments) + if inspect.isawaitable(result): result = await result - return _convert_to_content(result, serializer=self.serializer) + if isinstance(result, ToolResult): + return result + + unstructured_result = _convert_to_content(result, serializer=self.serializer) + + structured_output = None + # First handle structured content based on output schema, if any + if self.output_schema is not None: + if self.output_schema.get("x-fastmcp-wrap-result"): + # Schema says wrap - always wrap in result key + structured_output = {"result": result} + else: + structured_output = result + # If no output schema, try to serialize the result. If it is a dict, use + # it as structured content. If it is not a dict, ignore it. + if structured_output is None: + try: + structured_output = pydantic_core.to_jsonable_python(result) + if not isinstance(structured_output, dict): + structured_output = None + except Exception: + pass + + return ToolResult( + content=unstructured_result, + structured_content=structured_output, + ) @dataclass @@ -179,13 +309,15 @@ class ParsedFunction: fn: Callable[..., Any] name: str description: str | None - parameters: dict[str, Any] + input_schema: dict[str, Any] + output_schema: dict[str, Any] | None @classmethod def from_function( cls, fn: Callable[..., Any], exclude_args: list[str] | None = None, + ignore_response_types: list[type] | None = None, validate: bool = True, ) -> ParsedFunction: from fastmcp.server.context import Context @@ -225,9 +357,6 @@ def from_function( if isinstance(fn, staticmethod): fn = fn.__func__ - type_adapter = get_cached_typeadapter(fn) - schema = type_adapter.json_schema() - prune_params: list[str] = [] context_kwarg = find_kwarg_by_type(fn, kwarg_type=Context) if context_kwarg: @@ -235,12 +364,65 @@ def from_function( if exclude_args: prune_params.extend(exclude_args) - schema = compress_schema(schema, prune_params=prune_params) + input_type_adapter = get_cached_typeadapter(fn) + input_schema = input_type_adapter.json_schema() + input_schema = compress_schema(input_schema, prune_params=prune_params) + + output_schema = None + output_type = inspect.signature(fn).return_annotation + + if output_type not in (inspect._empty, None, Any, ...): + # there are a variety of types that we don't want to attempt to + # serialize because they are either used by FastMCP internally, + # or are MCP content types that explicitly don't form structured + # content. By replacing them with an explicitly unserializable type, + # we ensure that no output schema is automatically generated. + output_type = replace_type( + output_type, + { + t: _UnserializableType + for t in ( + Image, + Audio, + File, + ToolResult, + mcp.types.TextContent, + mcp.types.ImageContent, + mcp.types.AudioContent, + mcp.types.ResourceLink, + mcp.types.EmbeddedResource, + ) + }, + ) + + try: + output_type_adapter = get_cached_typeadapter(output_type) + output_schema = output_type_adapter.json_schema() + except PydanticSchemaGenerationError as e: + if "_UnserializableType" not in str(e): + logger.debug(f"Unable to generate schema for type {output_type!r}") + + return cls( + fn=fn, + name=fn_name, + description=fn_doc, + input_schema=input_schema, + output_schema=output_schema or None, + ) + + try: + output_type_adapter = get_cached_typeadapter(output_type) + output_schema = output_type_adapter.json_schema() + except PydanticSchemaGenerationError as e: + if "_UnserializableType" not in str(e): + logger.debug(f"Unable to generate schema for type {output_type!r}") + return cls( fn=fn, name=fn_name, description=fn_doc, - parameters=schema, + input_schema=input_schema, + output_schema=output_schema or None, ) diff --git a/src/fastmcp/tools/tool_manager.py b/src/fastmcp/tools/tool_manager.py index ee577a450..29bb956c3 100644 --- a/src/fastmcp/tools/tool_manager.py +++ b/src/fastmcp/tools/tool_manager.py @@ -4,12 +4,12 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any -from mcp.types import ContentBlock, ToolAnnotations +from mcp.types import ToolAnnotations from fastmcp import settings from fastmcp.exceptions import NotFoundError, ToolError from fastmcp.settings import DuplicateBehavior -from fastmcp.tools.tool import Tool +from fastmcp.tools.tool import Tool, ToolResult from fastmcp.utilities.logging import get_logger if TYPE_CHECKING: @@ -169,9 +169,7 @@ def remove_tool(self, key: str) -> None: else: raise NotFoundError(f"Tool {key!r} not found") - async def call_tool( - self, key: str, arguments: dict[str, Any] - ) -> list[ContentBlock]: + async def call_tool(self, key: str, arguments: dict[str, Any]) -> ToolResult: """ Internal API for servers: Finds and calls a tool, respecting the filtered protocol path. diff --git a/src/fastmcp/tools/tool_transform.py b/src/fastmcp/tools/tool_transform.py index b2b871513..61d9ecc9e 100644 --- a/src/fastmcp/tools/tool_transform.py +++ b/src/fastmcp/tools/tool_transform.py @@ -4,20 +4,17 @@ from collections.abc import Callable from contextvars import ContextVar from dataclasses import dataclass -from types import EllipsisType from typing import Any, Literal -from mcp.types import ContentBlock, ToolAnnotations +from mcp.types import ToolAnnotations from pydantic import ConfigDict -from fastmcp.tools.tool import ParsedFunction, Tool +from fastmcp.tools.tool import ParsedFunction, Tool, ToolResult, _wrap_schema_if_needed from fastmcp.utilities.logging import get_logger -from fastmcp.utilities.types import get_cached_typeadapter +from fastmcp.utilities.types import NotSet, NotSetT, get_cached_typeadapter logger = get_logger(__name__) -NotSet = ... - # Context variable to store current transformed tool _current_tool: ContextVar[TransformedTool | None] = ContextVar( @@ -25,7 +22,7 @@ ) -async def forward(**kwargs) -> Any: +async def forward(**kwargs) -> ToolResult: """Forward to parent tool with argument transformation applied. This function can only be called from within a transformed tool's custom @@ -41,7 +38,7 @@ async def forward(**kwargs) -> Any: **kwargs: Arguments to forward to the parent tool (using transformed names). Returns: - The result from the parent tool execution. + The ToolResult from the parent tool execution. Raises: RuntimeError: If called outside a transformed tool context. @@ -55,7 +52,7 @@ async def forward(**kwargs) -> Any: return await tool.forwarding_fn(**kwargs) -async def forward_raw(**kwargs) -> Any: +async def forward_raw(**kwargs) -> ToolResult: """Forward directly to parent tool without transformation. This function bypasses all argument transformation and validation, calling the parent @@ -69,7 +66,7 @@ async def forward_raw(**kwargs) -> Any: **kwargs: Arguments to pass directly to the parent tool (using original names). Returns: - The result from the parent tool execution. + The ToolResult from the parent tool execution. Raises: RuntimeError: If called outside a transformed tool context. @@ -151,14 +148,14 @@ class ArgTransform: ``` """ - name: str | EllipsisType = NotSet - description: str | EllipsisType = NotSet - default: Any | EllipsisType = NotSet - default_factory: Callable[[], Any] | EllipsisType = NotSet - type: Any | EllipsisType = NotSet + name: str | NotSetT = NotSet + description: str | NotSetT = NotSet + default: Any | NotSetT = NotSet + default_factory: Callable[[], Any] | NotSetT = NotSet + type: Any | NotSetT = NotSet hide: bool = False - required: Literal[True] | EllipsisType = NotSet - examples: Any | EllipsisType = NotSet + required: Literal[True] | NotSetT = NotSet + examples: Any | NotSetT = NotSet def __post_init__(self): """Validate that only one of default or default_factory is provided.""" @@ -201,11 +198,12 @@ class TransformedTool(Tool): This class represents a tool that has been created by transforming another tool. It supports argument renaming, schema modification, custom function injection, - and provides context for the forward() and forward_raw() functions. + structured output control, and provides context for the forward() and forward_raw() functions. The transformation can be purely schema-based (argument renaming, dropping, etc.) or can include a custom function that uses forward() to call the parent tool - with transformed arguments. + with transformed arguments. Output schemas and structured outputs are automatically + inherited from the parent tool but can be overridden or disabled. Attributes: parent_tool: The original tool that this tool was transformed from. @@ -222,7 +220,7 @@ class TransformedTool(Tool): forwarding_fn: Callable[..., Any] # Always present, handles arg transformation transform_args: dict[str, ArgTransform] - async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: + async def run(self, arguments: dict[str, Any]) -> ToolResult: """Run the tool with context set for forward() functions. This method executes the tool's function while setting up the context @@ -233,8 +231,7 @@ async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: arguments: Dictionary of arguments to pass to the tool's function. Returns: - List of content objects (text, image, or embedded resources) representing - the tool's output. + ToolResult object containing content and optional structured output. """ from fastmcp.tools.tool import _convert_to_content @@ -272,7 +269,57 @@ async def run(self, arguments: dict[str, Any]) -> list[ContentBlock]: token = _current_tool.set(self) try: result = await self.fn(**arguments) - return _convert_to_content(result, serializer=self.serializer) + + # If transform function returns ToolResult, respect our output_schema setting + if isinstance(result, ToolResult): + if self.output_schema is None: + # Check if this is from a custom function that returns ToolResult + import inspect + + return_annotation = inspect.signature(self.fn).return_annotation + if return_annotation is ToolResult: + # Custom function returns ToolResult - preserve its content + return result + else: + # Forwarded call with disabled schema - strip structured content + return ToolResult( + content=result.content, + structured_content=None, + ) + elif self.output_schema.get( + "type" + ) != "object" and not self.output_schema.get("x-fastmcp-wrap-result"): + # Non-object explicit schemas disable structured content + return ToolResult( + content=result.content, + structured_content=None, + ) + else: + return result + + # Otherwise convert to content and create ToolResult with proper structured content + from fastmcp.tools.tool import _convert_to_content + + unstructured_result = _convert_to_content( + result, serializer=self.serializer + ) + + # Handle structured content based on output schema + if self.output_schema is not None: + if self.output_schema.get("x-fastmcp-wrap-result"): + # Schema says wrap - always wrap in result key + structured_output = {"result": result} + else: + # Object schemas - use result directly + # User is responsible for returning dict-compatible data + structured_output = result + else: + structured_output = None + + return ToolResult( + content=unstructured_result, + structured_content=structured_output, + ) finally: _current_tool.reset(token) @@ -286,6 +333,7 @@ def from_tool( transform_fn: Callable[..., Any] | None = None, transform_args: dict[str, ArgTransform] | None = None, annotations: ToolAnnotations | None = None, + output_schema: dict[str, Any] | None | Literal[False] = None, serializer: Callable[[Any], str] | None = None, enabled: bool | None = None, ) -> TransformedTool: @@ -305,6 +353,10 @@ def from_tool( description: New description. Defaults to parent's description. tags: New tags. Defaults to parent's tags. annotations: New annotations. Defaults to parent's annotations. + output_schema: Control output schema for structured outputs: + - None (default): Inherit from transform_fn if available, then parent tool + - dict: Use custom output schema + - False: Disable output schema and structured outputs serializer: New serializer. Defaults to parent's serializer. Returns: @@ -333,6 +385,26 @@ async def flexible(**kwargs) -> str: Tool.from_tool(parent, transform_fn=flexible, transform_args={"a": "x"}) ``` + + # Control structured outputs and schemas + ```python + # Custom output schema + Tool.from_tool(parent, output_schema={ + "type": "object", + "properties": {"status": {"type": "string"}} + }) + + # Disable structured outputs + Tool.from_tool(parent, output_schema=False) + + # Return ToolResult for full control + async def custom_output(**kwargs) -> ToolResult: + result = await forward(**kwargs) + return ToolResult( + content=[TextContent(text="Summary")], + structured_content={"processed": True} + ) + ``` """ transform_args = transform_args or {} @@ -348,19 +420,45 @@ async def flexible(**kwargs) -> str: # Always create the forwarding transform schema, forwarding_fn = cls._create_forwarding_transform(tool, transform_args) + # Handle output schema with smart fallback + if output_schema is False: + final_output_schema = None + elif output_schema is not None: + # Explicit schema provided - use as-is + final_output_schema = output_schema + else: + # Smart fallback: try custom function, then parent, then None + if transform_fn is not None: + parsed_fn = ParsedFunction.from_function(transform_fn, validate=False) + final_output_schema = _wrap_schema_if_needed(parsed_fn.output_schema) + if final_output_schema is None: + # Check if function returns ToolResult - if so, don't fall back to parent + import inspect + + return_annotation = inspect.signature( + transform_fn + ).return_annotation + if return_annotation is ToolResult: + final_output_schema = None + else: + final_output_schema = tool.output_schema + else: + final_output_schema = tool.output_schema + if transform_fn is None: # User wants pure transformation - use forwarding_fn as the main function final_fn = forwarding_fn final_schema = schema else: # User provided custom function - merge schemas - parsed_fn = ParsedFunction.from_function(transform_fn, validate=False) + if "parsed_fn" not in locals(): + parsed_fn = ParsedFunction.from_function(transform_fn, validate=False) final_fn = transform_fn has_kwargs = cls._function_has_kwargs(transform_fn) # Validate function parameters against transformed schema - fn_params = set(parsed_fn.parameters.get("properties", {}).keys()) + fn_params = set(parsed_fn.input_schema.get("properties", {}).keys()) transformed_params = set(schema.get("properties", {}).keys()) if not has_kwargs: @@ -377,7 +475,7 @@ async def flexible(**kwargs) -> str: # ArgTransform takes precedence over function signature # Start with function schema as base, then override with transformed schema final_schema = cls._merge_schema_with_precedence( - parsed_fn.parameters, schema + parsed_fn.input_schema, schema ) else: # With **kwargs, function can access all transformed params @@ -386,7 +484,7 @@ async def flexible(**kwargs) -> str: # Start with function schema as base, then override with transformed schema final_schema = cls._merge_schema_with_precedence( - parsed_fn.parameters, schema + parsed_fn.input_schema, schema ) # Additional validation: check for naming conflicts after transformation @@ -422,6 +520,7 @@ async def flexible(**kwargs) -> str: name=name or tool.name, description=final_description, parameters=final_schema, + output_schema=final_output_schema, tags=tags or tool.tags, annotations=annotations or tool.annotations, serializer=serializer or tool.serializer, diff --git a/src/fastmcp/utilities/json_schema_type.py b/src/fastmcp/utilities/json_schema_type.py new file mode 100644 index 000000000..1160c6334 --- /dev/null +++ b/src/fastmcp/utilities/json_schema_type.py @@ -0,0 +1,646 @@ +"""Convert JSON Schema to Python types with validation. + +The json_schema_to_type function converts a JSON Schema into a Python type that can be used +for validation with Pydantic. It supports: + +- Basic types (string, number, integer, boolean, null) +- Complex types (arrays, objects) +- Format constraints (date-time, email, uri) +- Numeric constraints (minimum, maximum, multipleOf) +- String constraints (minLength, maxLength, pattern) +- Array constraints (minItems, maxItems, uniqueItems) +- Object properties with defaults +- References and recursive schemas +- Enums and constants +- Union types + +Example: + ```python + schema = { + "type": "object", + "properties": { + "name": {"type": "string", "minLength": 1}, + "age": {"type": "integer", "minimum": 0}, + "email": {"type": "string", "format": "email"} + }, + "required": ["name", "age"] + } + + # Name is optional and will be inferred from schema's "title" property if not provided + Person = json_schema_to_type(schema) + # Creates a validated dataclass with name, age, and optional email fields + ``` +""" + +from __future__ import annotations + +import hashlib +import json +import re +from collections.abc import Callable, Mapping +from copy import deepcopy +from dataclasses import MISSING, field, make_dataclass +from datetime import datetime +from enum import Enum +from typing import ( + Annotated, + Any, + ForwardRef, + Literal, + Union, +) + +from pydantic import ( + AnyUrl, + BaseModel, + ConfigDict, + EmailStr, + Field, + Json, + StringConstraints, + model_validator, +) +from typing_extensions import NotRequired, TypedDict + +__all__ = ["json_schema_to_type", "JSONSchema"] + + +FORMAT_TYPES: dict[str, Any] = { + "date-time": datetime, + "email": EmailStr, + "uri": AnyUrl, + "json": Json, +} + +_classes: dict[tuple[str, Any], type | None] = {} + + +class JSONSchema(TypedDict): + type: NotRequired[str | list[str]] + properties: NotRequired[dict[str, JSONSchema]] + required: NotRequired[list[str]] + additionalProperties: NotRequired[bool | JSONSchema] + items: NotRequired[JSONSchema | list[JSONSchema]] + enum: NotRequired[list[Any]] + const: NotRequired[Any] + default: NotRequired[Any] + description: NotRequired[str] + title: NotRequired[str] + examples: NotRequired[list[Any]] + format: NotRequired[str] + allOf: NotRequired[list[JSONSchema]] + anyOf: NotRequired[list[JSONSchema]] + oneOf: NotRequired[list[JSONSchema]] + not_: NotRequired[JSONSchema] + definitions: NotRequired[dict[str, JSONSchema]] + dependencies: NotRequired[dict[str, JSONSchema | list[str]]] + pattern: NotRequired[str] + minLength: NotRequired[int] + maxLength: NotRequired[int] + minimum: NotRequired[int | float] + maximum: NotRequired[int | float] + exclusiveMinimum: NotRequired[int | float] + exclusiveMaximum: NotRequired[int | float] + multipleOf: NotRequired[int | float] + uniqueItems: NotRequired[bool] + minItems: NotRequired[int] + maxItems: NotRequired[int] + additionalItems: NotRequired[bool | JSONSchema] + + +def json_schema_to_type( + schema: Mapping[str, Any], + name: str | None = None, +) -> type: + """Convert JSON schema to appropriate Python type with validation. + + Args: + schema: A JSON Schema dictionary defining the type structure and validation rules + name: Optional name for object schemas. Only allowed when schema type is "object". + If not provided for objects, name will be inferred from schema's "title" + property or default to "Root". + + Returns: + A Python type (typically a dataclass for objects) with Pydantic validation + + Raises: + ValueError: If a name is provided for a non-object schema + + Examples: + Create a dataclass from an object schema: + ```python + schema = { + "type": "object", + "title": "Person", + "properties": { + "name": {"type": "string", "minLength": 1}, + "age": {"type": "integer", "minimum": 0}, + "email": {"type": "string", "format": "email"} + }, + "required": ["name", "age"] + } + + Person = json_schema_to_type(schema) + # Creates a dataclass with name, age, and optional email fields: + # @dataclass + # class Person: + # name: str + # age: int + # email: str | None = None + ``` + Person(name="John", age=30) + + Create a scalar type with constraints: + ```python + schema = { + "type": "string", + "minLength": 3, + "pattern": "^[A-Z][a-z]+$" + } + + NameType = json_schema_to_type(schema) + # Creates Annotated[str, StringConstraints(min_length=3, pattern="^[A-Z][a-z]+$")] + + @dataclass + class Name: + name: NameType + ``` + """ + # Always use the top-level schema for references + if schema.get("type") == "object": + # If no properties defined but has additionalProperties, return typed dict + if not schema.get("properties") and schema.get("additionalProperties"): + additional_props = schema["additionalProperties"] + if additional_props is True: + return dict[str, Any] # type: ignore - additionalProperties: true means dict[str, Any] + else: + # Handle typed dictionaries like dict[str, str] + value_type = _schema_to_type(additional_props, schemas=schema) + return dict[str, value_type] # type: ignore + # If no properties and no additionalProperties, default to dict[str, Any] for safety + elif not schema.get("properties") and not schema.get("additionalProperties"): + return dict[str, Any] # type: ignore + # If has properties AND additionalProperties is True, use Pydantic BaseModel + elif schema.get("properties") and schema.get("additionalProperties") is True: + return _create_pydantic_model(schema, name, schemas=schema) + # Otherwise use fast dataclass + return _create_dataclass(schema, name, schemas=schema) + elif name: + raise ValueError(f"Can not apply name to non-object schema: {name}") + result = _schema_to_type(schema, schemas=schema) + return result # type: ignore[return-value] + + +def _hash_schema(schema: Mapping[str, Any]) -> str: + """Generate a deterministic hash for schema caching.""" + return hashlib.sha256(json.dumps(schema, sort_keys=True).encode()).hexdigest() + + +def _resolve_ref(ref: str, schemas: Mapping[str, Any]) -> Mapping[str, Any]: + """Resolve JSON Schema reference to target schema.""" + path = ref.replace("#/", "").split("/") + current = schemas + for part in path: + current = current.get(part, {}) + return current + + +def _create_string_type(schema: Mapping[str, Any]) -> type | Annotated[Any, ...]: + """Create string type with optional constraints.""" + if "const" in schema: + return Literal[schema["const"]] # type: ignore + + if fmt := schema.get("format"): + if fmt == "uri": + return AnyUrl + elif fmt == "uri-reference": + return str + return FORMAT_TYPES.get(fmt, str) + + constraints = { + k: v + for k, v in { + "min_length": schema.get("minLength"), + "max_length": schema.get("maxLength"), + "pattern": schema.get("pattern"), + }.items() + if v is not None + } + + return Annotated[str, StringConstraints(**constraints)] if constraints else str + + +def _create_numeric_type( + base: type[int | float], schema: Mapping[str, Any] +) -> type | Annotated[Any, ...]: + """Create numeric type with optional constraints.""" + if "const" in schema: + return Literal[schema["const"]] # type: ignore + + constraints = { + k: v + for k, v in { + "gt": schema.get("exclusiveMinimum"), + "ge": schema.get("minimum"), + "lt": schema.get("exclusiveMaximum"), + "le": schema.get("maximum"), + "multiple_of": schema.get("multipleOf"), + }.items() + if v is not None + } + + return Annotated[base, Field(**constraints)] if constraints else base + + +def _create_enum(name: str, values: list[Any]) -> type: + """Create enum type from list of values.""" + if all(isinstance(v, str) for v in values): + return Enum(name, {v.upper(): v for v in values}) # type: ignore[return-value] + return Literal[tuple(values)] # type: ignore[return-value] + + +def _create_array_type( + schema: Mapping[str, Any], schemas: Mapping[str, Any] +) -> type | Annotated[Any, ...]: + """Create list/set type with optional constraints.""" + items = schema.get("items", {}) + if isinstance(items, list): + # Handle positional item schemas + item_types = [_schema_to_type(s, schemas) for s in items] + combined = Union[tuple(item_types)] # type: ignore # noqa: UP007 + base = list[combined] + else: + # Handle single item schema + item_type = _schema_to_type(items, schemas) + base_class = set if schema.get("uniqueItems") else list + base = base_class[item_type] # type: ignore[misc] + + constraints = { + k: v + for k, v in { + "min_length": schema.get("minItems"), + "max_length": schema.get("maxItems"), + }.items() + if v is not None + } + + return Annotated[base, Field(**constraints)] if constraints else base + + +def _return_Any() -> Any: + return Any + + +def _get_from_type_handler( + schema: Mapping[str, Any], schemas: Mapping[str, Any] +) -> Callable[..., Any]: + """Get the appropriate type handler for the schema.""" + + type_handlers: dict[str, Callable[..., Any]] = { # TODO + "string": lambda s: _create_string_type(s), # type: ignore + "integer": lambda s: _create_numeric_type(int, s), # type: ignore + "number": lambda s: _create_numeric_type(float, s), # type: ignore + "boolean": lambda _: bool, # type: ignore + "null": lambda _: type(None), # type: ignore + "array": lambda s: _create_array_type(s, schemas), # type: ignore + "object": lambda s: ( + _create_pydantic_model(s, s.get("title"), schemas) + if s.get("properties") and s.get("additionalProperties") is True + else _create_dataclass(s, s.get("title"), schemas) + ), # type: ignore + } + return type_handlers.get(schema.get("type", None), _return_Any) + + +def _schema_to_type( + schema: Mapping[str, Any], + schemas: Mapping[str, Any], +) -> type | ForwardRef: + """Convert schema to appropriate Python type.""" + if not schema: + return object + + if "type" not in schema and "properties" in schema: + return _create_dataclass(schema, schema.get("title", ""), schemas) + + # Handle references first + if "$ref" in schema: + ref = schema["$ref"] + # Handle self-reference + if ref == "#": + return ForwardRef(schema.get("title", "Root")) # type: ignore[return-value] + return _schema_to_type(_resolve_ref(ref, schemas), schemas) + + if "const" in schema: + return Literal[schema["const"]] # type: ignore + + if "enum" in schema: + return _create_enum(f"Enum_{len(_classes)}", schema["enum"]) + + # Handle anyOf unions + if "anyOf" in schema: + types: list[type | Any] = [] + for subschema in schema["anyOf"]: + # Special handling for dict-like objects in unions + if ( + subschema.get("type") == "object" + and not subschema.get("properties") + and subschema.get("additionalProperties") + ): + # This is a dict type, handle it directly + additional_props = subschema["additionalProperties"] + if additional_props is True: + types.append(dict[str, Any]) # type: ignore + else: + value_type = _schema_to_type(additional_props, schemas) + types.append(dict[str, value_type]) # type: ignore + else: + types.append(_schema_to_type(subschema, schemas)) + + # Check if one of the types is None (null) + has_null = type(None) in types + types = [t for t in types if t is not type(None)] + + if len(types) == 0: + return type(None) + elif len(types) == 1: + if has_null: + return types[0] | None # type: ignore + else: + return types[0] + else: + if has_null: + return Union[tuple(types + [type(None)])] # type: ignore # noqa: UP007 + else: + return Union[tuple(types)] # type: ignore # noqa: UP007 + + schema_type = schema.get("type") + if not schema_type: + return Any # type: ignore[return-value] + + if isinstance(schema_type, list): + # Create a copy of the schema for each type, but keep all constraints + types: list[type | Any] = [] + for t in schema_type: + type_schema = dict(schema) + type_schema["type"] = t + types.append(_schema_to_type(type_schema, schemas)) + has_null = type(None) in types + types = [t for t in types if t is not type(None)] + if has_null: + if len(types) == 1: + return types[0] | None # type: ignore + else: + return Union[tuple(types + [type(None)])] # type: ignore # noqa: UP007 + return Union[tuple(types)] # type: ignore # noqa: UP007 + + return _get_from_type_handler(schema, schemas)(schema) + + +def _sanitize_name(name: str) -> str: + """Convert string to valid Python identifier.""" + # Step 1: replace everything except [0-9a-zA-Z_] with underscores + cleaned = re.sub(r"[^0-9a-zA-Z_]", "_", name) + # Step 2: deduplicate underscores + cleaned = re.sub(r"__+", "_", cleaned) + # Step 3: if the first char of original name isn't a letter, prepend field_ + if not name or not re.match(r"[a-zA-Z]", name[0]): + cleaned = f"field_{cleaned}" + # Step 4: deduplicate again and strip trailing underscores + cleaned = re.sub(r"__+", "_", cleaned).strip("_") + return cleaned + + +def _get_default_value( + schema: dict[str, Any], + prop_name: str, + parent_default: dict[str, Any] | None = None, +) -> Any: + """Get default value with proper priority ordering. + 1. Value from parent's default if it exists + 2. Property's own default if it exists + 3. None + """ + if parent_default is not None and prop_name in parent_default: + return parent_default[prop_name] + return schema.get("default") + + +def _create_field_with_default( + field_type: type, + default_value: Any, + schema: dict[str, Any], +) -> Any: + """Create a field with simplified default handling.""" + # Always use None as default for complex types + if isinstance(default_value, dict | list) or default_value is None: + return field(default=None) + + # For simple types, use the value directly + return field(default=default_value) + + +def _create_pydantic_model( + schema: Mapping[str, Any], + name: str | None = None, + schemas: Mapping[str, Any] | None = None, +) -> type: + """Create Pydantic BaseModel from object schema with additionalProperties.""" + name = name or schema.get("title", "Root") + assert name is not None # Should not be None after the or operation + sanitized_name = _sanitize_name(name) + schema_hash = _hash_schema(schema) + cache_key = (schema_hash, sanitized_name) + + # Return existing class if already built + if cache_key in _classes: + existing = _classes[cache_key] + if existing is None: + return ForwardRef(sanitized_name) # type: ignore[return-value] + return existing + + # Place placeholder for recursive references + _classes[cache_key] = None + + properties = schema.get("properties", {}) + required = schema.get("required", []) + + # Build field annotations and defaults + annotations = {} + defaults = {} + + for prop_name, prop_schema in properties.items(): + field_type = _schema_to_type(prop_schema, schemas or {}) + + # Handle defaults + default_value = prop_schema.get("default", MISSING) + if default_value is not MISSING: + defaults[prop_name] = default_value + annotations[prop_name] = field_type + elif prop_name in required: + annotations[prop_name] = field_type + else: + annotations[prop_name] = Union[field_type, type(None)] # type: ignore[misc] # noqa: UP007 + defaults[prop_name] = None + + # Create Pydantic model class + cls_dict = { + "__annotations__": annotations, + "model_config": ConfigDict(extra="allow"), + **defaults, + } + + cls = type(sanitized_name, (BaseModel,), cls_dict) + + # Store completed class + _classes[cache_key] = cls + return cls + + +def _create_dataclass( + schema: Mapping[str, Any], + name: str | None = None, + schemas: Mapping[str, Any] | None = None, +) -> type: + """Create dataclass from object schema.""" + name = name or schema.get("title", "Root") + # Sanitize name for class creation + assert name is not None # Should not be None after the or operation + sanitized_name = _sanitize_name(name) + schema_hash = _hash_schema(schema) + cache_key = (schema_hash, sanitized_name) + original_schema = dict(schema) # Store copy for validator + + # Return existing class if already built + if cache_key in _classes: + existing = _classes[cache_key] + if existing is None: + return ForwardRef(sanitized_name) # type: ignore[return-value] + return existing + + # Place placeholder for recursive references + _classes[cache_key] = None + + if "$ref" in schema: + ref = schema["$ref"] + if ref == "#": + return ForwardRef(sanitized_name) # type: ignore[return-value] + schema = _resolve_ref(ref, schemas or {}) + + properties = schema.get("properties", {}) + required = schema.get("required", []) + + fields: list[tuple[Any, ...]] = [] + for prop_name, prop_schema in properties.items(): + field_name = _sanitize_name(prop_name) + + # Check for self-reference in property + if prop_schema.get("$ref") == "#": + field_type = ForwardRef(sanitized_name) + else: + field_type = _schema_to_type(prop_schema, schemas or {}) + + default_val = prop_schema.get("default", MISSING) + is_required = prop_name in required + + # Include alias in field metadata + meta = {"alias": prop_name} + + if default_val is not MISSING: + if isinstance(default_val, dict | list): + field_def = field( + default_factory=lambda d=default_val: deepcopy(d), metadata=meta + ) + else: + field_def = field(default=default_val, metadata=meta) + else: + if is_required: + field_def = field(metadata=meta) + else: + field_def = field(default=None, metadata=meta) + + if is_required and default_val is not MISSING: + fields.append((field_name, field_type, field_def)) + elif is_required: + fields.append((field_name, field_type, field_def)) + else: + fields.append((field_name, Union[field_type, type(None)], field_def)) # type: ignore[misc] # noqa: UP007 + + cls = make_dataclass(sanitized_name, fields, kw_only=True) + + # Add model validator for defaults + @model_validator(mode="before") + @classmethod + def _apply_defaults(cls, data: Mapping[str, Any]): + if isinstance(data, dict): + return _merge_defaults(data, original_schema) + return data + + setattr(cls, "_apply_defaults", _apply_defaults) + + # Store completed class + _classes[cache_key] = cls + return cls + + +def _merge_defaults( + data: Mapping[str, Any], + schema: Mapping[str, Any], + parent_default: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + """Merge defaults with provided data at all levels.""" + # If we have no data + if not data: + # Start with parent default if available + if parent_default: + result = dict(parent_default) + # Otherwise use schema default if available + elif "default" in schema: + result = dict(schema["default"]) + # Otherwise start empty + else: + result = {} + # If we have data and a parent default, merge them + elif parent_default: + result = dict(parent_default) + for key, value in data.items(): + if ( + isinstance(value, dict) + and key in result + and isinstance(result[key], dict) + ): + # recursively merge nested dicts + result[key] = _merge_defaults(value, {"properties": {}}, result[key]) + else: + result[key] = value + # Otherwise just use the data + else: + result = dict(data) + + # For each property in the schema + for prop_name, prop_schema in schema.get("properties", {}).items(): + # If property is missing, apply defaults in priority order + if prop_name not in result: + if parent_default and prop_name in parent_default: + result[prop_name] = parent_default[prop_name] + elif "default" in prop_schema: + result[prop_name] = prop_schema["default"] + + # If property exists and is an object, recursively merge + if ( + prop_name in result + and isinstance(result[prop_name], dict) + and prop_schema.get("type") == "object" + ): + # Get the appropriate default for this nested object + nested_default = None + if parent_default and prop_name in parent_default: + nested_default = parent_default[prop_name] + elif "default" in prop_schema: + nested_default = prop_schema["default"] + + result[prop_name] = _merge_defaults( + result[prop_name], prop_schema, nested_default + ) + + return result diff --git a/src/fastmcp/utilities/types.py b/src/fastmcp/utilities/types.py index 8c65bd82c..919f03abd 100644 --- a/src/fastmcp/utilities/types.py +++ b/src/fastmcp/utilities/types.py @@ -6,21 +6,19 @@ from collections.abc import Callable from functools import lru_cache from pathlib import Path -from types import UnionType -from typing import Annotated, TypeVar, Union, get_args, get_origin - -from mcp.types import ( - Annotations, - AudioContent, - BlobResourceContents, - EmbeddedResource, - ImageContent, - TextResourceContents, -) +from types import EllipsisType, UnionType +from typing import Annotated, TypeAlias, TypeVar, Union, get_args, get_origin + +import mcp.types +from mcp.types import Annotations from pydantic import AnyUrl, BaseModel, ConfigDict, TypeAdapter, UrlConstraints T = TypeVar("T") +# sentinel values for optional arguments +NotSet = ... +NotSetT: TypeAlias = EllipsisType + class FastMCPBaseModel(BaseModel): """Base model for FastMCP models.""" @@ -129,7 +127,7 @@ def to_image_content( self, mime_type: str | None = None, annotations: Annotations | None = None, - ) -> ImageContent: + ) -> mcp.types.ImageContent: """Convert to MCP ImageContent.""" if self.path: with open(self.path, "rb") as f: @@ -139,7 +137,7 @@ def to_image_content( else: raise ValueError("No image data available") - return ImageContent( + return mcp.types.ImageContent( type="image", data=data, mimeType=mime_type or self._mime_type, @@ -188,7 +186,7 @@ def to_audio_content( self, mime_type: str | None = None, annotations: Annotations | None = None, - ) -> AudioContent: + ) -> mcp.types.AudioContent: if self.path: with open(self.path, "rb") as f: data = base64.b64encode(f.read()).decode() @@ -197,7 +195,7 @@ def to_audio_content( else: raise ValueError("No audio data available") - return AudioContent( + return mcp.types.AudioContent( type="audio", data=data, mimeType=mime_type or self._mime_type, @@ -248,7 +246,7 @@ def to_resource_content( self, mime_type: str | None = None, annotations: Annotations | None = None, - ) -> EmbeddedResource: + ) -> mcp.types.EmbeddedResource: if self.path: with open(self.path, "rb") as f: raw_data = f.read() @@ -271,21 +269,57 @@ def to_resource_content( text = raw_data.decode("utf-8") except UnicodeDecodeError: text = raw_data.decode("latin-1") - resource = TextResourceContents( + resource = mcp.types.TextResourceContents( text=text, mimeType=mime, uri=uri, ) else: data = base64.b64encode(raw_data).decode() - resource = BlobResourceContents( + resource = mcp.types.BlobResourceContents( blob=data, mimeType=mime, uri=uri, ) - return EmbeddedResource( + return mcp.types.EmbeddedResource( type="resource", resource=resource, annotations=annotations or self.annotations, ) + + +def replace_type(type_, type_map: dict[type, type]): + """ + Given a (possibly generic, nested, or otherwise complex) type, replaces all + instances of old_type with new_type. + + This is useful for transforming types when creating tools. + + Args: + type_: The type to replace instances of old_type with new_type. + old_type: The type to replace. + new_type: The type to replace old_type with. + + Examples: + >>> replace_type(list[int | bool], {int: str}) + list[str | bool] + + >>> replace_type(list[list[int]], {int: str}) + list[list[str]] + + """ + if type_ in type_map: + return type_map[type_] + + origin = get_origin(type_) + if not origin: + return type_ + + args = get_args(type_) + new_args = tuple(replace_type(arg, type_map) for arg in args) + + if origin is UnionType: + return Union[new_args] # type: ignore # noqa: UP007 + else: + return origin[new_args] diff --git a/tests/auth/test_oauth_client.py b/tests/auth/test_oauth_client.py index cb7204818..c12f80a1d 100644 --- a/tests/auth/test_oauth_client.py +++ b/tests/auth/test_oauth_client.py @@ -226,7 +226,9 @@ async def test_call_tool(client_with_headless_oauth: Client): """Test that we can call a tool.""" async with client_with_headless_oauth: result = await client_with_headless_oauth.call_tool("add", {"a": 5, "b": 3}) - assert result[0].text == "8" # type: ignore[attr-defined] + # The add tool returns int which gets wrapped as structured output + # Client unwraps it and puts the actual int in the data field + assert result.data == 8 async def test_list_resources(client_with_headless_oauth: Client): diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 55210c432..499a3f256 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -121,9 +121,10 @@ async def test_call_tool(fastmcp_server): async with client: result = await client.call_tool("greet", {"name": "World"}) - # The result content should contain our greeting - content_str = str(result[0]) - assert "Hello, World!" in content_str + assert result.content[0].text == "Hello, World!" # type: ignore[attr-defined] + assert result.structured_content == {"result": "Hello, World!"} + assert result.data == "Hello, World!" + assert result.is_error is False async def test_call_tool_mcp(fastmcp_server): diff --git a/tests/client/test_notifications.py b/tests/client/test_notifications.py index 0c833adf2..62a5ea283 100644 --- a/tests/client/test_notifications.py +++ b/tests/client/test_notifications.py @@ -126,7 +126,7 @@ async def test_tool_enable_sends_notification( # Enable the target tool result = await client.call_tool("enable_target_tool", {}) - assert result[0].text == "Target tool enabled" # type: ignore[attr-defined] + assert result.data == "Target tool enabled" # Check that notification was sent recording_message_handler.assert_notification_sent( @@ -147,7 +147,7 @@ async def test_tool_disable_sends_notification( # Disable the target tool result = await client.call_tool("disable_target_tool", {}) - assert result[0].text == "Target tool disabled" # type: ignore[attr-defined] + assert result.data == "Target tool disabled" # Check that notification was sent recording_message_handler.assert_notification_sent( @@ -231,7 +231,7 @@ async def test_resource_enable_sends_notification( # Enable the target resource result = await client.call_tool("enable_target_resource", {}) - assert result[0].text == "Target resource enabled" # type: ignore[attr-defined] + assert result.data == "Target resource enabled" # Check that notification was sent recording_message_handler.assert_notification_sent( @@ -252,7 +252,7 @@ async def test_resource_disable_sends_notification( # Disable the target resource result = await client.call_tool("disable_target_resource", {}) - assert result[0].text == "Target resource disabled" # type: ignore[attr-defined] + assert result.data == "Target resource disabled" # Check that notification was sent recording_message_handler.assert_notification_sent( @@ -313,7 +313,7 @@ async def test_prompt_enable_sends_notification( # Enable the target prompt result = await client.call_tool("enable_target_prompt", {}) - assert result[0].text == "Target prompt enabled" # type: ignore[attr-defined] + assert result.data == "Target prompt enabled" # Check that notification was sent recording_message_handler.assert_notification_sent( @@ -334,7 +334,7 @@ async def test_prompt_disable_sends_notification( # Disable the target prompt result = await client.call_tool("disable_target_prompt", {}) - assert result[0].text == "Target prompt disabled" # type: ignore[attr-defined] + assert result.data == "Target prompt disabled" # Check that notification was sent recording_message_handler.assert_notification_sent( diff --git a/tests/client/test_openapi.py b/tests/client/test_openapi.py index 6f662e927..04ba6c123 100644 --- a/tests/client/test_openapi.py +++ b/tests/client/test_openapi.py @@ -118,7 +118,7 @@ async def test_client_headers_sse_tool(self, sse_server: str): transport=SSETransport(sse_server, headers={"X-TEST": "test-123"}) ) as client: result = await client.call_tool("post_headers_headers_post") - headers = json.loads(result[0].text) # type: ignore[attr-defined] + headers: dict[str, str] = result.data assert headers["x-test"] == "test-123" async def test_client_headers_shttp_tool(self, shttp_server: str): @@ -128,7 +128,7 @@ async def test_client_headers_shttp_tool(self, shttp_server: str): ) ) as client: result = await client.call_tool("post_headers_headers_post") - headers = json.loads(result[0].text) # type: ignore[attr-defined] + headers: dict[str, str] = result.data assert headers["x-test"] == "test-123" async def test_client_overrides_server_headers(self, shttp_server: str): diff --git a/tests/client/test_roots.py b/tests/client/test_roots.py index f4df827de..d3bc7d5ca 100644 --- a/tests/client/test_roots.py +++ b/tests/client/test_roots.py @@ -1,5 +1,3 @@ -import json - import pytest from fastmcp import Client, Context, FastMCP @@ -40,7 +38,7 @@ async def test_invalid_urls(self, fastmcp_server: FastMCP, roots: list[str]): async def test_valid_roots(self, fastmcp_server: FastMCP, roots: list[str]): async with Client(fastmcp_server, roots=roots) as client: result = await client.call_tool("list_roots", {}) - assert json.loads(result[0].text) == [ # type: ignore[attr-defined] + assert result.data == [ "file://x/y/z", "file://x/y/z", ] diff --git a/tests/client/test_sampling.py b/tests/client/test_sampling.py index 497aa8513..5b11b0885 100644 --- a/tests/client/test_sampling.py +++ b/tests/client/test_sampling.py @@ -47,8 +47,7 @@ def sampling_handler( async with Client(fastmcp_server, sampling_handler=sampling_handler) as client: result = await client.call_tool("simple_sample", {"message": "Hello, world!"}) - reply = cast(TextContent, result[0]) - assert reply.text == "This is the sample message!" + assert result.data == "This is the sample message!" async def test_sampling_with_system_prompt(fastmcp_server: FastMCP): @@ -62,8 +61,7 @@ def sampling_handler( result = await client.call_tool( "sample_with_system_prompt", {"message": "Hello, world!"} ) - reply = cast(TextContent, result[0]) - assert reply.text == "You love FastMCP" + assert result.data == "You love FastMCP" async def test_sampling_with_messages(fastmcp_server: FastMCP): @@ -81,5 +79,4 @@ def sampling_handler( result = await client.call_tool( "sample_with_messages", {"message": "Hello, world!"} ) - reply = cast(TextContent, result[0]) - assert reply.text == "I need to think." + assert result.data == "I need to think." diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index d9f9247d8..0ccac8f5d 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -48,11 +48,11 @@ async def test_keep_alive_maintains_session_across_multiple_calls( async with client: result1 = await client.call_tool("pid") - pid1 = int(result1[0].text) # type: ignore[attr-defined] + pid1: int = result1.data async with client: result2 = await client.call_tool("pid") - pid2 = int(result2[0].text) # type: ignore[attr-defined] + pid2: int = result2.data assert pid1 == pid2 @@ -66,11 +66,11 @@ async def test_keep_alive_false_starts_new_session_across_multiple_calls( async with client: result1 = await client.call_tool("pid") - pid1 = int(result1[0].text) # type: ignore[attr-defined] + pid1: int = result1.data async with client: result2 = await client.call_tool("pid") - pid2 = int(result2[0].text) # type: ignore[attr-defined] + pid2: int = result2.data assert pid1 != pid2 @@ -80,13 +80,13 @@ async def test_keep_alive_starts_new_session_if_manually_closed(self, stdio_scri async with client: result1 = await client.call_tool("pid") - pid1 = int(result1[0].text) # type: ignore[attr-defined] + pid1: int = result1.data await client.close() async with client: result2 = await client.call_tool("pid") - pid2 = int(result2[0].text) # type: ignore[attr-defined] + pid2: int = result2.data assert pid1 != pid2 @@ -96,14 +96,14 @@ async def test_keep_alive_maintains_session_if_reentered(self, stdio_script): async with client: result1 = await client.call_tool("pid") - pid1 = int(result1[0].text) # type: ignore[attr-defined] + pid1: int = result1.data async with client: result2 = await client.call_tool("pid") - pid2 = int(result2[0].text) # type: ignore[attr-defined] + pid2: int = result2.data result3 = await client.call_tool("pid") - pid3 = int(result3[0].text) # type: ignore[attr-defined] + pid3: int = result3.data assert pid1 == pid2 == pid3 diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index efcb79f16..7bc7b4e7d 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -7,7 +7,6 @@ import pytest import uvicorn from mcp import McpError -from mcp.types import TextContent from starlette.applications import Starlette from starlette.routing import Mount @@ -166,10 +165,7 @@ async def test_greet_with_progress_tool(streamable_http_server: str): progress_handler=progress_handler, ) as client: result = await client.call_tool("greet_with_progress", {"name": "Alice"}) - - assert isinstance(result, list) - assert isinstance(result[0], TextContent) - assert result[0].text == "Hello, Alice!" + assert result.data == "Hello, Alice!" progress_handler.assert_called_once_with(0.5, 1.0, "Greeting in progress") diff --git a/tests/contrib/test_bulk_tool_caller.py b/tests/contrib/test_bulk_tool_caller.py index 58aebb762..578dcba96 100644 --- a/tests/contrib/test_bulk_tool_caller.py +++ b/tests/contrib/test_bulk_tool_caller.py @@ -59,7 +59,10 @@ async def no_return_tool(arg1: str) -> None: def no_return_tool_result_factory(arg1: str) -> CallToolRequestResult: """A tool that returns a result based on the input arguments.""" return CallToolRequestResult( - isError=False, content=[], tool="no_return_tool", arguments={"arg1": arg1} + isError=False, + content=[], + tool="no_return_tool", + arguments={"arg1": arg1}, ) diff --git a/tests/deprecated/test_mount_import_arg_order.py b/tests/deprecated/test_mount_import_arg_order.py index 7fc273b36..b65b75c30 100644 --- a/tests/deprecated/test_mount_import_arg_order.py +++ b/tests/deprecated/test_mount_import_arg_order.py @@ -36,7 +36,7 @@ def sub_tool() -> str: # Test functionality async with Client(main_app) as client: result = await client.call_tool("sub_sub_tool", {}) - assert result[0].text == "Sub tool result" # type: ignore[attr-defined] + assert result.data == "Sub tool result" async def test_mount_new_arg_order_no_warning(self): """Test that mount(server, prefix) works without deprecation warning.""" @@ -122,7 +122,7 @@ def sub_tool() -> str: # Test functionality async with Client(main_app) as client: result = await client.call_tool("sub_sub_tool", {}) - assert result[0].text == "Sub tool result" # type: ignore[attr-defined] + assert result.data == "Sub tool result" async def test_import_new_arg_order_no_warning(self): """Test that import_server(server, prefix) works without deprecation warning.""" diff --git a/tests/server/http/test_http_dependencies.py b/tests/server/http/test_http_dependencies.py index ff88e1e7a..32aa87588 100644 --- a/tests/server/http/test_http_dependencies.py +++ b/tests/server/http/test_http_dependencies.py @@ -86,9 +86,8 @@ async def test_http_headers_tool_shttp(shttp_server: str): ) ) as client: result = await client.call_tool("get_headers_tool") - json_result = json.loads(result[0].text) # type: ignore[attr-defined] - assert "x-demo-header" in json_result - assert json_result["x-demo-header"] == "ABC" + assert "x-demo-header" in result.data + assert result.data["x-demo-header"] == "ABC" async def test_http_headers_tool_sse(sse_server: str): @@ -96,9 +95,8 @@ async def test_http_headers_tool_sse(sse_server: str): transport=SSETransport(sse_server, headers={"X-DEMO-HEADER": "ABC"}) ) as client: result = await client.call_tool("get_headers_tool") - json_result = json.loads(result[0].text) # type: ignore[attr-defined] - assert "x-demo-header" in json_result - assert json_result["x-demo-header"] == "ABC" + assert "x-demo-header" in result.data + assert result.data["x-demo-header"] == "ABC" async def test_http_headers_prompt_shttp(shttp_server: str): diff --git a/tests/server/openapi/test_openapi.py b/tests/server/openapi/test_openapi.py index 5f1583fcf..f8d5013c6 100644 --- a/tests/server/openapi/test_openapi.py +++ b/tests/server/openapi/test_openapi.py @@ -269,9 +269,8 @@ async def test_call_create_user_tool( "create_user_users_post", {"name": "David", "active": False} ) - response_data = json.loads(tool_response[0].text) # type: ignore[attr-defined] expected_user = User(id=4, name="David", active=False).model_dump() - assert response_data == expected_user + assert tool_response.data == expected_user # Check that the user was created via API response = await api_client.get("/users") @@ -298,9 +297,8 @@ async def test_call_update_user_name_tool( {"user_id": 1, "name": "XYZ"}, ) - response_data = json.loads(tool_response[0].text) # type: ignore[attr-defined] expected_data = dict(id=1, name="XYZ", active=True) - assert response_data == expected_data + assert tool_response.data == expected_data # Check that the user was updated via API response = await api_client.get("/users") @@ -332,10 +330,12 @@ async def test_call_tool_return_list( ) async with Client(mcp_server) as client: tool_response = await client.call_tool("get_users_users_get", {}) - assert json.loads(tool_response[0].text) == [ # type: ignore[attr-defined] - user.model_dump() - for user in sorted(users_db.values(), key=lambda x: x.id) - ] + assert tool_response.data == { + "result": [ + user.model_dump() + for user in sorted(users_db.values(), key=lambda x: x.id) + ] + } class TestResources: @@ -729,12 +729,22 @@ async def test_tool_execution(self, openapi_30_server_with_all_types): "createProduct", {"name": "New Product", "price": 39.99} ) # Result should be a text content - assert len(result) == 1 - product = json.loads(result[0].text) # type: ignore[attr-defined] + assert len(result.content) == 1 + product = json.loads(result.content[0].text) # type: ignore[attr-defined] assert product["id"] == "p3" assert product["name"] == "New Product" assert product["price"] == 39.99 + assert result.structured_content is not None + assert result.structured_content["id"] == "p3" + assert result.structured_content["name"] == "New Product" + assert result.structured_content["price"] == 39.99 + + assert result.data is not None + assert result.data["id"] == "p3" + assert result.data["name"] == "New Product" + assert result.data["price"] == 39.99 + class TestOpenAPI31Compatibility: """Tests for compatibility with OpenAPI 3.1 specifications.""" @@ -905,12 +915,22 @@ async def test_tool_execution(self, openapi_31_server_with_all_types): "createOrder", {"customer": "Charlie", "items": ["item4", "item5"]} ) # Result should be a text content - assert len(result) == 1 - order = json.loads(result[0].text) # type: ignore[attr-dict] + assert len(result.content) == 1 + order = json.loads(result.content[0].text) # type: ignore[attr-dict] assert order["id"] == "o3" assert order["customer"] == "Charlie" assert order["items"] == ["item4", "item5"] + assert result.structured_content is not None + assert result.structured_content["id"] == "o3" + assert result.structured_content["customer"] == "Charlie" + assert result.structured_content["items"] == ["item4", "item5"] + + assert result.data is not None + assert result.data["id"] == "o3" + assert result.data["customer"] == "Charlie" + assert result.data["items"] == ["item4", "item5"] + async def test_empty_query_parameters_not_sent( fastapi_app: FastAPI, api_client: httpx.AsyncClient diff --git a/tests/server/openapi/test_openapi_path_parameters.py b/tests/server/openapi/test_openapi_path_parameters.py index 7f977ea50..075ad7045 100644 --- a/tests/server/openapi/test_openapi_path_parameters.py +++ b/tests/server/openapi/test_openapi_path_parameters.py @@ -301,20 +301,11 @@ async def select_days( # Single day result = await client.call_tool(tool_name, {"days": ["monday"]}) - # Client returns TextContent objects, so parse the JSON - assert len(result) == 1 - assert result[0].type == "text" - import json - - result_data = json.loads(result[0].text) - assert result_data == {"selected": ["monday"]} + assert result.data == {"selected": ["monday"]} # Multiple days result = await client.call_tool(tool_name, {"days": ["monday", "tuesday"]}) - assert len(result) == 1 - assert result[0].type == "text" - result_data = json.loads(result[0].text) - assert result_data == {"selected": ["monday", "tuesday"]} + assert result.data == {"selected": ["monday", "tuesday"]} async def test_array_query_parameter_format(mock_client): diff --git a/tests/server/test_import_server.py b/tests/server/test_import_server.py index 958fbdc83..1f8f6611c 100644 --- a/tests/server/test_import_server.py +++ b/tests/server/test_import_server.py @@ -224,7 +224,7 @@ def fetch_data(query: str) -> str: async with Client(main_app) as client: result = await client.call_tool("api_get_data", {"query": "test"}) - assert result[0].text == "Data for query: test" # type: ignore[attr-defined] + assert result.data == "Data for query: test" async def test_first_level_importing_with_custom_name(): @@ -278,7 +278,7 @@ def calculate_value(input: int) -> int: async with Client(main_app) as client: result = await client.call_tool("service_provider_compute", {"input": 21}) - assert result[0].text == "42" # type: ignore[attr-defined] + assert result.data == 42 async def test_import_with_proxy_tools(): @@ -302,7 +302,7 @@ def get_data(query: str) -> str: async with Client(main_app) as client: result = await client.call_tool("api_get_data", {"query": "test"}) - assert result[0].text == "Data for query: test" # type: ignore[attr-defined] + assert result.data == "Data for query: test" async def test_import_with_proxy_prompts(): @@ -443,7 +443,7 @@ def sub_prompt() -> str: async with Client(main_app) as client: # Test tool tool_result = await client.call_tool("sub_tool", {}) - assert tool_result[0].text == "Sub tool result" # type: ignore[attr-defined] + assert tool_result.data == "Sub tool result" # Test resource resource_result = await client.read_resource("data://config") @@ -485,7 +485,7 @@ def second_shared_tool() -> str: assert tool_names.count("shared_tool") == 1 # Should only appear once result = await client.call_tool("shared_tool", {}) - assert result[0].text == "Second app tool" # type: ignore[attr-defined] + assert result.data == "Second app tool" async def test_import_conflict_resolution_resources(): @@ -604,4 +604,4 @@ def second_shared_tool() -> str: assert tool_names.count("api_shared_tool") == 1 # Should only appear once result = await client.call_tool("api_shared_tool", {}) - assert result[0].text == "Second app tool" # type: ignore[attr-defined] + assert result.data == "Second app tool" diff --git a/tests/server/test_mount.py b/tests/server/test_mount.py index bb1d65d8f..942809c4e 100644 --- a/tests/server/test_mount.py +++ b/tests/server/test_mount.py @@ -33,7 +33,7 @@ def sub_tool() -> str: async with Client(main_app) as client: result = await client.call_tool("sub_sub_tool", {}) - assert result[0].text == "This is from the sub app" # type: ignore[attr-defined] + assert result.data == "This is from the sub app" async def test_mount_with_custom_separator(self): """Test mounting with a custom tool separator (deprecated but still supported).""" @@ -52,8 +52,9 @@ def greet(name: str) -> str: assert "sub_greet" in tools # Call the tool - result = await main_app._mcp_call_tool("sub_greet", {"name": "World"}) - assert result[0].text == "Hello, World!" # type: ignore[attr-defined] + async with Client(main_app) as client: + result = await client.call_tool("sub_greet", {"name": "World"}) + assert result.data == "Hello, World!" async def test_mount_invalid_resource_prefix(self): main_app = FastMCP("MainApp") @@ -104,8 +105,9 @@ def sub_tool() -> str: assert "sub_tool" in tools # Call the tool to verify it works - result = await main_app._mcp_call_tool("sub_tool", {}) - assert result[0].text == "This is from the sub app" # type: ignore[attr-defined] + async with Client(main_app) as client: + result = await client.call_tool("sub_tool", {}) + assert result.data == "This is from the sub app" async def test_mount_tools_no_prefix(self): """Test mounting a server with tools without prefix.""" @@ -124,8 +126,9 @@ def sub_tool() -> str: assert "sub_tool" in tools # Test actual functionality - tool_result = await main_app._mcp_call_tool("sub_tool", {}) - assert tool_result[0].text == "Sub tool result" # type: ignore[attr-defined] + async with Client(main_app) as client: + tool_result = await client.call_tool("sub_tool", {}) + assert tool_result.data == "Sub tool result" async def test_mount_resources_no_prefix(self): """Test mounting a server with resources without prefix.""" @@ -144,8 +147,9 @@ def sub_resource(): assert "data://config" in resources # Test actual functionality - resource_result = await main_app._mcp_read_resource("data://config") - assert resource_result[0].content == "Sub resource data" # type: ignore[attr-defined] + async with Client(main_app) as client: + resource_result = await client.read_resource("data://config") + assert resource_result[0].text == "Sub resource data" # type: ignore[attr-defined] async def test_mount_resource_templates_no_prefix(self): """Test mounting a server with resource templates without prefix.""" @@ -164,8 +168,9 @@ def sub_template(user_id: str): assert "users://{user_id}/info" in templates # Test actual functionality - template_result = await main_app._mcp_read_resource("users://123/info") - assert template_result[0].content == "Sub template for user 123" # type: ignore[attr-defined] + async with Client(main_app) as client: + template_result = await client.read_resource("users://123/info") + assert template_result[0].text == "Sub template for user 123" # type: ignore[attr-defined] async def test_mount_prompts_no_prefix(self): """Test mounting a server with prompts without prefix.""" @@ -184,8 +189,9 @@ def sub_prompt() -> str: assert "sub_prompt" in prompts # Test actual functionality - prompt_result = await main_app._mcp_get_prompt("sub_prompt", {}) - assert prompt_result.messages is not None + async with Client(main_app) as client: + prompt_result = await client.get_prompt("sub_prompt", {}) + assert prompt_result.messages is not None class TestMultipleServerMount: @@ -215,11 +221,11 @@ def get_headlines() -> str: assert "news_get_headlines" in tools # Call tools from both mounted servers - result1 = await main_app._mcp_call_tool("weather_get_forecast", {}) - assert result1[0].text == "Weather forecast" # type: ignore[attr-defined] - - result2 = await main_app._mcp_call_tool("news_get_headlines", {}) - assert result2[0].text == "News headlines" # type: ignore[attr-defined] + async with Client(main_app) as client: + result1 = await client.call_tool("weather_get_forecast", {}) + assert result1.data == "Weather forecast" + result2 = await client.call_tool("news_get_headlines", {}) + assert result2.data == "News headlines" async def test_mount_same_prefix(self): """Test that mounting with the same prefix replaces the previous mount.""" @@ -292,7 +298,7 @@ def working_prompt() -> str: # Test calling a tool result = await client.call_tool("working_working_tool", {}) - assert result[0].text == "Working tool" # type: ignore[attr-defined] + assert result.data == "Working tool" # Test resources resources = await client.list_resources() @@ -352,7 +358,7 @@ def second_shared_tool() -> str: # Test that calling the tool uses the later server's implementation result = await client.call_tool("shared_tool", {}) - assert result[0].text == "Second app tool" # type: ignore[attr-defined] + assert result.data == "Second app tool" async def test_later_server_wins_tools_same_prefix(self): """Test that later mounted server wins for tools when same prefix is used.""" @@ -381,7 +387,7 @@ def second_shared_tool() -> str: # Test that calling the tool uses the later server's implementation result = await client.call_tool("api_shared_tool", {}) - assert result[0].text == "Second app tool" # type: ignore[attr-defined] + assert result.data == "Second app tool" async def test_later_server_wins_resources_no_prefix(self): """Test that later mounted server wins for resources when no prefix is used.""" @@ -593,8 +599,9 @@ def dynamic_tool() -> str: assert "sub_dynamic_tool" in tools # Call the dynamically added tool - result = await main_app._mcp_call_tool("sub_dynamic_tool", {}) - assert result[0].text == "Added after mounting" # type: ignore[attr-defined] + async with Client(main_app) as client: + result = await client.call_tool("sub_dynamic_tool", {}) + assert result.data == "Added after mounting" async def test_removing_tool_after_mounting(self): """Test that tools removed from mounted servers are no longer accessible.""" @@ -726,8 +733,9 @@ def greeting(name: str) -> str: assert "assistant_greeting" in prompts # Render the prompt - result = await main_app._mcp_get_prompt("assistant_greeting", {"name": "World"}) - assert result.messages is not None + async with Client(main_app) as client: + result = await client.get_prompt("assistant_greeting", {"name": "World"}) + assert result.messages is not None # The message should contain our greeting text async def test_adding_prompt_after_mounting(self): @@ -748,8 +756,9 @@ def farewell(name: str) -> str: assert "assistant_farewell" in prompts # Render the prompt - result = await main_app._mcp_get_prompt("assistant_farewell", {"name": "World"}) - assert result.messages is not None + async with Client(main_app) as client: + result = await client.get_prompt("assistant_farewell", {"name": "World"}) + assert result.messages is not None # The message should contain our farewell text @@ -779,8 +788,9 @@ def get_data(query: str) -> str: assert "proxy_get_data" in tools # Call the tool - result = await main_app._mcp_call_tool("proxy_get_data", {"query": "test"}) - assert result[0].text == "Data for test" # type: ignore[attr-defined] + async with Client(main_app) as client: + result = await client.call_tool("proxy_get_data", {"query": "test"}) + assert result.data == "Data for test" async def test_dynamically_adding_to_proxied_server(self): """Test that changes to the original server are reflected in the mounted proxy.""" @@ -806,8 +816,9 @@ def dynamic_data() -> str: assert "proxy_dynamic_data" in tools # Call the tool - result = await main_app._mcp_call_tool("proxy_dynamic_data", {}) - assert result[0].text == "Dynamic data" # type: ignore[attr-defined] + async with Client(main_app) as client: + result = await client.call_tool("proxy_dynamic_data", {}) + assert result.data == "Dynamic data" async def test_proxy_server_with_resources(self): """Test mounting a proxy server with resources.""" @@ -828,9 +839,10 @@ def get_config(): main_app.mount(proxy_server, "proxy") # Resource should be accessible through main app - result = await main_app._mcp_read_resource("config://proxy/settings") - config = json.loads(result[0].content) # type: ignore[attr-defined] - assert config["api_key"] == "12345" + async with Client(main_app) as client: + result = await client.read_resource("config://proxy/settings") + config = json.loads(result[0].text) # type: ignore[attr-defined] + assert config["api_key"] == "12345" async def test_proxy_server_with_prompts(self): """Test mounting a proxy server with prompts.""" @@ -851,8 +863,9 @@ def welcome(name: str) -> str: main_app.mount(proxy_server, "proxy") # Prompt should be accessible through main app - result = await main_app._mcp_get_prompt("proxy_welcome", {"name": "World"}) - assert result.messages is not None + async with Client(main_app) as client: + result = await client.get_prompt("proxy_welcome", {"name": "World"}) + assert result.messages is not None # The message should contain our welcome text diff --git a/tests/server/test_proxy.py b/tests/server/test_proxy.py index 612f8bfc8..e63c18551 100644 --- a/tests/server/test_proxy.py +++ b/tests/server/test_proxy.py @@ -89,15 +89,17 @@ async def test_create_proxy(fastmcp_server): async def test_as_proxy_with_server(fastmcp_server): """FastMCP.as_proxy should accept a FastMCP instance.""" proxy = FastMCP.as_proxy(fastmcp_server) - result = await proxy._mcp_call_tool("greet", {"name": "Test"}) - assert result[0].text == "Hello, Test!" # type: ignore[attr-defined] + async with Client(proxy) as client: + result = await client.call_tool("greet", {"name": "Test"}) + assert result.data == "Hello, Test!" async def test_as_proxy_with_transport(fastmcp_server): """FastMCP.as_proxy should accept a ClientTransport.""" proxy = FastMCP.as_proxy(FastMCPTransport(fastmcp_server)) - result = await proxy._mcp_call_tool("greet", {"name": "Test"}) - assert result[0].text == "Hello, Test!" # type: ignore[attr-defined] + async with Client(proxy) as client: + result = await client.call_tool("greet", {"name": "Test"}) + assert result.data == "Hello, Test!" def test_as_proxy_with_url(): @@ -137,7 +139,7 @@ async def test_call_tool_result_same_as_original( async def test_call_tool_calls_tool(self, proxy_server): async with Client(proxy_server) as client: proxy_result = await client.call_tool("add", {"a": 1, "b": 2}) - assert proxy_result[0].text == "3" # type: ignore[attr-defined] + assert proxy_result.data == 3 async def test_error_tool_raises_error(self, proxy_server): with pytest.raises(ToolError, match="This is a test error"): @@ -155,7 +157,7 @@ def greet(name: str, extra: str = "extra") -> str: async with Client(proxy_server) as client: result = await client.call_tool("greet", {"name": "Marvin", "extra": "abc"}) - assert result[0].text == "Overwritten, Marvin! abc" # type: ignore[attr-defined] + assert result.data == "Overwritten, Marvin! abc" async def test_proxy_errors_if_overwritten_tool_is_disabled(self, proxy_server): """ diff --git a/tests/server/test_server.py b/tests/server/test_server.py index d255ad76c..9c08955c2 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -45,9 +45,7 @@ def hello_world(name: str = "世界") -> str: assert "🎉" in tool.description result = await client.call_tool("hello_world", {}) - assert len(result) == 1 - content = result[0] - assert content.text == "¡Hola, 世界! 👋" # type: ignore[attr-defined] + assert result.data == "¡Hola, 世界! 👋" class TestTools: @@ -129,8 +127,9 @@ async def test_tool_decorator(self): def add(x: int, y: int) -> int: return x + y - result = await mcp._mcp_call_tool("add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 async def test_tool_decorator_without_parentheses(self): """Test that @tool decorator works without parentheses.""" @@ -146,8 +145,9 @@ def add(x: int, y: int) -> int: assert "add" in tools # Verify it can be called - result = await mcp._mcp_call_tool("add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 async def test_tool_decorator_with_name(self): mcp = FastMCP() @@ -156,8 +156,9 @@ async def test_tool_decorator_with_name(self): def add(x: int, y: int) -> int: return x + y - result = await mcp._mcp_call_tool("custom-add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("custom-add", {"x": 1, "y": 2}) + assert result.data == 3 async def test_tool_decorator_with_description(self): mcp = FastMCP() @@ -183,8 +184,9 @@ def add(self, y: int) -> int: obj = MyClass(10) mcp.add_tool(Tool.from_function(obj.add)) - result = await mcp._mcp_call_tool("add", {"y": 2}) - assert result[0].text == "12" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"y": 2}) + assert result.data == 12 async def test_tool_decorator_classmethod(self): mcp = FastMCP() @@ -197,8 +199,9 @@ def add(cls, y: int) -> int: return cls.x + y mcp.add_tool(Tool.from_function(MyClass.add)) - result = await mcp._mcp_call_tool("add", {"y": 2}) - assert result[0].text == "12" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"y": 2}) + assert result.data == 12 async def test_tool_decorator_staticmethod(self): mcp = FastMCP() @@ -209,8 +212,9 @@ class MyClass: def add(x: int, y: int) -> int: return x + y - result = await mcp._mcp_call_tool("add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 async def test_tool_decorator_async_function(self): mcp = FastMCP() @@ -219,8 +223,9 @@ async def test_tool_decorator_async_function(self): async def add(x: int, y: int) -> int: return x + y - result = await mcp._mcp_call_tool("add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 async def test_tool_decorator_classmethod_error(self): mcp = FastMCP() @@ -244,8 +249,9 @@ async def add(cls, y: int) -> int: return cls.x + y mcp.add_tool(Tool.from_function(MyClass.add)) - result = await mcp._mcp_call_tool("add", {"y": 2}) - assert result[0].text == "12" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"y": 2}) + assert result.data == 12 async def test_tool_decorator_staticmethod_async_function(self): mcp = FastMCP() @@ -256,8 +262,9 @@ async def add(x: int, y: int) -> int: return x + y mcp.add_tool(Tool.from_function(MyClass.add)) - result = await mcp._mcp_call_tool("add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add", {"x": 1, "y": 2}) + assert result.data == 3 async def test_tool_decorator_staticmethod_order(self): """Test that the recommended decorator order works for static methods""" @@ -270,8 +277,9 @@ def add_v1(x: int, y: int) -> int: return x + y # Test that the recommended order works - result = await mcp._mcp_call_tool("add_v1", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("add_v1", {"x": 1, "y": 2}) + assert result.data == 3 async def test_tool_decorator_with_tags(self): """Test that the tool decorator properly sets tags.""" @@ -301,8 +309,9 @@ def multiply(a: int, b: int) -> int: assert "custom_multiply" in tools # Call the tool by its custom name - result = await mcp._mcp_call_tool("custom_multiply", {"a": 5, "b": 3}) - assert result[0].text == "15" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("custom_multiply", {"a": 5, "b": 3}) + assert result.data == 15 # Original name should not be registered assert "multiply" not in tools @@ -356,8 +365,9 @@ def standalone_function(x: int, y: int) -> int: assert tools["direct_call_tool"] is result_fn # Verify it can be called - result = await mcp._mcp_call_tool("direct_call_tool", {"x": 5, "y": 3}) - assert result[0].text == "8" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("direct_call_tool", {"x": 5, "y": 3}) + assert result.data == 8 async def test_tool_decorator_with_string_name(self): """Test that @tool("custom_name") syntax works correctly.""" @@ -374,8 +384,9 @@ def my_function(x: int) -> str: assert "my_function" not in tools # Original name should not be registered # Verify it can be called - result = await mcp._mcp_call_tool("string_named_tool", {"x": 42}) - assert result[0].text == "Result: 42" # type: ignore[attr-defined] + async with Client(mcp) as client: + result = await client.call_tool("string_named_tool", {"x": 42}) + assert result.data == "Result: 42" async def test_tool_decorator_conflicting_names_error(self): """Test that providing both positional and keyword name raises an error.""" @@ -390,6 +401,17 @@ async def test_tool_decorator_conflicting_names_error(self): def my_function(x: int) -> str: return f"Result: {x}" + async def test_tool_decorator_with_output_schema(self): + mcp = FastMCP() + + with pytest.raises( + ValueError, match='Output schemas must have "type" set to "object"' + ): + + @mcp.tool(output_schema={"type": "integer"}) + def my_function(x: int) -> str: + return f"Result: {x}" + class TestResourceDecorator: async def test_no_resources_before_decorator(self): diff --git a/tests/server/test_server_interactions.py b/tests/server/test_server_interactions.py index 36abc0c4b..f37434194 100644 --- a/tests/server/test_server_interactions.py +++ b/tests/server/test_server_interactions.py @@ -2,11 +2,11 @@ import datetime import json import uuid +from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Annotated, Literal +from typing import Annotated, Any, Literal -import pydantic_core import pytest from mcp import McpError from mcp.types import ( @@ -17,7 +17,8 @@ TextContent, TextResourceContents, ) -from pydantic import AnyUrl, Field +from pydantic import AnyUrl, BaseModel, Field, TypeAdapter +from typing_extensions import TypedDict from fastmcp import Client, Context, FastMCP from fastmcp.client.transports import FastMCPTransport @@ -25,10 +26,26 @@ from fastmcp.prompts.prompt import Prompt, PromptMessage from fastmcp.resources import FileResource, ResourceTemplate from fastmcp.resources.resource import FunctionResource -from fastmcp.tools.tool import Tool +from fastmcp.tools.tool import Tool, ToolResult from fastmcp.utilities.types import Audio, File, Image +class PersonTypedDict(TypedDict): + name: str + age: int + + +class PersonModel(BaseModel): + name: str + age: int + + +@dataclass +class PersonDataclass: + name: str + age: int + + @pytest.fixture def tool_server(): mcp = FastMCP() @@ -72,7 +89,7 @@ def mixed_content_tool() -> list[TextContent | ImageContent | EmbeddedResource]: ), ] - @mcp.tool + @mcp.tool(output_schema=None) def mixed_list_fn(image_path: str) -> list: return [ "text message", @@ -81,7 +98,7 @@ def mixed_list_fn(image_path: str) -> list: TextContent(type="text", text="direct content"), ] - @mcp.tool + @mcp.tool(output_schema=None) def mixed_audio_list_fn(audio_path: str) -> list: return [ "text message", @@ -90,7 +107,7 @@ def mixed_audio_list_fn(audio_path: str) -> list: TextContent(type="text", text="direct content"), ] - @mcp.tool + @mcp.tool(output_schema=None) def mixed_file_list_fn(file_path: str) -> list: return [ "text message", @@ -117,26 +134,24 @@ async def test_list_tools(self, tool_server: FastMCP): async with Client(tool_server) as client: assert len(await client.list_tools()) == 11 - async def test_call_tool(self, tool_server: FastMCP): + async def test_call_tool_mcp(self, tool_server: FastMCP): async with Client(tool_server) as client: - result = await client.call_tool("add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + result = await client.call_tool_mcp("add", {"x": 1, "y": 2}) + assert result.content[0].text == "3" # type: ignore[attr-defined] + assert result.structuredContent == {"result": 3} - async def test_call_tool_as_client(self, tool_server: FastMCP): + async def test_call_tool(self, tool_server: FastMCP): async with Client(tool_server) as client: result = await client.call_tool("add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + assert result.content[0].text == "3" # type: ignore[attr-defined] + assert result.structured_content == {"result": 3} + assert result.data == 3 async def test_call_tool_error(self, tool_server: FastMCP): async with Client(tool_server) as client: with pytest.raises(Exception): await client.call_tool("error_tool", {}) - async def test_call_tool_error_as_client(self, tool_server: FastMCP): - async with Client(tool_server) as client: - with pytest.raises(Exception): - await client.call_tool("error_tool", {}) - async def test_call_tool_error_as_client_raw(self): """Test raising and catching errors from a tool.""" mcp = FastMCP() @@ -154,13 +169,14 @@ def error_tool(): async def test_tool_returns_list(self, tool_server: FastMCP): async with Client(tool_server) as client: result = await client.call_tool("list_tool", {}) - assert result[0].text == '[\n "x",\n 2\n]' # type: ignore[attr-defined] + assert result.content[0].text == '[\n "x",\n 2\n]' # type: ignore[attr-defined] + assert result.data == ["x", 2] async def test_file_text_tool(self, tool_server: FastMCP): async with Client(tool_server) as client: result = await client.call_tool("file_text_tool", {}) - assert len(result) == 1 - embedded = result[0] + assert len(result.content) == 1 + embedded = result.content[0] assert isinstance(embedded, EmbeddedResource) resource = embedded.resource assert isinstance(resource, TextResourceContents) @@ -222,7 +238,7 @@ async def test_call_included_tool(self): async with Client(mcp) as client: result_1 = await client.call_tool("tool_1", {}) - assert result_1[0].text == "1" # type: ignore[attr-defined] + assert result_1.data == 1 with pytest.raises(ToolError, match="Unknown tool"): await client.call_tool("tool_2", {}) @@ -235,7 +251,7 @@ async def test_call_excluded_tool(self): await client.call_tool("tool_1", {}) result_2 = await client.call_tool("tool_2", {}) - assert result_2[0].text == "2" # type: ignore[attr-defined] + assert result_2.data == 2 class TestToolReturnTypes: @@ -248,7 +264,7 @@ def string_tool() -> str: async with Client(mcp) as client: result = await client.call_tool("string_tool", {}) - assert result[0].text == "Hello, world!" # type: ignore[attr-defined] + assert result.data == "Hello, world!" async def test_bytes(self, tmp_path: Path): mcp = FastMCP() @@ -259,7 +275,7 @@ def bytes_tool() -> bytes: async with Client(mcp) as client: result = await client.call_tool("bytes_tool", {}) - assert result[0].text == '"Hello, world!"' # type: ignore[attr-defined] + assert result.data == "Hello, world!" async def test_uuid(self): mcp = FastMCP() @@ -272,7 +288,7 @@ def uuid_tool() -> uuid.UUID: async with Client(mcp) as client: result = await client.call_tool("uuid_tool", {}) - assert result[0].text == pydantic_core.to_json(test_uuid).decode() # type: ignore[attr-defined] + assert result.data == str(test_uuid) async def test_path(self): mcp = FastMCP() @@ -285,7 +301,7 @@ def path_tool() -> Path: async with Client(mcp) as client: result = await client.call_tool("path_tool", {}) - assert result[0].text == pydantic_core.to_json(test_path).decode() # type: ignore[attr-defined] + assert result.data == str(test_path) async def test_datetime(self): mcp = FastMCP() @@ -298,7 +314,7 @@ def datetime_tool() -> datetime.datetime: async with Client(mcp) as client: result = await client.call_tool("datetime_tool", {}) - assert result[0].text == pydantic_core.to_json(dt).decode() # type: ignore[attr-defined] + assert result.data == dt async def test_image(self, tmp_path: Path): mcp = FastMCP() @@ -313,7 +329,8 @@ def image_tool(path: str) -> Image: async with Client(mcp) as client: result = await client.call_tool("image_tool", {"path": str(image_path)}) - content = result[0] + assert result.structured_content is None + content = result.content[0] assert isinstance(content, ImageContent) assert content.type == "image" assert content.mimeType == "image/png" @@ -334,7 +351,7 @@ def audio_tool(path: str) -> Audio: async with Client(mcp) as client: result = await client.call_tool("audio_tool", {"path": str(audio_path)}) - content = result[0] + content = result.content[0] assert isinstance(content, AudioContent) assert content.type == "audio" assert content.mimeType == "audio/wav" @@ -355,7 +372,7 @@ def file_tool(path: str) -> File: async with Client(mcp) as client: result = await client.call_tool("file_tool", {"path": str(file_path)}) - content = result[0] + content = result.content[0] assert isinstance(content, EmbeddedResource) assert content.type == "resource" resource = content.resource @@ -371,10 +388,10 @@ def file_tool(path: str) -> File: async def test_tool_mixed_content(self, tool_server: FastMCP): async with Client(tool_server) as client: result = await client.call_tool("mixed_content_tool", {}) - assert len(result) == 3 - content1 = result[0] - content2 = result[1] - content3 = result[2] + assert len(result.content) == 3 + content1 = result.content[0] + content2 = result.content[1] + content3 = result.content[2] assert isinstance(content1, TextContent) assert content1.text == "Hello" assert isinstance(content2, ImageContent) @@ -402,18 +419,18 @@ async def test_tool_mixed_list_with_image( result = await client.call_tool( "mixed_list_fn", {"image_path": str(image_path)} ) - assert len(result) == 3 + assert len(result.content) == 3 # Check text conversion - content1 = result[0] + content1 = result.content[0] assert isinstance(content1, TextContent) assert json.loads(content1.text) == ["text message", {"key": "value"}] # Check image conversion - content2 = result[1] + content2 = result.content[1] assert isinstance(content2, ImageContent) assert content2.mimeType == "image/png" assert base64.b64decode(content2.data) == b"test image data" # Check direct TextContent - content3 = result[2] + content3 = result.content[2] assert isinstance(content3, TextContent) assert content3.text == "direct content" @@ -430,18 +447,18 @@ async def test_tool_mixed_list_with_audio( result = await client.call_tool( "mixed_audio_list_fn", {"audio_path": str(audio_path)} ) - assert len(result) == 3 + assert len(result.content) == 3 # Check text conversion - content1 = result[0] + content1 = result.content[0] assert isinstance(content1, TextContent) assert json.loads(content1.text) == ["text message", {"key": "value"}] # Check audio conversion - content2 = result[1] + content2 = result.content[1] assert isinstance(content2, AudioContent) assert content2.mimeType == "audio/wav" assert base64.b64decode(content2.data) == b"test audio data" # Check direct TextContent - content3 = result[2] + content3 = result.content[2] assert isinstance(content3, TextContent) assert content3.text == "direct content" @@ -458,13 +475,13 @@ async def test_tool_mixed_list_with_file( result = await client.call_tool( "mixed_file_list_fn", {"file_path": str(file_path)} ) - assert len(result) == 3 + assert len(result.content) == 3 # Check text conversion - content1 = result[0] + content1 = result.content[0] assert isinstance(content1, TextContent) assert json.loads(content1.text) == ["text message", {"key": "value"}] # Check file conversion - content2 = result[1] + content2 = result.content[1] assert isinstance(content2, EmbeddedResource) assert content2.type == "resource" resource = content2.resource @@ -473,7 +490,7 @@ async def test_tool_mixed_list_with_file( blob_data = getattr(resource, "blob") assert base64.b64decode(blob_data) == b"test file data" # Check direct TextContent - content3 = result[2] + content3 = result.content[2] assert isinstance(content3, TextContent) assert content3.text == "direct content" @@ -540,9 +557,10 @@ def process_image(image: bytes) -> Image: result = await client.call_tool( "process_image", {"image": b"fake png data"} ) - assert isinstance(result[0], ImageContent) - assert result[0].mimeType == "image/png" - assert result[0].data == base64.b64encode(b"fake png data").decode() + assert result.structured_content is None + assert isinstance(result.content[0], ImageContent) + assert result.content[0].mimeType == "image/png" + assert result.content[0].data == base64.b64encode(b"fake png data").decode() async def test_tool_with_invalid_input(self): mcp = FastMCP() @@ -660,7 +678,7 @@ def analyze(x: Literal["a", "b"]) -> str: async with Client(mcp) as client: result = await client.call_tool("analyze", {"x": "a"}) - assert result[0].text == "a" # type: ignore[attr-defined] + assert result.data == "a" async def test_enum_type_validation_error(self): mcp = FastMCP() @@ -695,7 +713,7 @@ def analyze(x: MyEnum) -> str: async with Client(mcp) as client: result = await client.call_tool("analyze", {"x": "red"}) - assert result[0].text == "red" # type: ignore[attr-defined] + assert result.data == "red" async def test_union_type_validation(self): mcp = FastMCP() @@ -706,10 +724,10 @@ def analyze(x: int | float) -> str: async with Client(mcp) as client: result = await client.call_tool("analyze", {"x": 1}) - assert result[0].text == "1" # type: ignore[attr-defined] + assert result.data == "1" result = await client.call_tool("analyze", {"x": 1.0}) - assert result[0].text == "1.0" # type: ignore[attr-defined] + assert result.data == "1.0" with pytest.raises( ToolError, @@ -730,7 +748,7 @@ def send_path(path: Path) -> str: async with Client(mcp) as client: result = await client.call_tool("send_path", {"path": str(test_path)}) - assert result[0].text == str(test_path) # type: ignore[attr-defined] + assert result.data == str(test_path) async def test_path_type_error(self): mcp = FastMCP() @@ -757,7 +775,7 @@ def send_uuid(x: uuid.UUID) -> str: async with Client(mcp) as client: result = await client.call_tool("send_uuid", {"x": test_uuid}) - assert result[0].text == str(test_uuid) # type: ignore[attr-defined] + assert result.data == str(test_uuid) async def test_uuid_type_error(self): mcp = FastMCP() @@ -781,7 +799,7 @@ def send_datetime(x: datetime.datetime) -> str: async with Client(mcp) as client: result = await client.call_tool("send_datetime", {"x": dt}) - assert result[0].text == dt.isoformat() # type: ignore[attr-defined] + assert result.data == dt.isoformat() async def test_datetime_type_parse_string(self): mcp = FastMCP() @@ -794,7 +812,7 @@ def send_datetime(x: datetime.datetime) -> str: result = await client.call_tool( "send_datetime", {"x": "2021-01-01T00:00:00"} ) - assert result[0].text == "2021-01-01T00:00:00" # type: ignore[attr-defined] + assert result.data == "2021-01-01T00:00:00" async def test_datetime_type_error(self): mcp = FastMCP() @@ -816,7 +834,7 @@ def send_date(x: datetime.date) -> str: async with Client(mcp) as client: result = await client.call_tool("send_date", {"x": datetime.date.today()}) - assert result[0].text == datetime.date.today().isoformat() # type: ignore[attr-defined] + assert result.data == datetime.date.today().isoformat() async def test_date_type_parse_string(self): mcp = FastMCP() @@ -827,7 +845,7 @@ def send_date(x: datetime.date) -> str: async with Client(mcp) as client: result = await client.call_tool("send_date", {"x": "2021-01-01"}) - assert result[0].text == "2021-01-01" # type: ignore[attr-defined] + assert result.data == "2021-01-01" async def test_timedelta_type(self): mcp = FastMCP() @@ -840,7 +858,7 @@ def send_timedelta(x: datetime.timedelta) -> str: result = await client.call_tool( "send_timedelta", {"x": datetime.timedelta(days=1)} ) - assert result[0].text == "1 day, 0:00:00" # type: ignore[attr-defined] + assert result.data == "1 day, 0:00:00" async def test_timedelta_type_parse_int(self): """Test that invalid timedelta input raises validation error.""" @@ -859,6 +877,270 @@ def send_timedelta(x: datetime.timedelta) -> str: await client.call_tool("send_timedelta", {"x": 1000}) +class TestToolOutputSchema: + @pytest.mark.parametrize("annotation", [str, int, float, bool, list, AnyUrl]) + async def test_simple_output_schema(self, annotation): + mcp = FastMCP() + + @mcp.tool + def f() -> annotation: # type: ignore + return "hello" + + async with Client(mcp) as client: + tools = await client.list_tools() + assert len(tools) == 1 + + type_schema = TypeAdapter(annotation).json_schema() + # this line will fail until MCP adds output schemas!! + assert tools[0].outputSchema == { + "type": "object", + "properties": {"result": type_schema}, + "x-fastmcp-wrap-result": True, + } + + @pytest.mark.parametrize( + "annotation", + [dict[str, int | str], PersonTypedDict, PersonModel, PersonDataclass], + ) + async def test_structured_output_schema(self, annotation): + mcp = FastMCP() + + @mcp.tool + def f() -> annotation: # type: ignore[valid-type] + return {"name": "John", "age": 30} + + async with Client(mcp) as client: + tools = await client.list_tools() + + type_schema = TypeAdapter(annotation).json_schema() + assert len(tools) == 1 + assert tools[0].outputSchema == type_schema + + async def test_disabled_output_schema_no_structured_content(self): + mcp = FastMCP() + + @mcp.tool(output_schema=None) + def f() -> int: + return 42 + + async with Client(mcp) as client: + result = await client.call_tool("f", {}) + assert result.content[0].text == "42" # type: ignore[attr-defined] + assert result.structured_content is None + assert result.data is None + + async def test_manual_structured_content(self): + mcp = FastMCP() + + @mcp.tool + def f() -> ToolResult: + return ToolResult( + content="Hello, world!", structured_content={"message": "Hello, world!"} + ) + + assert f.output_schema is None + + async with Client(mcp) as client: + result = await client.call_tool("f", {}) + assert result.content[0].text == "Hello, world!" # type: ignore[attr-defined] + assert result.structured_content == {"message": "Hello, world!"} + assert result.data == {"message": "Hello, world!"} + + async def test_output_schema_false_full_handshake(self): + """Test that output_schema=False works through full client/server + handshake. We test this by returning a scalar, which requires an output + schema to serialize.""" + mcp = FastMCP() + + @mcp.tool(output_schema=False) # type: ignore[arg-type] + def simple_tool() -> int: + return 42 + + async with Client(mcp) as client: + # List tools and verify output schema is None + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "simple_tool") + assert tool.outputSchema is None + + # Call tool and verify no structured content + result = await client.call_tool("simple_tool", {}) + assert result.structured_content is None + assert result.data is None + assert result.content[0].text == "42" # type: ignore[attr-defined] + + async def test_output_schema_explicit_object_full_handshake(self): + """Test explicit object output schema through full client/server handshake.""" + mcp = FastMCP() + + @mcp.tool( + output_schema={ + "type": "object", + "properties": { + "greeting": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["greeting"], + } + ) + def explicit_tool() -> dict[str, Any]: + return {"greeting": "Hello", "count": 42} + + async with Client(mcp) as client: + # List tools and verify exact schema is preserved + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "explicit_tool") + expected_schema = { + "type": "object", + "properties": { + "greeting": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["greeting"], + } + assert tool.outputSchema == expected_schema + + # Call tool and verify structured content matches return value directly + result = await client.call_tool("explicit_tool", {}) + assert result.structured_content == {"greeting": "Hello", "count": 42} + # Client deserializes according to schema, so check fields + assert result.data.greeting == "Hello" # type: ignore[attr-defined] + assert result.data.count == 42 # type: ignore[attr-defined] + + async def test_output_schema_wrapped_primitive_full_handshake(self): + """Test wrapped primitive output schema through full client/server handshake.""" + mcp = FastMCP() + + @mcp.tool + def primitive_tool() -> str: + return "Hello, primitives!" + + async with Client(mcp) as client: + # List tools and verify schema shows wrapped structure + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "primitive_tool") + expected_schema = { + "type": "object", + "properties": {"result": {"type": "string"}}, + "x-fastmcp-wrap-result": True, + } + assert tool.outputSchema == expected_schema + + # Call tool and verify structured content is wrapped + result = await client.call_tool("primitive_tool", {}) + assert result.structured_content == {"result": "Hello, primitives!"} + assert result.data == "Hello, primitives!" # Client unwraps for convenience + + async def test_output_schema_complex_type_full_handshake(self): + """Test complex type output schema through full client/server handshake.""" + mcp = FastMCP() + + @mcp.tool + def complex_tool() -> list[dict[str, int]]: + return [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + + async with Client(mcp) as client: + # List tools and verify schema shows wrapped array + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "complex_tool") + expected_inner_schema = TypeAdapter(list[dict[str, int]]).json_schema() + expected_schema = { + "type": "object", + "properties": {"result": expected_inner_schema}, + "x-fastmcp-wrap-result": True, + } + assert tool.outputSchema == expected_schema + + # Call tool and verify structured content is wrapped + result = await client.call_tool("complex_tool", {}) + expected_data = [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + assert result.structured_content == {"result": expected_data} + # Client deserializes - just verify we got data back + assert result.data is not None + + async def test_output_schema_dataclass_full_handshake(self): + """Test dataclass output schema through full client/server handshake.""" + mcp = FastMCP() + + @dataclass + class User: + name: str + age: int + + @mcp.tool + def dataclass_tool() -> User: + return User(name="Alice", age=30) + + async with Client(mcp) as client: + # List tools and verify schema is object type (not wrapped) + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "dataclass_tool") + expected_schema = TypeAdapter(User).json_schema() + assert tool.outputSchema == expected_schema + assert ( + tool.outputSchema and "x-fastmcp-wrap-result" not in tool.outputSchema + ) + + # Call tool and verify structured content is direct + result = await client.call_tool("dataclass_tool", {}) + assert result.structured_content == {"name": "Alice", "age": 30} + # Client deserializes according to schema + assert result.data.name == "Alice" # type: ignore[attr-defined] + assert result.data.age == 30 # type: ignore[attr-defined] + + async def test_output_schema_mixed_content_types(self): + """Test tools with mixed content and output schemas.""" + mcp = FastMCP() + + @mcp.tool + def mixed_output() -> list[Any]: + # Return mixed content that includes MCP types and regular data + return [ + "text message", + {"structured": "data"}, + TextContent(type="text", text="direct MCP content"), + ] + + async with Client(mcp) as client: + result = await client.call_tool("mixed_output", {}) + + # Should have multiple content blocks + assert len(result.content) >= 2 + + # Should have structured output with wrapped result + expected_data = [ + "text message", + {"structured": "data"}, + { + "type": "text", + "text": "direct MCP content", + "annotations": None, + "_meta": None, + }, + ] + assert result.structured_content == {"result": expected_data} + + async def test_output_schema_serialization_edge_cases(self): + """Test edge cases in output schema serialization.""" + mcp = FastMCP() + + @mcp.tool + def edge_case_tool() -> tuple[int, str]: + return (42, "hello") + + async with Client(mcp) as client: + # Verify tuple gets proper schema + tools = await client.list_tools() + tool = next(t for t in tools if t.name == "edge_case_tool") + + # Tuples should be wrapped since they're not object type + assert tool.outputSchema and "x-fastmcp-wrap-result" in tool.outputSchema + + result = await client.call_tool("edge_case_tool", {}) + # Should be wrapped with result key + assert result.structured_content == {"result": [42, "hello"]} + assert result.data == [42, "hello"] + + class TestToolContextInjection: """Test context injection in tools.""" @@ -887,9 +1169,7 @@ def tool_with_context(x: int, ctx: Context) -> str: async with Client(mcp) as client: result = await client.call_tool("tool_with_context", {"x": 42}) - assert len(result) == 1 - content = result[0] - assert content.text == "1" # type: ignore[attr-defined] + assert result.data == "1" async def test_async_context(self): """Test that context works in async functions.""" @@ -902,9 +1182,7 @@ async def async_tool(x: int, ctx: Context) -> str: async with Client(mcp) as client: result = await client.call_tool("async_tool", {"x": 42}) - assert len(result) == 1 - content = result[0] - assert content.text == "Async request 1: 42" # type: ignore[attr-defined] + assert result.data == "Async request 1: 42" async def test_optional_context(self): """Test that context is optional.""" @@ -916,9 +1194,7 @@ def no_context(x: int) -> int: async with Client(mcp) as client: result = await client.call_tool("no_context", {"x": 21}) - assert len(result) == 1 - content = result[0] - assert content.text == "42" # type: ignore[attr-defined] + assert result.data == 42 async def test_context_resource_access(self): """Test that context can access resources.""" @@ -938,9 +1214,9 @@ async def tool_with_resource(ctx: Context) -> str: async with Client(mcp) as client: result = await client.call_tool("tool_with_resource", {}) - assert len(result) == 1 - content = result[0] - assert "Read resource: resource data" in content.text # type: ignore[attr-defined] + assert ( + result.data == "Read resource: resource data with mime type text/plain" + ) async def test_tool_decorator_with_tags(self): """Test that the tool decorator properly sets tags.""" @@ -968,7 +1244,7 @@ async def __call__(self, x: int, ctx: Context) -> int: async with Client(mcp) as client: result = await client.call_tool("MyTool", {"x": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + assert result.data == 3 class TestToolEnabled: diff --git a/tests/server/test_tool_annotations.py b/tests/server/test_tool_annotations.py index 07a21077d..d0f03876c 100644 --- a/tests/server/test_tool_annotations.py +++ b/tests/server/test_tool_annotations.py @@ -218,8 +218,4 @@ def create_item(name: str, value: int) -> dict[str, Any]: result = await client.call_tool( "create_item", {"name": "test_item", "value": 42} ) - assert len(result) == 1 - - # The result should contain the expected JSON - assert '"name": "test_item"' in result[0].text # type: ignore[attr-defined] - assert '"value": 42' in result[0].text # type: ignore[attr-defined] + assert result.data == {"name": "test_item", "value": 42} diff --git a/tests/server/test_tool_exclude_args.py b/tests/server/test_tool_exclude_args.py index 025555ebc..fec695f88 100644 --- a/tests/server/test_tool_exclude_args.py +++ b/tests/server/test_tool_exclude_args.py @@ -1,7 +1,6 @@ from typing import Any import pytest -from mcp.types import TextContent from fastmcp import Client, FastMCP from fastmcp.tools.tool import Tool @@ -92,9 +91,4 @@ def create_item( result = await client.call_tool( "create_item", {"name": "test_item", "value": 42} ) - assert len(result) == 1 - assert isinstance(result[0], TextContent) - - # The result should contain the expected JSON - assert '"name": "test_item"' in result[0].text - assert '"value": 42' in result[0].text + assert result.data == {"name": "test_item", "value": 42} diff --git a/tests/test_examples.py b/tests/test_examples.py index 0fa1da3a4..0edeee96e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -10,9 +10,9 @@ async def test_simple_echo(): from examples.simple_echo import mcp async with Client(mcp) as client: - result = await client.call_tool("echo", {"text": "hello"}) - assert len(result) == 1 - assert result[0].text == "hello" # type: ignore[attr-defined] + result = await client.call_tool_mcp("echo", {"text": "hello"}) + assert len(result.content) == 1 + assert result.content[0].text == "hello" # type: ignore[attr-defined] async def test_complex_inputs(): @@ -21,11 +21,11 @@ async def test_complex_inputs(): async with Client(mcp) as client: tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]} - result = await client.call_tool( + result = await client.call_tool_mcp( "name_shrimp", {"tank": tank, "extra_names": ["charlie"]} ) - assert len(result) == 1 - assert result[0].text == '[\n "bob",\n "alice",\n "charlie"\n]' # type: ignore[attr-defined] + assert len(result.content) == 1 + assert result.content[0].text == '[\n "bob",\n "alice",\n "charlie"\n]' # type: ignore[attr-defined] async def test_desktop(monkeypatch): @@ -34,9 +34,9 @@ async def test_desktop(monkeypatch): async with Client(mcp) as client: # Test the add function - result = await client.call_tool("add", {"a": 1, "b": 2}) - assert len(result) == 1 - assert result[0].text == "3" # type: ignore[attr-defined] + result = await client.call_tool_mcp("add", {"a": 1, "b": 2}) + assert len(result.content) == 1 + assert result.content[0].text == "3" # type: ignore[attr-defined] async with Client(mcp) as client: result = await client.read_resource(AnyUrl("greeting://rooter12")) @@ -49,9 +49,9 @@ async def test_echo(): from examples.echo import mcp async with Client(mcp) as client: - result = await client.call_tool("echo_tool", {"text": "hello"}) - assert len(result) == 1 - assert result[0].text == "hello" # type: ignore[attr-defined] + result = await client.call_tool_mcp("echo_tool", {"text": "hello"}) + assert len(result.content) == 1 + assert result.content[0].text == "hello" # type: ignore[attr-defined] async with Client(mcp) as client: result = await client.read_resource(AnyUrl("echo://static")) diff --git a/tests/tools/test_tool.py b/tests/tools/test_tool.py index d3da4c02a..f94e73fd3 100644 --- a/tests/tools/test_tool.py +++ b/tests/tools/test_tool.py @@ -1,4 +1,6 @@ import json +from dataclasses import dataclass +from typing import Annotated, Any import pytest from mcp.types import ( @@ -8,7 +10,8 @@ TextContent, TextResourceContents, ) -from pydantic import AnyUrl, BaseModel +from pydantic import AnyUrl, BaseModel, Field, TypeAdapter +from typing_extensions import TypedDict from fastmcp.tools.tool import Tool, _convert_to_content from fastmcp.utilities.types import Audio, File, Image @@ -29,6 +32,13 @@ def add(a: int, b: int) -> int: assert len(tool.parameters["properties"]) == 2 assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" + # With primitive wrapping, int return type becomes object with result property + expected_schema = { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "x-fastmcp-wrap-result": True, + } + assert tool.output_schema == expected_schema async def test_async_function(self): """Test registering and running an async function.""" @@ -100,7 +110,7 @@ def image_tool(data: bytes) -> Image: result = await tool.run({"data": "test.png"}) assert tool.parameters["properties"]["data"]["type"] == "string" - assert isinstance(result[0], ImageContent) + assert isinstance(result.content[0], ImageContent) async def test_tool_with_audio_return(self): def audio_tool(data: bytes) -> Audio: @@ -110,7 +120,7 @@ def audio_tool(data: bytes) -> Audio: result = await tool.run({"data": "test.wav"}) assert tool.parameters["properties"]["data"]["type"] == "string" - assert isinstance(result[0], AudioContent) + assert isinstance(result.content[0], AudioContent) async def test_tool_with_file_return(self): def file_tool(data: bytes) -> File: @@ -120,11 +130,11 @@ def file_tool(data: bytes) -> File: result = await tool.run({"data": "test.bin"}) assert tool.parameters["properties"]["data"]["type"] == "string" - assert len(result) == 1 - assert isinstance(result[0], EmbeddedResource) - assert result[0].type == "resource" - assert hasattr(result[0], "resource") - resource = result[0].resource + assert len(result.content) == 1 + assert isinstance(result.content[0], EmbeddedResource) + assert result.content[0].type == "resource" + assert hasattr(result.content[0], "resource") + resource = result.content[0].resource assert resource.mimeType == "application/octet-stream" def test_non_callable_fn(self): @@ -236,8 +246,490 @@ def process_list(items: list[int]) -> int: tool = Tool.from_function(process_list, serializer=custom_serializer) result = await tool.run(arguments={"items": [1, 2, 3, 4, 5]}) - assert isinstance(result[0], TextContent) - assert result[0].text == "Custom serializer: 15" + # Custom serializer affects unstructured content + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Custom serializer: 15" + # Structured output should have the raw value + assert result.structured_content == {"result": 15} + + +class TestToolFromFunctionOutputSchema: + async def test_no_return_annotation(self): + def func(): + pass + + tool = Tool.from_function(func) + assert tool.output_schema is None + + @pytest.mark.parametrize( + "annotation", + [ + int, + float, + bool, + str, + int | float, + list, + list[int], + list[int | float], + dict, + dict[str, Any], + dict[str, int | None], + tuple[int, str], + set[int], + list[tuple[int, str]], + ], + ) + async def test_simple_return_annotation(self, annotation): + def func() -> annotation: # type: ignore + return 1 + + tool = Tool.from_function(func) + + base_schema = TypeAdapter(annotation).json_schema() + + # Non-object types get wrapped + schema_type = base_schema.get("type") + is_object_type = schema_type == "object" + + if not is_object_type: + # Non-object types get wrapped + expected_schema = { + "type": "object", + "properties": {"result": base_schema}, + "x-fastmcp-wrap-result": True, + } + assert tool.output_schema == expected_schema + else: + # Object types remain unwrapped + assert tool.output_schema == base_schema + + @pytest.mark.parametrize( + "annotation", + [ + AnyUrl, + Annotated[int, Field(ge=1)], + Annotated[int, Field(ge=1)], + ], + ) + async def test_complex_return_annotation(self, annotation): + def func() -> annotation: # type: ignore + return 1 + + tool = Tool.from_function(func) + base_schema = TypeAdapter(annotation).json_schema() + + expected_schema = { + "type": "object", + "properties": {"result": base_schema}, + "x-fastmcp-wrap-result": True, + } + assert tool.output_schema == expected_schema + + async def test_none_return_annotation(self): + def func() -> None: + pass + + tool = Tool.from_function(func) + assert tool.output_schema is None + + async def test_any_return_annotation(self): + def func() -> Any: + return 1 + + tool = Tool.from_function(func) + assert tool.output_schema is None + + @pytest.mark.parametrize( + "annotation, expected", + [ + (Image, ImageContent), + (Audio, AudioContent), + (File, EmbeddedResource), + (Image | int, ImageContent | int), + (Image | Audio, ImageContent | AudioContent), + (list[Image | Audio], list[ImageContent | AudioContent]), + ], + ) + async def test_converted_return_annotation(self, annotation, expected): + def func() -> annotation: # type: ignore + return 1 + + tool = Tool.from_function(func) + # Image, Audio, File types don't generate output schemas since they're converted to content directly + assert tool.output_schema is None + + async def test_dataclass_return_annotation(self): + @dataclass + class Person: + name: str + age: int + + def func() -> Person: + return Person(name="John", age=30) + + tool = Tool.from_function(func) + assert tool.output_schema == TypeAdapter(Person).json_schema() + + async def test_base_model_return_annotation(self): + class Person(BaseModel): + name: str + age: int + + def func() -> Person: + return Person(name="John", age=30) + + tool = Tool.from_function(func) + assert tool.output_schema == TypeAdapter(Person).json_schema() + + async def test_typeddict_return_annotation(self): + class Person(TypedDict): + name: str + age: int + + def func() -> Person: + return Person(name="John", age=30) + + tool = Tool.from_function(func) + assert tool.output_schema == TypeAdapter(Person).json_schema() + + async def test_unserializable_return_annotation(self): + class Unserializable: + def __init__(self, data: Any): + self.data = data + + def func() -> Unserializable: + return Unserializable(data="test") + + tool = Tool.from_function(func) + assert tool.output_schema is None + + async def test_mixed_unserializable_return_annotation(self): + class Unserializable: + def __init__(self, data: Any): + self.data = data + + def func() -> Unserializable | int: + return Unserializable(data="test") + + tool = Tool.from_function(func) + assert tool.output_schema is None + + async def test_provided_output_schema_takes_precedence_over_json_compatible_annotation( + self, + ): + """Test that provided output_schema takes precedence over inferred schema from JSON-compatible annotation.""" + + def func() -> dict[str, int]: + return {"a": 1, "b": 2} + + # Provide a custom output schema that differs from the inferred one + custom_schema = {"type": "object", "description": "Custom schema"} + + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + async def test_provided_output_schema_takes_precedence_over_complex_annotation( + self, + ): + """Test that provided output_schema takes precedence over inferred schema from complex annotation.""" + + def func() -> list[dict[str, int | float]]: + return [{"a": 1, "b": 2.5}] + + # Provide a custom output schema that differs from the inferred one + custom_schema = {"type": "object", "properties": {"custom": {"type": "string"}}} + + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + async def test_provided_output_schema_takes_precedence_over_unserializable_annotation( + self, + ): + """Test that provided output_schema takes precedence over None schema from unserializable annotation.""" + + class Unserializable: + def __init__(self, data: Any): + self.data = data + + def func() -> Unserializable: + return Unserializable(data="test") + + # Provide a custom output schema even though the annotation is unserializable + custom_schema = { + "type": "object", + "properties": {"items": {"type": "array", "items": {"type": "string"}}}, + } + + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + async def test_provided_output_schema_takes_precedence_over_no_annotation(self): + """Test that provided output_schema takes precedence over None schema from no annotation.""" + + def func(): + return "hello" + + # Provide a custom output schema even though there's no return annotation + custom_schema = { + "type": "object", + "properties": {"value": {"type": "number", "minimum": 0}}, + } + + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + async def test_provided_output_schema_takes_precedence_over_converted_annotation( + self, + ): + """Test that provided output_schema takes precedence over converted schema from Image/Audio/File annotations.""" + + def func() -> Image: + return Image(data=b"test") + + # Provide a custom output schema that differs from the converted ImageContent schema + custom_schema = { + "type": "object", + "properties": {"custom_image": {"type": "string"}}, + } + + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + async def test_provided_output_schema_takes_precedence_over_union_annotation(self): + """Test that provided output_schema takes precedence over inferred schema from union annotation.""" + + def func() -> str | int | None: + return "hello" + + # Provide a custom output schema that differs from the inferred union schema + custom_schema = {"type": "object", "properties": {"flag": {"type": "boolean"}}} + + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + async def test_provided_output_schema_takes_precedence_over_pydantic_annotation( + self, + ): + """Test that provided output_schema takes precedence over inferred schema from Pydantic model annotation.""" + + class Person(BaseModel): + name: str + age: int + + def func() -> Person: + return Person(name="John", age=30) + + # Provide a custom output schema that differs from the inferred Person schema + custom_schema = { + "type": "object", + "properties": {"numbers": {"type": "array", "items": {"type": "number"}}}, + } + + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + async def test_output_schema_false_allows_automatic_structured_content(self): + """Test that output_schema=False still allows automatic structured content for dict-like objects.""" + + def func() -> dict[str, str]: + return {"message": "Hello, world!"} + + tool = Tool.from_function(func, output_schema=False) + assert tool.output_schema is None + + result = await tool.run({}) + # Dict objects automatically become structured content even without schema + assert result.structured_content == {"message": "Hello, world!"} + assert len(result.content) == 1 + assert result.content[0].text == '{\n "message": "Hello, world!"\n}' # type: ignore[attr-defined] + + async def test_output_schema_none_disables_structured_content(self): + """Test that output_schema=None explicitly disables structured content.""" + + def func() -> int: + return 42 + + tool = Tool.from_function(func, output_schema=None) + assert tool.output_schema is None + + result = await tool.run({}) + assert result.structured_content is None + assert len(result.content) == 1 + assert result.content[0].text == "42" # type: ignore[attr-defined] + + async def test_output_schema_inferred_when_not_specified(self): + """Test that output schema is inferred when not explicitly specified.""" + + def func() -> int: + return 42 + + # Don't specify output_schema - should infer and wrap + tool = Tool.from_function(func) + expected_schema = { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "x-fastmcp-wrap-result": True, + } + assert tool.output_schema == expected_schema + + result = await tool.run({}) + assert result.structured_content == {"result": 42} + + async def test_explicit_object_schema_with_dict_return(self): + """Test that explicit object schemas work when function returns a dict.""" + + def func() -> dict[str, int]: + return {"value": 42} + + # Provide explicit object schema + explicit_schema = { + "type": "object", + "properties": {"value": {"type": "integer", "minimum": 0}}, + } + tool = Tool.from_function(func, output_schema=explicit_schema) + assert tool.output_schema == explicit_schema # Schema not wrapped + assert tool.output_schema and "x-fastmcp-wrap-result" not in tool.output_schema + + result = await tool.run({}) + # Dict result with object schema is used directly + assert result.structured_content == {"value": 42} + assert result.content[0].text == '{\n "value": 42\n}' # type: ignore[attr-defined] + + async def test_explicit_object_schema_with_non_dict_return_fails(self): + """Test that explicit object schemas fail when function returns non-dict.""" + + def func() -> int: + return 42 + + # Provide explicit object schema but return non-dict + explicit_schema = { + "type": "object", + "properties": {"value": {"type": "integer"}}, + } + tool = Tool.from_function(func, output_schema=explicit_schema) + + # Should fail because int is not dict-compatible with object schema + with pytest.raises(ValueError, match="structured_content must be a dict"): + await tool.run({}) + + async def test_object_output_schema_not_wrapped(self): + """Test that object-type output schemas are never wrapped.""" + + def func() -> dict[str, int]: + return {"value": 42} + + # Object schemas should never be wrapped, even when inferred + tool = Tool.from_function(func) + expected_schema = TypeAdapter(dict[str, int]).json_schema() + assert tool.output_schema == expected_schema # Not wrapped + assert tool.output_schema and "x-fastmcp-wrap-result" not in tool.output_schema + + result = await tool.run({}) + assert result.structured_content == {"value": 42} # Direct value + + async def test_structured_content_interaction_with_wrapping(self): + """Test that structured content works correctly with schema wrapping.""" + + def func() -> str: + return "hello" + + # Inferred schema should wrap string type + tool = Tool.from_function(func) + expected_schema = { + "type": "object", + "properties": {"result": {"type": "string"}}, + "x-fastmcp-wrap-result": True, + } + assert tool.output_schema == expected_schema + + result = await tool.run({}) + # Unstructured content + assert len(result.content) == 1 + assert result.content[0].text == "hello" # type: ignore[attr-defined] + # Structured content should be wrapped + assert result.structured_content == {"result": "hello"} + + async def test_structured_content_with_explicit_object_schema(self): + """Test structured content with explicit object schema.""" + + def func() -> dict[str, str]: + return {"greeting": "hello"} + + # Provide explicit object schema + explicit_schema = { + "type": "object", + "properties": {"greeting": {"type": "string"}}, + "required": ["greeting"], + } + tool = Tool.from_function(func, output_schema=explicit_schema) + assert tool.output_schema == explicit_schema + + result = await tool.run({}) + # Should use direct value since explicit schema doesn't have wrap marker + assert result.structured_content == {"greeting": "hello"} + + async def test_structured_content_with_custom_wrapper_schema(self): + """Test structured content with custom schema that includes wrap marker.""" + + def func() -> str: + return "world" + + # Custom schema with wrap marker + custom_schema = { + "type": "object", + "properties": {"message": {"type": "string"}}, + "x-fastmcp-wrap-result": True, + } + tool = Tool.from_function(func, output_schema=custom_schema) + assert tool.output_schema == custom_schema + + result = await tool.run({}) + # Should wrap with "result" key due to wrap marker + assert result.structured_content == {"result": "world"} + + async def test_none_vs_false_output_schema_behavior(self): + """Test the difference between None and False for output_schema.""" + + def func() -> int: + return 123 + + # None should disable + tool_none = Tool.from_function(func, output_schema=None) + assert tool_none.output_schema is None + + # False should also disable + tool_false = Tool.from_function(func, output_schema=False) + assert tool_false.output_schema is None + + # Both should have same behavior + result_none = await tool_none.run({}) + result_false = await tool_false.run({}) + + assert result_none.structured_content is None + assert result_false.structured_content is None + assert result_none.content[0].text == result_false.content[0].text == "123" # type: ignore[attr-defined] + + async def test_non_object_output_schema_raises_error(self): + """Test that providing a non-object output schema raises a ValueError.""" + + def func() -> int: + return 42 + + # Test various non-object schemas that should raise errors + non_object_schemas = [ + {"type": "string"}, + {"type": "integer", "minimum": 0}, + {"type": "number"}, + {"type": "boolean"}, + {"type": "array", "items": {"type": "string"}}, + ] + + for schema in non_object_schemas: + with pytest.raises( + ValueError, match='Output schemas must have "type" set to "object"' + ): + Tool.from_function(func, output_schema=schema) class TestConvertResultToContent: @@ -537,3 +1029,199 @@ def test_process_as_single_item_flag(self): 1, {"type": "text", "text": "hello", "annotations": None, "_meta": None}, ] + + +class TestAutomaticStructuredContent: + """Tests for automatic structured content generation based on return types.""" + + async def test_dict_return_creates_structured_content_without_schema(self): + """Test that dict returns automatically create structured content even without output schema.""" + + def get_user_data(user_id: str) -> dict: + return {"name": "Alice", "age": 30, "active": True} + + # No explicit output schema provided + tool = Tool.from_function(get_user_data) + + result = await tool.run({"user_id": "123"}) + + # Should have both content and structured content + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.structured_content == {"name": "Alice", "age": 30, "active": True} + + async def test_dataclass_return_creates_structured_content_without_schema(self): + """Test that dataclass returns automatically create structured content even without output schema.""" + + @dataclass + class UserProfile: + name: str + age: int + email: str + + def get_profile(user_id: str) -> UserProfile: + return UserProfile(name="Bob", age=25, email="bob@example.com") + + # No explicit output schema, but dataclass should still create structured content + tool = Tool.from_function(get_profile, output_schema=False) + + result = await tool.run({"user_id": "456"}) + + # Should have both content and structured content + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + # Dataclass should serialize to dict + assert result.structured_content == { + "name": "Bob", + "age": 25, + "email": "bob@example.com", + } + + async def test_pydantic_model_return_creates_structured_content_without_schema( + self, + ): + """Test that Pydantic model returns automatically create structured content even without output schema.""" + + class UserData(BaseModel): + username: str + score: int + verified: bool + + def get_user_stats(user_id: str) -> UserData: + return UserData(username="charlie", score=100, verified=True) + + # Explicitly disable output schema to test automatic structured content + tool = Tool.from_function(get_user_stats, output_schema=False) + + result = await tool.run({"user_id": "789"}) + + # Should have both content and structured content + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + # Pydantic model should serialize to dict + assert result.structured_content == { + "username": "charlie", + "score": 100, + "verified": True, + } + + async def test_int_return_no_structured_content_without_schema(self): + """Test that int returns don't create structured content without output schema.""" + + def calculate_sum(a: int, b: int): + """No return annotation.""" + return a + b + + # No output schema + tool = Tool.from_function(calculate_sum) + + result = await tool.run({"a": 5, "b": 3}) + + # Should only have content, no structured content + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "8" + assert result.structured_content is None + + async def test_str_return_no_structured_content_without_schema(self): + """Test that str returns don't create structured content without output schema.""" + + def get_greeting(name: str): + """No return annotation.""" + return f"Hello, {name}!" + + # No output schema + tool = Tool.from_function(get_greeting) + + result = await tool.run({"name": "World"}) + + # Should only have content, no structured content + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Hello, World!" + assert result.structured_content is None + + async def test_list_return_no_structured_content_without_schema(self): + """Test that list returns don't create structured content without output schema.""" + + def get_numbers(): + """No return annotation.""" + return [1, 2, 3, 4, 5] + + # No output schema + tool = Tool.from_function(get_numbers) + + result = await tool.run({}) + + # Should only have content, no structured content + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.structured_content is None + + async def test_int_return_with_schema_creates_structured_content(self): + """Test that int returns DO create structured content when there's an output schema.""" + + def calculate_sum(a: int, b: int) -> int: + """With return annotation.""" + return a + b + + # Output schema should be auto-generated from annotation + tool = Tool.from_function(calculate_sum) + assert tool.output_schema is not None + + result = await tool.run({"a": 5, "b": 3}) + + # Should have both content and structured content + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "8" + assert result.structured_content == {"result": 8} + + async def test_client_automatic_deserialization_with_dict_result(self): + """Test that clients automatically deserialize dict results from structured content.""" + from fastmcp import FastMCP + from fastmcp.client import Client + + mcp = FastMCP() + + @mcp.tool + def get_user_info(user_id: str) -> dict: + return {"name": "Alice", "age": 30, "active": True} + + async with Client(mcp) as client: + result = await client.call_tool("get_user_info", {"user_id": "123"}) + + # Client should provide the deserialized data + assert result.data == {"name": "Alice", "age": 30, "active": True} + assert result.structured_content == { + "name": "Alice", + "age": 30, + "active": True, + } + assert len(result.content) == 1 + + async def test_client_automatic_deserialization_with_dataclass_result(self): + """Test that clients automatically deserialize dataclass results from structured content.""" + from fastmcp import FastMCP + from fastmcp.client import Client + + mcp = FastMCP() + + @dataclass + class UserProfile: + name: str + age: int + verified: bool + + @mcp.tool + def get_profile(user_id: str) -> UserProfile: + return UserProfile(name="Bob", age=25, verified=True) + + async with Client(mcp) as client: + result = await client.call_tool("get_profile", {"user_id": "456"}) + + # Client should deserialize back to a dataclass (type name will match) + assert result.data.__class__.__name__ == "UserProfile" + assert result.data.name == "Bob" + assert result.data.age == 25 + assert result.data.verified is True diff --git a/tests/tools/test_tool_manager.py b/tests/tools/test_tool_manager.py index de6ab041e..5ebf50c95 100644 --- a/tests/tools/test_tool_manager.py +++ b/tests/tools/test_tool_manager.py @@ -125,7 +125,8 @@ def image_tool(data: bytes) -> Image: tool = await manager.get_tool("image_tool") result = await tool.run({"data": "test.png"}) assert tool.parameters["properties"]["data"]["type"] == "string" - assert isinstance(result[0], ImageContent) + assert isinstance(result.content[0], ImageContent) + assert result.structured_content is None def test_add_noncallable_tool(self): manager = ToolManager() @@ -353,7 +354,8 @@ def add(a: int, b: int) -> int: manager.add_tool(tool) result = await manager.call_tool("add", {"a": 1, "b": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + assert result.content[0].text == "3" # type: ignore[attr-defined] + assert result.structured_content == {"result": 3} async def test_call_async_tool(self): async def double(n: int) -> int: @@ -364,7 +366,8 @@ async def double(n: int) -> int: tool = Tool.from_function(double) manager.add_tool(tool) result = await manager.call_tool("double", {"n": 5}) - assert result[0].text == "10" # type: ignore[attr-defined] + assert result.content[0].text == "10" # type: ignore[attr-defined] + assert result.structured_content == {"result": 10} async def test_call_tool_callable_object(self): class Adder: @@ -378,7 +381,8 @@ def __call__(self, x: int, y: int) -> int: tool = Tool.from_function(Adder()) manager.add_tool(tool) result = await manager.call_tool("Adder", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + assert result.content[0].text == "3" # type: ignore[attr-defined] + assert result.structured_content == {"result": 3} async def test_call_tool_callable_object_async(self): class Adder: @@ -392,7 +396,8 @@ async def __call__(self, x: int, y: int) -> int: tool = Tool.from_function(Adder()) manager.add_tool(tool) result = await manager.call_tool("Adder", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + assert result.content[0].text == "3" # type: ignore[attr-defined] + assert result.structured_content == {"result": 3} async def test_call_tool_with_default_args(self): def add(a: int, b: int = 1) -> int: @@ -404,7 +409,8 @@ def add(a: int, b: int = 1) -> int: manager.add_tool(tool) result = await manager.call_tool("add", {"a": 1}) - assert result[0].text == "2" # type: ignore[attr-defined] + assert result.content[0].text == "2" # type: ignore[attr-defined] + assert result.structured_content == {"result": 2} async def test_call_tool_with_missing_args(self): def add(a: int, b: int) -> int: @@ -431,7 +437,8 @@ def sum_vals(vals: list[int]) -> int: manager.add_tool(tool) result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) - assert result[0].text == "6" # type: ignore[attr-defined] + assert result.content[0].text == "6" # type: ignore[attr-defined] + assert result.structured_content == {"result": 6} async def test_call_tool_with_list_str_or_str_input(self): def concat_strs(vals: list[str] | str) -> str: @@ -443,10 +450,12 @@ def concat_strs(vals: list[str] | str) -> str: # Try both with plain python object and with JSON list result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) - assert result[0].text == "abc" # type: ignore[attr-defined] + assert result.content[0].text == "abc" # type: ignore[attr-defined] + assert result.structured_content == {"result": "abc"} result = await manager.call_tool("concat_strs", {"vals": "a"}) - assert result[0].text == "a" # type: ignore[attr-defined] + assert result.content[0].text == "a" # type: ignore[attr-defined] + assert result.structured_content == {"result": "a"} async def test_call_tool_with_complex_model(self): class MyShrimpTank(BaseModel): @@ -477,7 +486,8 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context | None) -> list[str]: }, ) - assert result[0].text == '[\n "rex",\n "gertrude"\n]' # type: ignore[attr-defined] + assert result.content[0].text == '[\n "rex",\n "gertrude"\n]' # type: ignore[attr-defined] + assert result.structured_content == {"result": ["rex", "gertrude"]} async def test_call_tool_with_custom_serializer(self): """Test that a custom serializer provided to FastMCP is used by tools.""" @@ -496,7 +506,8 @@ def get_data() -> dict: return {"key": "value", "number": 123} result = await manager.call_tool("get_data", {}) - assert result[0].text == 'CUSTOM:{"key": "value", "number": 123}' # type: ignore[attr-defined] + assert result.content[0].text == 'CUSTOM:{"key": "value", "number": 123}' # type: ignore[attr-defined] + assert result.structured_content == {"key": "value", "number": 123} async def test_call_tool_with_list_result_custom_serializer(self): """Test that a custom serializer provided to FastMCP is used by tools that return lists.""" @@ -518,9 +529,15 @@ def get_data() -> list[dict]: result = await manager.call_tool("get_data", {}) assert ( - result[0].text # type: ignore[attr-defined] + result.content[0].text # type: ignore[attr-defined] == 'CUSTOM:[{"key": "value", "number": 123}, {"key": "value2", "number": 456}]' # type: ignore[attr-defined] ) + assert result.structured_content == { + "result": [ + {"key": "value", "number": 123}, + {"key": "value2", "number": 456}, + ] + } async def test_custom_serializer_fallback_on_error(self): """Test that a broken custom serializer gracefully falls back.""" @@ -538,7 +555,11 @@ def get_data() -> uuid.UUID: return uuid_result result = await manager.call_tool("get_data", {}) - assert result[0].text == pydantic_core.to_json(uuid_result).decode() # type: ignore[attr-defined] + assert ( + result.content[0].text # type: ignore[attr-defined] + == pydantic_core.to_json(uuid_result).decode() + ) + assert result.structured_content == {"result": str(uuid_result)} class TestToolSchema: @@ -608,7 +629,8 @@ def tool_with_context(x: int, ctx: Context) -> str: async with context: result = await manager.call_tool("tool_with_context", {"x": 42}) - assert result[0].text == "42" # type: ignore[attr-defined] + assert result.content[0].text == "42" # type: ignore[attr-defined] + assert result.structured_content == {"result": "42"} async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" @@ -626,7 +648,8 @@ async def async_tool(x: int, ctx: Context) -> str: async with context: result = await manager.call_tool("async_tool", {"x": 42}) - assert result[0].text == "42" # type: ignore[attr-defined] + assert result.content[0].text == "42" # type: ignore[attr-defined] + assert result.structured_content == {"result": "42"} async def test_context_optional(self): """Test that context is optional when calling tools.""" @@ -644,7 +667,8 @@ def tool_with_context(x: int, ctx: Context | None) -> int: async with context: result = await manager.call_tool("tool_with_context", {"x": 42}) - assert result[0].text == "42" # type: ignore[attr-defined] + assert result.content[0].text == "42" # type: ignore[attr-defined] + assert result.structured_content == {"result": 42} def test_parameterized_context_parameter_detection(self): """Test that context parameters are properly detected in @@ -752,7 +776,8 @@ def multiply(a: int, b: int) -> int: # Tool should be callable by its custom name result = await manager.call_tool("custom_multiply", {"a": 5, "b": 3}) - assert result[0].text == "15" # type: ignore[attr-defined] + assert result.content[0].text == "15" # type: ignore[attr-defined] + assert result.structured_content == {"result": 15} # Original name should not be registered with pytest.raises(NotFoundError, match="Tool 'multiply' not found"): diff --git a/tests/tools/test_tool_transform.py b/tests/tools/test_tool_transform.py index 0498c6c5c..5e976b381 100644 --- a/tests/tools/test_tool_transform.py +++ b/tests/tools/test_tool_transform.py @@ -4,14 +4,15 @@ import pytest from dirty_equals import IsList -from pydantic import BaseModel, Field +from mcp.types import TextContent +from pydantic import BaseModel, Field, TypeAdapter from typing_extensions import TypedDict from fastmcp import FastMCP from fastmcp.client.client import Client from fastmcp.exceptions import ToolError from fastmcp.tools import Tool, forward, forward_raw -from fastmcp.tools.tool import FunctionTool +from fastmcp.tools.tool import FunctionTool, ToolResult from fastmcp.tools.tool_transform import ArgTransform, TransformedTool @@ -52,7 +53,8 @@ async def test_tool_defaults_are_maintained_on_unmapped_args(add_tool): add_tool, transform_args={"old_x": ArgTransform(name="new_x")} ) result = await new_tool.run(arguments={"new_x": 1}) - assert result[0].text == "11" # type: ignore[attr-defined] + # The parent tool returns int which gets wrapped as structured output + assert result.structured_content == {"result": 11} async def test_tool_defaults_are_maintained_on_mapped_args(add_tool): @@ -60,7 +62,8 @@ async def test_tool_defaults_are_maintained_on_mapped_args(add_tool): add_tool, transform_args={"old_y": ArgTransform(name="new_y")} ) result = await new_tool.run(arguments={"old_x": 1}) - assert result[0].text == "11" # type: ignore[attr-defined] + # The parent tool returns int which gets wrapped as structured output + assert result.structured_content == {"result": 11} def test_tool_change_arg_name(add_tool): @@ -87,7 +90,7 @@ async def test_tool_drop_arg(add_tool): ) assert sorted(new_tool.parameters["properties"]) == ["old_x"] result = await new_tool.run(arguments={"old_x": 1}) - assert result[0].text == "11" # type: ignore[attr-defined] + assert result.structured_content == {"result": 11} async def test_dropped_args_error_if_provided(add_tool): @@ -109,7 +112,7 @@ async def test_hidden_arg_with_constant_default(add_tool): assert sorted(new_tool.parameters["properties"]) == ["old_x"] # Should pass old_x=5 and old_y=20 to parent result = await new_tool.run(arguments={"old_x": 5}) - assert result[0].text == "25" # type: ignore[attr-defined] + assert result.structured_content == {"result": 25} async def test_hidden_arg_without_default_uses_parent_default(add_tool): @@ -121,13 +124,14 @@ async def test_hidden_arg_without_default_uses_parent_default(add_tool): assert sorted(new_tool.parameters["properties"]) == ["old_x"] # Should pass old_x=3 and let parent use its default old_y=10 result = await new_tool.run(arguments={"old_x": 3}) - assert result[0].text == "13" # type: ignore[attr-defined] + assert result.content[0].text == "13" # type: ignore[attr-defined] + assert result.structured_content == {"result": 13} async def test_mixed_hidden_args_with_custom_function(add_tool): """Test custom function with both hidden constant and hidden default parameters.""" - async def custom_fn(visible_x: int) -> int: + async def custom_fn(visible_x: int) -> ToolResult: # This custom function should receive the transformed visible parameter # and the hidden parameters should be automatically handled result = await forward(visible_x=visible_x) @@ -146,7 +150,8 @@ async def custom_fn(visible_x: int) -> int: assert sorted(new_tool.parameters["properties"]) == ["visible_x"] # Should pass visible_x=7 as old_x=7 and old_y=25 to parent result = await new_tool.run(arguments={"visible_x": 7}) - assert result[0].text == "32" # type: ignore[attr-defined] + assert result.content[0].text == "32" # type: ignore[attr-defined] + assert result.structured_content == {"result": 32} async def test_hide_required_param_without_default_raises_error(): @@ -184,13 +189,13 @@ def tool_with_required_param(required_param: int, optional_param: int = 10) -> i assert sorted(new_tool.parameters["properties"]) == ["optional_param"] # Should pass required_param=5 and optional_param=20 to parent result = await new_tool.run(arguments={"optional_param": 20}) - assert result[0].text == "25" # type: ignore[attr-defined] + assert result.structured_content == {"result": 25} async def test_forward_with_argument_mapping(add_tool): """Test that forward() applies argument mapping correctly.""" - async def custom_fn(new_x: int, new_y: int = 5) -> int: + async def custom_fn(new_x: int, new_y: int = 5) -> ToolResult: return await forward(new_x=new_x, new_y=new_y) new_tool = Tool.from_tool( @@ -203,11 +208,12 @@ async def custom_fn(new_x: int, new_y: int = 5) -> int: ) result = await new_tool.run(arguments={"new_x": 2, "new_y": 3}) - assert result[0].text == "5" # type: ignore[attr-defined] + assert result.content[0].text == "5" # type: ignore[attr-defined] + assert result.structured_content == {"result": 5} async def test_forward_with_incorrect_args_raises_error(add_tool): - async def custom_fn(new_x: int, new_y: int = 5) -> int: + async def custom_fn(new_x: int, new_y: int = 5) -> ToolResult: # the forward should use the new args, not the old ones return await forward(old_x=new_x, old_y=new_y) @@ -228,7 +234,7 @@ async def custom_fn(new_x: int, new_y: int = 5) -> int: async def test_forward_raw_without_argument_mapping(add_tool): """Test that forward_raw() calls parent directly without mapping.""" - async def custom_fn(new_x: int, new_y: int = 5) -> int: + async def custom_fn(new_x: int, new_y: int = 5) -> ToolResult: # Call parent directly with original argument names result = await forward_raw(old_x=new_x, old_y=new_y) return result @@ -243,17 +249,19 @@ async def custom_fn(new_x: int, new_y: int = 5) -> int: ) result = await new_tool.run(arguments={"new_x": 2, "new_y": 3}) - assert result[0].text == "5" # type: ignore[attr-defined] + assert result.content[0].text == "5" # type: ignore[attr-defined] + assert result.structured_content == {"result": 5} async def test_custom_fn_with_kwargs_and_no_transform_args(add_tool): async def custom_fn(extra: int, **kwargs) -> int: sum = await forward(**kwargs) - return int(sum[0].text) + extra # type: ignore[attr-defined] + return int(sum.content[0].text) + extra # type: ignore[attr-defined] new_tool = Tool.from_tool(add_tool, transform_fn=custom_fn) result = await new_tool.run(arguments={"extra": 1, "old_x": 2, "old_y": 3}) - assert result[0].text == "6" # type: ignore[attr-defined] + assert result.content[0].text == "6" # type: ignore[attr-defined] + assert result.structured_content == {"result": 6} assert new_tool.parameters["required"] == IsList( "extra", "old_x", check_order=False ) @@ -263,20 +271,21 @@ async def custom_fn(extra: int, **kwargs) -> int: async def test_fn_with_kwargs_passes_through_original_args(add_tool): - async def custom_fn(new_y: int = 5, **kwargs) -> int: + async def custom_fn(new_y: int = 5, **kwargs) -> ToolResult: assert kwargs == {"old_y": 3} result = await forward(old_x=new_y, **kwargs) return result new_tool = Tool.from_tool(add_tool, transform_fn=custom_fn) result = await new_tool.run(arguments={"new_y": 2, "old_y": 3}) - assert result[0].text == "5" # type: ignore[attr-defined] + assert result.content[0].text == "5" # type: ignore[attr-defined] + assert result.structured_content == {"result": 5} async def test_fn_with_kwargs_receives_transformed_arg_names(add_tool): """Test that **kwargs receives arguments with their transformed names from transform_args.""" - async def custom_fn(new_x: int, **kwargs) -> int: + async def custom_fn(new_x: int, **kwargs) -> ToolResult: # kwargs should contain 'old_y': 3 (transformed name), not 'old_y': 3 (original name) assert kwargs == {"old_y": 3} result = await forward(new_x=new_x, **kwargs) @@ -288,13 +297,16 @@ async def custom_fn(new_x: int, **kwargs) -> int: transform_args={"old_x": ArgTransform(name="new_x")}, ) result = await new_tool.run(arguments={"new_x": 2, "old_y": 3}) - assert result[0].text == "5" # type: ignore[attr-defined] + assert result.content[0].text == "5" # type: ignore[attr-defined] + assert result.structured_content == {"result": 5} async def test_fn_with_kwargs_handles_partial_explicit_args(add_tool): """Test that function can explicitly handle some transformed args while others pass through kwargs.""" - async def custom_fn(new_x: int, some_other_param: str = "default", **kwargs) -> int: + async def custom_fn( + new_x: int, some_other_param: str = "default", **kwargs + ) -> ToolResult: # x is explicitly handled, y should come through kwargs with transformed name assert kwargs == {"old_y": 7} result = await forward(new_x=new_x, **kwargs) @@ -308,13 +320,14 @@ async def custom_fn(new_x: int, some_other_param: str = "default", **kwargs) -> result = await new_tool.run( arguments={"new_x": 3, "old_y": 7, "some_other_param": "test"} ) - assert result[0].text == "10" # type: ignore[attr-defined] + assert result.content[0].text == "10" # type: ignore[attr-defined] + assert result.structured_content == {"result": 10} async def test_fn_with_kwargs_mixed_mapped_and_unmapped_args(add_tool): """Test **kwargs behavior with mix of mapped and unmapped arguments.""" - async def custom_fn(new_x: int, **kwargs) -> int: + async def custom_fn(new_x: int, **kwargs) -> ToolResult: # new_x is explicitly handled, old_y should pass through kwargs with original name (unmapped) assert kwargs == {"old_y": 5} result = await forward(new_x=new_x, **kwargs) @@ -326,13 +339,14 @@ async def custom_fn(new_x: int, **kwargs) -> int: transform_args={"old_x": ArgTransform(name="new_x")}, ) # only map 'a' result = await new_tool.run(arguments={"new_x": 1, "old_y": 5}) - assert result[0].text == "6" # type: ignore[attr-defined] + assert result.content[0].text == "6" # type: ignore[attr-defined] + assert result.structured_content == {"result": 6} async def test_fn_with_kwargs_dropped_args_not_in_kwargs(add_tool): """Test that dropped arguments don't appear in **kwargs.""" - async def custom_fn(new_x: int, **kwargs) -> int: + async def custom_fn(new_x: int, **kwargs) -> ToolResult: # 'b' was dropped, so kwargs should be empty assert kwargs == {} # Can't use 'old_y' since it was dropped, so just use 'old_x' mapped to 'new_x' @@ -349,7 +363,7 @@ async def custom_fn(new_x: int, **kwargs) -> int: ) # drop 'old_y' result = await new_tool.run(arguments={"new_x": 8}) # 8 + 10 (default value of b in parent) - assert result[0].text == "18" # type: ignore[attr-defined] + assert result.content[0].text == "18" # type: ignore[attr-defined] async def test_forward_outside_context_raises_error(): @@ -469,18 +483,18 @@ async def test_tool_transform_chaining(add_tool): tool2 = Tool.from_tool(tool1, transform_args={"x": ArgTransform(name="final_x")}) result = await tool2.run(arguments={"final_x": 5}) - assert result[0].text == "15" # type: ignore[attr-defined] + assert result.content[0].text == "15" # type: ignore[attr-defined] # Transform tool1 with custom function that handles all parameters async def custom(final_x: int, **kwargs) -> str: result = await forward(final_x=final_x, **kwargs) - return f"custom {result[0].text}" # Extract text from content + return f"custom {result.content[0].text}" # Extract text from content # type: ignore[attr-defined] tool3 = Tool.from_tool( tool1, transform_fn=custom, transform_args={"x": ArgTransform(name="final_x")} ) result = await tool3.run(arguments={"final_x": 3, "old_y": 5}) - assert result[0].text == "custom 8" # type: ignore[attr-defined] + assert result.content[0].text == "custom 8" # type: ignore[attr-defined] class MyModel(BaseModel): @@ -608,7 +622,7 @@ def base(x: int, y: str = "base_default") -> str: # Function signature has different types/defaults than ArgTransform async def custom_fn(x: str = "function_default", **kwargs) -> str: result = await forward(x=x, **kwargs) - return f"custom: {result}" + return f"custom: {result.content[0].text}" # type: ignore[attr-defined] tool = Tool.from_tool( base, @@ -635,7 +649,7 @@ async def custom_fn(x: str = "function_default", **kwargs) -> str: # Test it works at runtime result = await tool.run(arguments={"y": "test"}) # Should use ArgTransform default of 42 - assert "42: test" in result[0].text # type: ignore[attr-defined] + assert "42: test" in result.content[0].text # type: ignore[attr-defined] def test_arg_transform_combined_attributes(): @@ -680,7 +694,7 @@ async def custom_fn(x: str, y: int = 10) -> str: # Convert string back to int for the original function result = await forward_raw(x=int(x), y=y) # Extract the text from the result - result_text = result[0].text + result_text = result.content[0].text # type: ignore[attr-defined] return f"String input '{x}' converted to result: {result_text}" tool = Tool.from_tool( @@ -692,8 +706,8 @@ async def custom_fn(x: str, y: int = 10) -> str: # Test it works with string input result = await tool.run(arguments={"x": "5", "y": 3}) - assert "String input '5'" in result[0].text # type: ignore[attr-defined] - assert "result: 8" in result[0].text # type: ignore[attr-defined] + assert "String input '5'" in result.content[0].text # type: ignore[attr-defined] + assert "result: 8" in result.content[0].text # type: ignore[attr-defined] class TestProxy: @@ -728,7 +742,7 @@ async def test_transform_proxy(self, proxy_server: FastMCP): async with Client(proxy_server) as client: # The tool should be registered with its transformed name result = await client.call_tool("add_transformed", {"new_x": 1, "old_y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + assert result.content[0].text == "3" # type: ignore[attr-defined] async def test_arg_transform_default_factory(): @@ -751,7 +765,7 @@ def base_tool(x: int, timestamp: float) -> str: # Should work without providing timestamp (gets value from factory) result = await new_tool.run(arguments={"x": 42}) - assert result[0].text == "42_12345.0" # type: ignore[attr-defined] + assert result.content[0].text == "42_12345.0" # type: ignore[attr-defined] async def test_arg_transform_default_factory_called_each_time(): @@ -779,11 +793,11 @@ def base_tool(x: int, counter: int = 0) -> str: # First call result1 = await new_tool.run(arguments={"x": 1}) - assert result1[0].text == "1_1" # type: ignore[attr-defined] + assert result1.content[0].text == "1_1" # type: ignore[attr-defined] # Second call should get a different value result2 = await new_tool.run(arguments={"x": 2}) - assert result2[0].text == "2_2" # type: ignore[attr-defined] + assert result2.content[0].text == "2_2" # type: ignore[attr-defined] async def test_arg_transform_hidden_with_default_factory(): @@ -808,7 +822,7 @@ def make_request_id(): # Should pass hidden request_id with factory value result = await new_tool.run(arguments={"x": 42}) - assert result[0].text == "42_req_123" # type: ignore[attr-defined] + assert result.content[0].text == "42_req_123" # type: ignore[attr-defined] async def test_arg_transform_default_and_factory_raises_error(): @@ -845,7 +859,7 @@ def base_tool(optional_param: int = 42) -> str: # Should work when parameter is provided result = await new_tool.run(arguments={"optional_param": 100}) - assert result[0].text == "value: 100" # type: ignore + assert result.content[0].text == "value: 100" # type: ignore # Should fail when parameter is not provided with pytest.raises(TypeError, match="Missing required argument"): @@ -892,7 +906,7 @@ def base_tool(optional_param: int = 42) -> str: # Should work with new name result = await new_tool.run(arguments={"new_param": 200}) - assert result[0].text == "value: 200" # type: ignore + assert result.content[0].text == "value: 200" # type: ignore async def test_arg_transform_required_true_with_default_raises_error(): @@ -934,7 +948,7 @@ def base_tool(required_param: int, optional_param: int = 42) -> str: # Should work as expected result = await new_tool.run(arguments={"req": 1}) - assert result[0].text == "values: 1, 42" # type: ignore + assert result.content[0].text == "values: 1, 42" # type: ignore async def test_arg_transform_hide_and_required_raises_error(): @@ -966,7 +980,7 @@ def add(x: int, y: int = 10) -> int: assert {tool.name for tool in tools} == {"new_add"} result = await client.call_tool("new_add", {"x": 1, "y": 2}) - assert result[0].text == "3" # type: ignore[attr-defined] + assert result.content[0].text == "3" # type: ignore[attr-defined] with pytest.raises(ToolError): await client.call_tool("add", {"x": 1, "y": 2}) @@ -1019,3 +1033,266 @@ def test_arg_transform_examples_in_schema(add_tool): ) prop3 = get_property(new_tool3, "old_x") assert "examples" not in prop3 + + +class TestTransformToolOutputSchema: + """Test output schema handling in transformed tools.""" + + @pytest.fixture + def base_string_tool(self) -> FunctionTool: + """Tool that returns a string (gets wrapped).""" + + def string_tool(x: int) -> str: + return f"Result: {x}" + + return Tool.from_function(string_tool) + + @pytest.fixture + def base_dict_tool(self) -> FunctionTool: + """Tool that returns a dict (object type, not wrapped).""" + + def dict_tool(x: int) -> dict[str, int]: + return {"value": x} + + return Tool.from_function(dict_tool) + + def test_transform_inherits_parent_output_schema(self, base_string_tool): + """Test that transformed tool inherits parent's output schema by default.""" + new_tool = Tool.from_tool(base_string_tool) + + # Should inherit parent's wrapped string schema + expected_schema = { + "type": "object", + "properties": {"result": {"type": "string"}}, + "x-fastmcp-wrap-result": True, + } + assert new_tool.output_schema == expected_schema + assert new_tool.output_schema == base_string_tool.output_schema + + def test_transform_with_explicit_output_schema_false(self, base_string_tool): + """Test that output_schema=False disables structured output.""" + new_tool = Tool.from_tool(base_string_tool, output_schema=False) + + assert new_tool.output_schema is None + + async def test_transform_output_schema_false_runtime(self, base_string_tool): + """Test runtime behavior with output_schema=False.""" + new_tool = Tool.from_tool(base_string_tool, output_schema=False) + + # Debug: check that output_schema is actually None + assert new_tool.output_schema is None, ( + f"Expected None, got {new_tool.output_schema}" + ) + + result = await new_tool.run({"x": 5}) + assert result.structured_content is None + assert result.content[0].text == "Result: 5" # type: ignore[attr-defined] + + def test_transform_with_explicit_output_schema_dict(self, base_string_tool): + """Test that explicit output schema overrides parent.""" + custom_schema = { + "type": "object", + "properties": {"message": {"type": "string"}}, + } + new_tool = Tool.from_tool(base_string_tool, output_schema=custom_schema) + + assert new_tool.output_schema == custom_schema + assert new_tool.output_schema != base_string_tool.output_schema + + async def test_transform_explicit_schema_runtime(self, base_string_tool): + """Test runtime behavior with explicit output schema.""" + custom_schema = {"type": "string", "minLength": 1} + new_tool = Tool.from_tool(base_string_tool, output_schema=custom_schema) + + result = await new_tool.run({"x": 10}) + # Non-object explicit schemas disable structured content + assert result.structured_content is None + assert result.content[0].text == "Result: 10" # type: ignore[attr-defined] + + def test_transform_with_custom_function_inferred_schema(self, base_dict_tool): + """Test that custom function's output schema is inferred.""" + + async def custom_fn(x: int) -> str: + result = await forward(x=x) + return f"Custom: {result.content[0].text}" # type: ignore[attr-defined] + + new_tool = Tool.from_tool(base_dict_tool, transform_fn=custom_fn) + + # Should infer string schema from custom function and wrap it + expected_schema = { + "type": "object", + "properties": {"result": {"type": "string"}}, + "x-fastmcp-wrap-result": True, + } + assert new_tool.output_schema == expected_schema + + async def test_transform_custom_function_runtime(self, base_dict_tool): + """Test runtime behavior with custom function that has inferred schema.""" + + async def custom_fn(x: int) -> str: + result = await forward(x=x) + return f"Custom: {result.content[0].text}" # type: ignore[attr-defined] + + new_tool = Tool.from_tool(base_dict_tool, transform_fn=custom_fn) + + result = await new_tool.run({"x": 3}) + # Should wrap string result + assert result.structured_content == {"result": 'Custom: {\n "value": 3\n}'} + + def test_transform_custom_function_fallback_to_parent(self, base_string_tool): + """Test that custom function without output annotation falls back to parent.""" + + async def custom_fn(x: int): + # No return annotation - should fallback to parent schema + result = await forward(x=x) + return result + + new_tool = Tool.from_tool(base_string_tool, transform_fn=custom_fn) + + # Should use parent's schema since custom function has no annotation + assert new_tool.output_schema == base_string_tool.output_schema + + def test_transform_custom_function_explicit_overrides(self, base_string_tool): + """Test that explicit output_schema overrides both custom function and parent.""" + + async def custom_fn(x: int) -> dict[str, str]: + return {"custom": "value"} + + explicit_schema = {"type": "array", "items": {"type": "number"}} + new_tool = Tool.from_tool( + base_string_tool, transform_fn=custom_fn, output_schema=explicit_schema + ) + + # Explicit schema should win + assert new_tool.output_schema == explicit_schema + + async def test_transform_custom_function_object_return(self, base_string_tool): + """Test custom function returning object type.""" + + async def custom_fn(x: int) -> dict[str, int]: + await forward(x=x) + return {"original": x, "transformed": x * 2} + + new_tool = Tool.from_tool(base_string_tool, transform_fn=custom_fn) + + # Object types should not be wrapped + expected_schema = TypeAdapter(dict[str, int]).json_schema() + assert new_tool.output_schema == expected_schema + assert "x-fastmcp-wrap-result" not in new_tool.output_schema # type: ignore[attr-defined] + + result = await new_tool.run({"x": 4}) + # Direct value, not wrapped + assert result.structured_content == {"original": 4, "transformed": 8} + + async def test_transform_preserves_wrap_marker_behavior(self, base_string_tool): + """Test that wrap marker behavior is preserved through transformation.""" + new_tool = Tool.from_tool(base_string_tool) + + result = await new_tool.run({"x": 7}) + # Should wrap because parent schema has wrap marker + assert result.structured_content == {"result": "Result: 7"} + assert "x-fastmcp-wrap-result" in new_tool.output_schema # type: ignore[attr-defined] + + def test_transform_chained_output_schema_inheritance(self, base_string_tool): + """Test output schema inheritance through multiple transformations.""" + # First transformation keeps parent schema + tool1 = Tool.from_tool(base_string_tool) + assert tool1.output_schema == base_string_tool.output_schema + + # Second transformation also inherits + tool2 = Tool.from_tool(tool1) + assert ( + tool2.output_schema == tool1.output_schema == base_string_tool.output_schema + ) + + # Third transformation with explicit override + custom_schema = {"type": "number"} + tool3 = Tool.from_tool(tool2, output_schema=custom_schema) + assert tool3.output_schema == custom_schema + assert tool3.output_schema != tool2.output_schema + + async def test_transform_mixed_structured_unstructured_content( + self, base_string_tool + ): + """Test transformation handling of mixed content types.""" + + async def custom_fn(x: int): + # Return mixed content including ToolResult + if x == 1: + return ["text", {"data": x}] + else: + # Return ToolResult directly + return ToolResult( + content=[TextContent(type="text", text=f"Custom: {x}")], + structured_content={"custom_value": x}, + ) + + new_tool = Tool.from_tool(base_string_tool, transform_fn=custom_fn) + + # Test mixed content return + result1 = await new_tool.run({"x": 1}) + assert result1.structured_content == {"result": ["text", {"data": 1}]} + + # Test ToolResult return + result2 = await new_tool.run({"x": 2}) + assert result2.structured_content == {"custom_value": 2} + assert result2.content[0].text == "Custom: 2" # type: ignore[attr-defined] + + def test_transform_output_schema_with_arg_transforms(self, base_string_tool): + """Test that output schema works correctly with argument transformations.""" + + async def custom_fn(new_x: int) -> dict[str, str]: + result = await forward(new_x=new_x) + return {"transformed": result.content[0].text} # type: ignore[attr-defined] + + new_tool = Tool.from_tool( + base_string_tool, + transform_fn=custom_fn, + transform_args={"x": ArgTransform(name="new_x")}, + ) + + # Should infer object schema from custom function + expected_schema = TypeAdapter(dict[str, str]).json_schema() + assert new_tool.output_schema == expected_schema + + async def test_transform_output_schema_none_vs_false(self, base_string_tool): + """Test None vs False behavior for output_schema in transforms.""" + # None (default) should use smart fallback (inherit from parent) + tool_none = Tool.from_tool(base_string_tool) # default output_schema=None + assert tool_none.output_schema == base_string_tool.output_schema # Inherits + + # False should explicitly disable + tool_false = Tool.from_tool(base_string_tool, output_schema=False) + assert tool_false.output_schema is None + + # Different behavior at runtime + result_none = await tool_none.run({"x": 5}) + result_false = await tool_false.run({"x": 5}) + + assert result_none.structured_content == { + "result": "Result: 5" + } # Inherits wrapping + assert result_false.structured_content is None # Disabled + assert result_none.content[0].text == result_false.content[0].text # type: ignore[attr-defined] + + async def test_transform_output_schema_with_tool_result_return( + self, base_string_tool + ): + """Test transform when custom function returns ToolResult directly.""" + + async def custom_fn(x: int) -> ToolResult: + # Custom function returns ToolResult - should bypass schema handling + return ToolResult( + content=[TextContent(type="text", text=f"Direct: {x}")], + structured_content={"direct_value": x, "doubled": x * 2}, + ) + + new_tool = Tool.from_tool(base_string_tool, transform_fn=custom_fn) + + # ToolResult return type should result in None output schema + assert new_tool.output_schema is None + + result = await new_tool.run({"x": 6}) + # Should use ToolResult content directly + assert result.content[0].text == "Direct: 6" # type: ignore[attr-defined] + assert result.structured_content == {"direct_value": 6, "doubled": 12} diff --git a/tests/utilities/test_json_schema_type.py b/tests/utilities/test_json_schema_type.py new file mode 100644 index 000000000..866facbce --- /dev/null +++ b/tests/utilities/test_json_schema_type.py @@ -0,0 +1,1441 @@ +from datetime import datetime +from typing import Any, Union + +import pytest +from pydantic import AnyUrl, BaseModel, TypeAdapter, ValidationError + +from fastmcp.utilities.json_schema_type import ( + _hash_schema, + _merge_defaults, + json_schema_to_type, +) + + +class TestSimpleTypes: + """Test suite for basic type validation.""" + + @pytest.fixture + def simple_string(self): + return json_schema_to_type({"type": "string"}) + + @pytest.fixture + def simple_number(self): + return json_schema_to_type({"type": "number"}) + + @pytest.fixture + def simple_integer(self): + return json_schema_to_type({"type": "integer"}) + + @pytest.fixture + def simple_boolean(self): + return json_schema_to_type({"type": "boolean"}) + + @pytest.fixture + def simple_null(self): + return json_schema_to_type({"type": "null"}) + + def test_string_accepts_string(self, simple_string): + validator = TypeAdapter(simple_string) + assert validator.validate_python("test") == "test" + + def test_string_rejects_number(self, simple_string): + validator = TypeAdapter(simple_string) + with pytest.raises(ValidationError): + validator.validate_python(123) + + def test_number_accepts_float(self, simple_number): + validator = TypeAdapter(simple_number) + assert validator.validate_python(123.45) == 123.45 + + def test_number_accepts_integer(self, simple_number): + validator = TypeAdapter(simple_number) + assert validator.validate_python(123) == 123 + + def test_number_accepts_numeric_string(self, simple_number): + validator = TypeAdapter(simple_number) + assert validator.validate_python("123.45") == 123.45 + assert validator.validate_python("123") == 123 + + def test_number_rejects_invalid_string(self, simple_number): + validator = TypeAdapter(simple_number) + with pytest.raises(ValidationError): + validator.validate_python("not a number") + + def test_integer_accepts_integer(self, simple_integer): + validator = TypeAdapter(simple_integer) + assert validator.validate_python(123) == 123 + + def test_integer_accepts_integer_string(self, simple_integer): + validator = TypeAdapter(simple_integer) + assert validator.validate_python("123") == 123 + + def test_integer_rejects_float(self, simple_integer): + validator = TypeAdapter(simple_integer) + with pytest.raises(ValidationError): + validator.validate_python(123.45) + + def test_integer_rejects_float_string(self, simple_integer): + validator = TypeAdapter(simple_integer) + with pytest.raises(ValidationError): + validator.validate_python("123.45") + + def test_boolean_accepts_boolean(self, simple_boolean): + validator = TypeAdapter(simple_boolean) + assert validator.validate_python(True) is True + assert validator.validate_python(False) is False + + def test_boolean_accepts_boolean_strings(self, simple_boolean): + validator = TypeAdapter(simple_boolean) + assert validator.validate_python("true") is True + assert validator.validate_python("True") is True + assert validator.validate_python("false") is False + assert validator.validate_python("False") is False + + def test_boolean_rejects_invalid_string(self, simple_boolean): + validator = TypeAdapter(simple_boolean) + with pytest.raises(ValidationError): + validator.validate_python("not a boolean") + + def test_null_accepts_none(self, simple_null): + validator = TypeAdapter(simple_null) + assert validator.validate_python(None) is None + + def test_null_rejects_false(self, simple_null): + validator = TypeAdapter(simple_null) + with pytest.raises(ValidationError): + validator.validate_python(False) + + +class TestStringConstraints: + """Test suite for string constraint validation.""" + + @pytest.fixture + def min_length_string(self): + return json_schema_to_type({"type": "string", "minLength": 3}) + + @pytest.fixture + def max_length_string(self): + return json_schema_to_type({"type": "string", "maxLength": 5}) + + @pytest.fixture + def pattern_string(self): + return json_schema_to_type({"type": "string", "pattern": "^[A-Z][a-z]+$"}) + + @pytest.fixture + def email_string(self): + return json_schema_to_type({"type": "string", "format": "email"}) + + def test_min_length_accepts_valid(self, min_length_string): + validator = TypeAdapter(min_length_string) + assert validator.validate_python("test") == "test" + + def test_min_length_rejects_short(self, min_length_string): + validator = TypeAdapter(min_length_string) + with pytest.raises(ValidationError): + validator.validate_python("ab") + + def test_max_length_accepts_valid(self, max_length_string): + validator = TypeAdapter(max_length_string) + assert validator.validate_python("test") == "test" + + def test_max_length_rejects_long(self, max_length_string): + validator = TypeAdapter(max_length_string) + with pytest.raises(ValidationError): + validator.validate_python("toolong") + + def test_pattern_accepts_valid(self, pattern_string): + validator = TypeAdapter(pattern_string) + assert validator.validate_python("Hello") == "Hello" + + def test_pattern_rejects_invalid(self, pattern_string): + validator = TypeAdapter(pattern_string) + with pytest.raises(ValidationError): + validator.validate_python("hello") + + def test_email_accepts_valid(self, email_string): + validator = TypeAdapter(email_string) + result = validator.validate_python("test@example.com") + assert result == "test@example.com" + + def test_email_rejects_invalid(self, email_string): + validator = TypeAdapter(email_string) + with pytest.raises(ValidationError): + validator.validate_python("not-an-email") + + +class TestNumberConstraints: + """Test suite for numeric constraint validation.""" + + @pytest.fixture + def multiple_of_number(self): + return json_schema_to_type({"type": "number", "multipleOf": 0.5}) + + @pytest.fixture + def min_number(self): + return json_schema_to_type({"type": "number", "minimum": 0}) + + @pytest.fixture + def exclusive_min_number(self): + return json_schema_to_type({"type": "number", "exclusiveMinimum": 0}) + + @pytest.fixture + def max_number(self): + return json_schema_to_type({"type": "number", "maximum": 100}) + + @pytest.fixture + def exclusive_max_number(self): + return json_schema_to_type({"type": "number", "exclusiveMaximum": 100}) + + def test_multiple_of_accepts_valid(self, multiple_of_number): + validator = TypeAdapter(multiple_of_number) + assert validator.validate_python(2.5) == 2.5 + + def test_multiple_of_rejects_invalid(self, multiple_of_number): + validator = TypeAdapter(multiple_of_number) + with pytest.raises(ValidationError): + validator.validate_python(2.7) + + def test_minimum_accepts_equal(self, min_number): + validator = TypeAdapter(min_number) + assert validator.validate_python(0) == 0 + + def test_minimum_rejects_less(self, min_number): + validator = TypeAdapter(min_number) + with pytest.raises(ValidationError): + validator.validate_python(-1) + + def test_exclusive_minimum_rejects_equal(self, exclusive_min_number): + validator = TypeAdapter(exclusive_min_number) + with pytest.raises(ValidationError): + validator.validate_python(0) + + def test_maximum_accepts_equal(self, max_number): + validator = TypeAdapter(max_number) + assert validator.validate_python(100) == 100 + + def test_maximum_rejects_greater(self, max_number): + validator = TypeAdapter(max_number) + with pytest.raises(ValidationError): + validator.validate_python(101) + + def test_exclusive_maximum_rejects_equal(self, exclusive_max_number): + validator = TypeAdapter(exclusive_max_number) + with pytest.raises(ValidationError): + validator.validate_python(100) + + +class TestArrayTypes: + """Test suite for array validation.""" + + @pytest.fixture + def string_array(self): + return json_schema_to_type({"type": "array", "items": {"type": "string"}}) + + @pytest.fixture + def min_items_array(self): + return json_schema_to_type( + {"type": "array", "items": {"type": "string"}, "minItems": 2} + ) + + @pytest.fixture + def max_items_array(self): + return json_schema_to_type( + {"type": "array", "items": {"type": "string"}, "maxItems": 3} + ) + + @pytest.fixture + def unique_items_array(self): + return json_schema_to_type( + {"type": "array", "items": {"type": "string"}, "uniqueItems": True} + ) + + def test_array_accepts_valid_items(self, string_array): + validator = TypeAdapter(string_array) + assert validator.validate_python(["a", "b"]) == ["a", "b"] + + def test_array_rejects_invalid_items(self, string_array): + validator = TypeAdapter(string_array) + with pytest.raises(ValidationError): + validator.validate_python([1, "b"]) + + def test_min_items_accepts_valid(self, min_items_array): + validator = TypeAdapter(min_items_array) + assert validator.validate_python(["a", "b"]) == ["a", "b"] + + def test_min_items_rejects_too_few(self, min_items_array): + validator = TypeAdapter(min_items_array) + with pytest.raises(ValidationError): + validator.validate_python(["a"]) + + def test_max_items_accepts_valid(self, max_items_array): + validator = TypeAdapter(max_items_array) + assert validator.validate_python(["a", "b", "c"]) == ["a", "b", "c"] + + def test_max_items_rejects_too_many(self, max_items_array): + validator = TypeAdapter(max_items_array) + with pytest.raises(ValidationError): + validator.validate_python(["a", "b", "c", "d"]) + + def test_unique_items_accepts_unique(self, unique_items_array): + validator = TypeAdapter(unique_items_array) + assert isinstance(validator.validate_python(["a", "b"]), set) + + def test_unique_items_converts_duplicates(self, unique_items_array): + validator = TypeAdapter(unique_items_array) + result = validator.validate_python(["a", "a", "b"]) + assert result == {"a", "b"} + + +class TestObjectTypes: + """Test suite for object validation.""" + + @pytest.fixture + def simple_object(self): + return json_schema_to_type( + { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + } + ) + + @pytest.fixture + def required_object(self): + return json_schema_to_type( + { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name"], + } + ) + + @pytest.fixture + def nested_object(self): + return json_schema_to_type( + { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + } + }, + } + ) + + @pytest.mark.parametrize( + "input_type, expected_type", + [ + # Plain dict becomes dict[str, Any] (JSON Schema accurate) + (dict, dict[str, Any]), + # dict[str, Any] stays the same + (dict[str, Any], dict[str, Any]), + # Simple typed dicts work correctly + (dict[str, str], dict[str, str]), + (dict[str, int], dict[str, int]), + # Union value types work + (dict[str, str | int], dict[str, str | int]), + # Key types are constrained to str in JSON Schema + (dict[int, list[str]], dict[str, list[str]]), + # Union key types become str (JSON Schema limitation) + (dict[str | int, str | None], dict[str, str | None]), + ], + ) + def test_dict_types_are_generated_correctly(self, input_type, expected_type): + schema = TypeAdapter(input_type).json_schema() + generated_type = json_schema_to_type(schema) + assert generated_type == expected_type + + def test_object_accepts_valid(self, simple_object): + validator = TypeAdapter(simple_object) + result = validator.validate_python({"name": "test", "age": 30}) + assert result.name == "test" + assert result.age == 30 + + def test_object_accepts_extra_properties(self, simple_object): + validator = TypeAdapter(simple_object) + result = validator.validate_python( + {"name": "test", "age": 30, "extra": "field"} + ) + assert result.name == "test" + assert result.age == 30 + assert not hasattr(result, "extra") + + def test_required_accepts_valid(self, required_object): + validator = TypeAdapter(required_object) + result = validator.validate_python({"name": "test"}) + assert result.name == "test" + assert result.age is None + + def test_required_rejects_missing(self, required_object): + validator = TypeAdapter(required_object) + with pytest.raises(ValidationError): + validator.validate_python({}) + + def test_nested_accepts_valid(self, nested_object): + validator = TypeAdapter(nested_object) + result = validator.validate_python({"user": {"name": "test", "age": 30}}) + assert result.user.name == "test" + assert result.user.age == 30 + + def test_nested_rejects_invalid(self, nested_object): + validator = TypeAdapter(nested_object) + with pytest.raises(ValidationError): + validator.validate_python({"user": {"age": 30}}) + + +class TestDefaultValues: + """Test suite for default value handling.""" + + @pytest.fixture + def simple_defaults(self): + return json_schema_to_type( + { + "type": "object", + "properties": { + "name": {"type": "string", "default": "anonymous"}, + "age": {"type": "integer", "default": 0}, + }, + } + ) + + @pytest.fixture + def nested_defaults(self): + return json_schema_to_type( + { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "default": "anonymous"}, + "settings": { + "type": "object", + "properties": { + "theme": {"type": "string", "default": "light"} + }, + "default": {"theme": "dark"}, + }, + }, + "default": {"name": "guest", "settings": {"theme": "system"}}, + } + }, + } + ) + + def test_simple_defaults_empty_object(self, simple_defaults): + validator = TypeAdapter(simple_defaults) + result = validator.validate_python({}) + assert result.name == "anonymous" + assert result.age == 0 + + def test_simple_defaults_partial_override(self, simple_defaults): + validator = TypeAdapter(simple_defaults) + result = validator.validate_python({"name": "test"}) + assert result.name == "test" + assert result.age == 0 + + def test_nested_defaults_empty_object(self, nested_defaults): + validator = TypeAdapter(nested_defaults) + result = validator.validate_python({}) + assert result.user.name == "guest" + assert result.user.settings.theme == "system" + + def test_nested_defaults_partial_override(self, nested_defaults): + validator = TypeAdapter(nested_defaults) + result = validator.validate_python({"user": {"name": "test"}}) + assert result.user.name == "test" + assert result.user.settings.theme == "system" + + +class TestUnionTypes: + """Test suite for testing union type behaviors.""" + + @pytest.fixture + def heterogeneous_union(self): + return json_schema_to_type({"type": ["string", "number", "boolean", "null"]}) + + @pytest.fixture + def union_with_constraints(self): + return json_schema_to_type( + {"type": ["string", "number"], "minLength": 3, "minimum": 0} + ) + + @pytest.fixture + def union_with_formats(self): + return json_schema_to_type({"type": ["string", "null"], "format": "email"}) + + @pytest.fixture + def nested_union_array(self): + return json_schema_to_type( + {"type": "array", "items": {"type": ["string", "number"]}} + ) + + @pytest.fixture + def nested_union_object(self): + return json_schema_to_type( + { + "type": "object", + "properties": { + "id": {"type": ["string", "integer"]}, + "data": { + "type": ["object", "null"], + "properties": {"value": {"type": "string"}}, + }, + }, + } + ) + + def test_heterogeneous_accepts_string(self, heterogeneous_union): + validator = TypeAdapter(heterogeneous_union) + assert validator.validate_python("test") == "test" + + def test_heterogeneous_accepts_number(self, heterogeneous_union): + validator = TypeAdapter(heterogeneous_union) + assert validator.validate_python(123.45) == 123.45 + + def test_heterogeneous_accepts_boolean(self, heterogeneous_union): + validator = TypeAdapter(heterogeneous_union) + assert validator.validate_python(True) is True + + def test_heterogeneous_accepts_null(self, heterogeneous_union): + validator = TypeAdapter(heterogeneous_union) + assert validator.validate_python(None) is None + + def test_heterogeneous_rejects_array(self, heterogeneous_union): + validator = TypeAdapter(heterogeneous_union) + with pytest.raises(ValidationError): + validator.validate_python([]) + + def test_constrained_string_valid(self, union_with_constraints): + validator = TypeAdapter(union_with_constraints) + assert validator.validate_python("test") == "test" + + def test_constrained_string_invalid(self, union_with_constraints): + validator = TypeAdapter(union_with_constraints) + with pytest.raises(ValidationError): + validator.validate_python("ab") + + def test_constrained_number_valid(self, union_with_constraints): + validator = TypeAdapter(union_with_constraints) + assert validator.validate_python(10) == 10 + + def test_constrained_number_invalid(self, union_with_constraints): + validator = TypeAdapter(union_with_constraints) + with pytest.raises(ValidationError): + validator.validate_python(-1) + + def test_format_valid_email(self, union_with_formats): + validator = TypeAdapter(union_with_formats) + result = validator.validate_python("test@example.com") + assert isinstance(result, str) + + def test_format_valid_null(self, union_with_formats): + validator = TypeAdapter(union_with_formats) + assert validator.validate_python(None) is None + + def test_format_invalid_email(self, union_with_formats): + validator = TypeAdapter(union_with_formats) + with pytest.raises(ValidationError): + validator.validate_python("not-an-email") + + def test_nested_array_mixed_types(self, nested_union_array): + validator = TypeAdapter(nested_union_array) + result = validator.validate_python(["test", 123, "abc"]) + assert result == ["test", 123, "abc"] + + def test_nested_array_rejects_invalid(self, nested_union_array): + validator = TypeAdapter(nested_union_array) + with pytest.raises(ValidationError): + validator.validate_python(["test", ["not", "allowed"], "abc"]) + + def test_nested_object_string_id(self, nested_union_object): + validator = TypeAdapter(nested_union_object) + result = validator.validate_python({"id": "abc123", "data": {"value": "test"}}) + assert result.id == "abc123" + assert result.data.value == "test" + + def test_nested_object_integer_id(self, nested_union_object): + validator = TypeAdapter(nested_union_object) + result = validator.validate_python({"id": 123, "data": None}) + assert result.id == 123 + assert result.data is None + + +class TestFormatTypes: + """Test suite for format type validation.""" + + @pytest.fixture + def datetime_format(self): + return json_schema_to_type({"type": "string", "format": "date-time"}) + + @pytest.fixture + def email_format(self): + return json_schema_to_type({"type": "string", "format": "email"}) + + @pytest.fixture + def uri_format(self): + return json_schema_to_type({"type": "string", "format": "uri"}) + + @pytest.fixture + def uri_reference_format(self): + return json_schema_to_type({"type": "string", "format": "uri-reference"}) + + @pytest.fixture + def json_format(self): + return json_schema_to_type({"type": "string", "format": "json"}) + + @pytest.fixture + def mixed_formats_object(self): + return json_schema_to_type( + { + "type": "object", + "properties": { + "full_uri": {"type": "string", "format": "uri"}, + "ref_uri": {"type": "string", "format": "uri-reference"}, + }, + } + ) + + def test_datetime_valid(self, datetime_format): + validator = TypeAdapter(datetime_format) + result = validator.validate_python("2024-01-17T12:34:56Z") + assert isinstance(result, datetime) + + def test_datetime_invalid(self, datetime_format): + validator = TypeAdapter(datetime_format) + with pytest.raises(ValidationError): + validator.validate_python("not-a-date") + + def test_email_valid(self, email_format): + validator = TypeAdapter(email_format) + result = validator.validate_python("test@example.com") + assert isinstance(result, str) + + def test_email_invalid(self, email_format): + validator = TypeAdapter(email_format) + with pytest.raises(ValidationError): + validator.validate_python("not-an-email") + + def test_uri_valid(self, uri_format): + validator = TypeAdapter(uri_format) + result = validator.validate_python("https://example.com") + assert isinstance(result, AnyUrl) + + def test_uri_invalid(self, uri_format): + validator = TypeAdapter(uri_format) + with pytest.raises(ValidationError): + validator.validate_python("not-a-uri") + + def test_uri_reference_valid(self, uri_reference_format): + validator = TypeAdapter(uri_reference_format) + result = validator.validate_python("https://example.com") + assert isinstance(result, str) + + def test_uri_reference_relative_valid(self, uri_reference_format): + validator = TypeAdapter(uri_reference_format) + result = validator.validate_python("/path/to/resource") + assert isinstance(result, str) + + def test_uri_reference_invalid(self, uri_reference_format): + validator = TypeAdapter(uri_reference_format) + result = validator.validate_python("not a uri") + assert isinstance(result, str) + + def test_json_valid(self, json_format): + validator = TypeAdapter(json_format) + result = validator.validate_python('{"key": "value"}') + assert isinstance(result, dict) + + def test_json_invalid(self, json_format): + validator = TypeAdapter(json_format) + with pytest.raises(ValidationError): + validator.validate_python("{invalid json}") + + def test_mixed_formats_object(self, mixed_formats_object): + validator = TypeAdapter(mixed_formats_object) + result = validator.validate_python( + {"full_uri": "https://example.com", "ref_uri": "/path/to/resource"} + ) + assert isinstance(result.full_uri, AnyUrl) + assert isinstance(result.ref_uri, str) + + +class TestCircularReferences: + """Test suite for circular reference handling.""" + + @pytest.fixture + def self_referential(self): + return json_schema_to_type( + { + "type": "object", + "properties": {"name": {"type": "string"}, "child": {"$ref": "#"}}, + } + ) + + @pytest.fixture + def mutually_recursive(self): + return json_schema_to_type( + { + "type": "object", + "definitions": { + "Person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "friend": {"$ref": "#/definitions/Pet"}, + }, + }, + "Pet": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "owner": {"$ref": "#/definitions/Person"}, + }, + }, + }, + "properties": {"person": {"$ref": "#/definitions/Person"}}, + } + ) + + def test_self_ref_single_level(self, self_referential): + validator = TypeAdapter(self_referential) + result = validator.validate_python( + {"name": "parent", "child": {"name": "child"}} + ) + assert result.name == "parent" + assert result.child.name == "child" + assert result.child.child is None + + def test_self_ref_multiple_levels(self, self_referential): + validator = TypeAdapter(self_referential) + result = validator.validate_python( + { + "name": "grandparent", + "child": {"name": "parent", "child": {"name": "child"}}, + } + ) + assert result.name == "grandparent" + assert result.child.name == "parent" + assert result.child.child.name == "child" + + def test_mutual_recursion_single_level(self, mutually_recursive): + validator = TypeAdapter(mutually_recursive) + result = validator.validate_python( + {"person": {"name": "Alice", "friend": {"name": "Spot"}}} + ) + assert result.person.name == "Alice" + assert result.person.friend.name == "Spot" + assert result.person.friend.owner is None + + def test_mutual_recursion_multiple_levels(self, mutually_recursive): + validator = TypeAdapter(mutually_recursive) + result = validator.validate_python( + { + "person": { + "name": "Alice", + "friend": {"name": "Spot", "owner": {"name": "Bob"}}, + } + } + ) + assert result.person.name == "Alice" + assert result.person.friend.name == "Spot" + assert result.person.friend.owner.name == "Bob" + + +class TestIdentifierNormalization: + """Test suite for handling non-standard property names.""" + + @pytest.fixture + def special_chars(self): + return json_schema_to_type( + { + "type": "object", + "properties": { + "@type": {"type": "string"}, + "first-name": {"type": "string"}, + "last.name": {"type": "string"}, + "2nd_address": {"type": "string"}, + "$ref": {"type": "string"}, + }, + } + ) + + def test_normalizes_special_chars(self, special_chars): + validator = TypeAdapter(special_chars) + result = validator.validate_python( + { + "@type": "person", + "first-name": "Alice", + "last.name": "Smith", + "2nd_address": "456 Oak St", + "$ref": "12345", + } + ) + assert result.field_type == "person" # @type -> field_type + assert result.first_name == "Alice" # first-name -> first_name + assert result.last_name == "Smith" # last.name -> last_name + assert ( + result.field_2nd_address == "456 Oak St" + ) # 2nd_address -> field_2nd_address + assert result.field_ref == "12345" # $ref -> field_ref + + +class TestConstantValues: + """Test suite for constant value validation.""" + + @pytest.fixture + def string_const(self): + return json_schema_to_type({"type": "string", "const": "production"}) + + @pytest.fixture + def number_const(self): + return json_schema_to_type({"type": "number", "const": 42.5}) + + @pytest.fixture + def boolean_const(self): + return json_schema_to_type({"type": "boolean", "const": True}) + + @pytest.fixture + def null_const(self): + return json_schema_to_type({"type": "null", "const": None}) + + @pytest.fixture + def object_with_consts(self): + return json_schema_to_type( + { + "type": "object", + "properties": { + "env": {"const": "production"}, + "version": {"const": 1}, + "enabled": {"const": True}, + }, + } + ) + + def test_string_const_valid(self, string_const): + validator = TypeAdapter(string_const) + assert validator.validate_python("production") == "production" + + def test_string_const_invalid(self, string_const): + validator = TypeAdapter(string_const) + with pytest.raises(ValidationError): + validator.validate_python("development") + + def test_number_const_valid(self, number_const): + validator = TypeAdapter(number_const) + assert validator.validate_python(42.5) == 42.5 + + def test_number_const_invalid(self, number_const): + validator = TypeAdapter(number_const) + with pytest.raises(ValidationError): + validator.validate_python(42) + + def test_boolean_const_valid(self, boolean_const): + validator = TypeAdapter(boolean_const) + assert validator.validate_python(True) is True + + def test_boolean_const_invalid(self, boolean_const): + validator = TypeAdapter(boolean_const) + with pytest.raises(ValidationError): + validator.validate_python(False) + + def test_null_const_valid(self, null_const): + validator = TypeAdapter(null_const) + assert validator.validate_python(None) is None + + def test_null_const_invalid(self, null_const): + validator = TypeAdapter(null_const) + with pytest.raises(ValidationError): + validator.validate_python(False) + + def test_object_consts_valid(self, object_with_consts): + validator = TypeAdapter(object_with_consts) + result = validator.validate_python( + {"env": "production", "version": 1, "enabled": True} + ) + assert result.env == "production" + assert result.version == 1 + assert result.enabled is True + + def test_object_consts_invalid(self, object_with_consts): + validator = TypeAdapter(object_with_consts) + with pytest.raises(ValidationError): + validator.validate_python( + { + "env": "production", + "version": 2, # Wrong constant + "enabled": True, + } + ) + + +class TestSchemaCaching: + """Test suite for schema caching behavior.""" + + def test_identical_schemas_reuse_class(self): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + + class1 = json_schema_to_type(schema) + class2 = json_schema_to_type(schema) + assert class1 is class2 + + def test_different_names_different_classes(self): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + + class1 = json_schema_to_type(schema, name="Class1") + class2 = json_schema_to_type(schema, name="Class2") + assert class1 is not class2 + assert class1.__name__ == "Class1" + assert class2.__name__ == "Class2" + + def test_nested_schema_caching(self): + schema = { + "type": "object", + "properties": { + "nested": {"type": "object", "properties": {"name": {"type": "string"}}} + }, + } + + class1 = json_schema_to_type(schema) + class2 = json_schema_to_type(schema) + + # Both main classes and their nested classes should be identical + assert class1 is class2 + assert ( + class1.__dataclass_fields__["nested"].type + is class2.__dataclass_fields__["nested"].type + ) + + +class TestSchemaHashing: + """Test suite for schema hashing utility.""" + + def test_deterministic_hash(self): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + hash1 = _hash_schema(schema) + hash2 = _hash_schema(schema) + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) == 64 # SHA-256 hash length + + def test_different_schemas_different_hashes(self): + schema1 = {"type": "object", "properties": {"name": {"type": "string"}}} + schema2 = {"type": "object", "properties": {"age": {"type": "integer"}}} + assert _hash_schema(schema1) != _hash_schema(schema2) + + def test_order_independent_hash(self): + schema1 = {"properties": {"name": {"type": "string"}}, "type": "object"} + schema2 = {"type": "object", "properties": {"name": {"type": "string"}}} + assert _hash_schema(schema1) == _hash_schema(schema2) + + def test_nested_schema_hash(self): + schema = { + "type": "object", + "properties": { + "nested": {"type": "object", "properties": {"name": {"type": "string"}}} + }, + } + hash1 = _hash_schema(schema) + assert isinstance(hash1, str) + assert len(hash1) == 64 + + +class TestDefaultMerging: + """Test suite for default value merging behavior.""" + + def test_simple_merge(self): + defaults = {"name": "anonymous", "age": 0} + data = {"name": "test"} + result = _merge_defaults(data, {"properties": {}}, defaults) + assert result["name"] == "test" + assert result["age"] == 0 + + def test_nested_merge(self): + defaults = {"user": {"name": "anonymous", "settings": {"theme": "light"}}} + data = {"user": {"name": "test"}} + result = _merge_defaults(data, {"properties": {}}, defaults) + assert result["user"]["name"] == "test" + assert result["user"]["settings"]["theme"] == "light" + + def test_array_merge(self): + defaults = { + "items": [ + {"name": "item1", "done": False}, + {"name": "item2", "done": False}, + ] + } + data = {"items": [{"name": "custom", "done": True}]} + result = _merge_defaults(data, {"properties": {}}, defaults) + assert len(result["items"]) == 1 + assert result["items"][0]["name"] == "custom" + assert result["items"][0]["done"] is True + + def test_empty_data_uses_defaults(self): + schema = { + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "default": "anonymous"}, + "settings": {"type": "object", "default": {"theme": "light"}}, + }, + "default": {"name": "guest", "settings": {"theme": "dark"}}, + } + } + } + result = _merge_defaults({}, schema) + assert result["user"]["name"] == "guest" + assert result["user"]["settings"]["theme"] == "dark" + + def test_property_level_defaults(self): + schema = { + "properties": { + "name": {"type": "string", "default": "anonymous"}, + "age": {"type": "integer", "default": 0}, + } + } + result = _merge_defaults({}, schema) + assert result["name"] == "anonymous" + assert result["age"] == 0 + + def test_nested_property_defaults(self): + schema = { + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "default": "anonymous"}, + "settings": { + "type": "object", + "properties": { + "theme": {"type": "string", "default": "light"} + }, + }, + }, + } + } + } + result = _merge_defaults({"user": {"settings": {}}}, schema) + assert result["user"]["name"] == "anonymous" + assert result["user"]["settings"]["theme"] == "light" + + def test_default_priority(self): + schema = { + "properties": { + "settings": { + "type": "object", + "properties": {"theme": {"type": "string", "default": "light"}}, + "default": {"theme": "dark"}, + } + }, + "default": {"settings": {"theme": "system"}}, + } + + # Test priority: data > parent default > object default > property default + result1 = _merge_defaults({}, schema) # Uses schema default + assert result1["settings"]["theme"] == "system" + + result2 = _merge_defaults({"settings": {}}, schema) # Uses object default + assert result2["settings"]["theme"] == "dark" + + result3 = _merge_defaults( + {"settings": {"theme": "custom"}}, schema + ) # Uses provided data + assert result3["settings"]["theme"] == "custom" + + +class TestEdgeCases: + """Test suite for edge cases and corner scenarios.""" + + def test_empty_schema(self): + schema = {} + result = json_schema_to_type(schema) + assert result is object + + def test_schema_without_type(self): + schema = {"properties": {"name": {"type": "string"}}} + Type = json_schema_to_type(schema) + validator = TypeAdapter(Type) + result = validator.validate_python({"name": "test"}) + assert result.name == "test" + + def test_recursive_defaults(self): + schema = { + "type": "object", + "properties": { + "node": { + "type": "object", + "properties": {"value": {"type": "string"}, "next": {"$ref": "#"}}, + "default": {"value": "default", "next": None}, + } + }, + } + Type = json_schema_to_type(schema) + validator = TypeAdapter(Type) + result = validator.validate_python({}) + assert result.node.value == "default" + assert result.node.next is None + + def test_mixed_type_array(self): + schema = { + "type": "array", + "items": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], + } + Type = json_schema_to_type(schema) + validator = TypeAdapter(Type) + result = validator.validate_python(["test", 123, True]) + assert result == ["test", 123, True] + + +class TestNameHandling: + """Test suite for schema name handling.""" + + def test_name_from_title(self): + schema = { + "type": "object", + "title": "Person", + "properties": {"name": {"type": "string"}}, + } + Type = json_schema_to_type(schema) + assert Type.__name__ == "Person" + + def test_explicit_name_overrides_title(self): + schema = { + "type": "object", + "title": "Person", + "properties": {"name": {"type": "string"}}, + } + Type = json_schema_to_type(schema, name="CustomPerson") + assert Type.__name__ == "CustomPerson" + + def test_default_name_without_title(self): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + Type = json_schema_to_type(schema) + assert Type.__name__ == "Root" + + def test_name_only_allowed_for_objects(self): + schema = {"type": "string"} + with pytest.raises(ValueError, match="Can not apply name to non-object schema"): + json_schema_to_type(schema, name="StringType") + + def test_nested_object_names(self): + schema = { + "type": "object", + "title": "Parent", + "properties": { + "child": { + "type": "object", + "title": "Child", + "properties": {"name": {"type": "string"}}, + } + }, + } + Type = json_schema_to_type(schema) + assert Type.__name__ == "Parent" + assert Type.__dataclass_fields__["child"].type.__origin__ is Union + assert Type.__dataclass_fields__["child"].type.__args__[0].__name__ == "Child" + assert Type.__dataclass_fields__["child"].type.__args__[1] is type(None) + + def test_recursive_schema_naming(self): + schema = { + "type": "object", + "title": "Node", + "properties": {"next": {"$ref": "#"}}, + } + Type = json_schema_to_type(schema) + assert Type.__name__ == "Node" + assert Type.__dataclass_fields__["next"].type.__origin__ is Union + assert ( + Type.__dataclass_fields__["next"].type.__args__[0].__forward_arg__ == "Node" + ) + assert Type.__dataclass_fields__["next"].type.__args__[1] is type(None) + + def test_name_caching_with_different_titles(self): + """Ensure schemas with different titles create different cached classes""" + schema1 = { + "type": "object", + "title": "Type1", + "properties": {"name": {"type": "string"}}, + } + schema2 = { + "type": "object", + "title": "Type2", + "properties": {"name": {"type": "string"}}, + } + Type1 = json_schema_to_type(schema1) + Type2 = json_schema_to_type(schema2) + assert Type1 is not Type2 + assert Type1.__name__ == "Type1" + assert Type2.__name__ == "Type2" + + def test_recursive_schema_with_invalid_python_name(self): + """Test that recursive schemas work with titles that aren't valid Python identifiers""" + schema = { + "type": "object", + "title": "My Complex Type!", + "properties": {"name": {"type": "string"}, "child": {"$ref": "#"}}, + } + Type = json_schema_to_type(schema) + # The class should get a sanitized name + assert Type.__name__ == "My_Complex_Type" + # Create an instance to verify the recursive reference works + validator = TypeAdapter(Type) + result = validator.validate_python( + {"name": "parent", "child": {"name": "child", "child": None}} + ) + assert result.name == "parent" + assert result.child.name == "child" + assert result.child.child is None + + +class TestAdditionalProperties: + """Test suite for additionalProperties handling.""" + + @pytest.fixture + def dict_only_schema(self): + """Schema with no properties but additionalProperties=True -> dict[str, Any]""" + return json_schema_to_type({"type": "object", "additionalProperties": True}) + + @pytest.fixture + def properties_with_additional(self): + """Schema with properties AND additionalProperties=True -> BaseModel""" + return json_schema_to_type( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "additionalProperties": True, + } + ) + + @pytest.fixture + def properties_without_additional(self): + """Schema with properties but no additionalProperties -> dataclass""" + return json_schema_to_type( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + ) + + @pytest.fixture + def required_properties_with_additional(self): + """Schema with required properties AND additionalProperties=True -> BaseModel""" + return json_schema_to_type( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + "additionalProperties": True, + } + ) + + def test_dict_only_returns_dict_type(self, dict_only_schema): + """Test that schema with no properties + additionalProperties=True returns dict[str, Any]""" + import typing + + assert dict_only_schema == dict[str, typing.Any] + + def test_dict_only_accepts_any_data(self, dict_only_schema): + """Test that pure dict accepts arbitrary key-value pairs""" + validator = TypeAdapter(dict_only_schema) + data = {"anything": "works", "numbers": 123, "nested": {"key": "value"}} + result = validator.validate_python(data) + assert result == data + assert isinstance(result, dict) + + def test_properties_with_additional_returns_basemodel( + self, properties_with_additional + ): + """Test that schema with properties + additionalProperties=True returns BaseModel""" + assert issubclass(properties_with_additional, BaseModel) + + def test_properties_with_additional_accepts_extra_fields( + self, properties_with_additional + ): + """Test that BaseModel with extra='allow' accepts additional properties""" + validator = TypeAdapter(properties_with_additional) + data = { + "name": "Alice", + "age": 30, + "extra": "field", + "another": {"nested": "data"}, + } + result = validator.validate_python(data) + + # Check standard properties + assert result.name == "Alice" + assert result.age == 30 + + # Check extra properties are preserved with dot access + assert hasattr(result, "extra") + assert result.extra == "field" + assert hasattr(result, "another") + assert result.another == {"nested": "data"} + + def test_properties_with_additional_validates_known_fields( + self, properties_with_additional + ): + """Test that BaseModel still validates known fields""" + validator = TypeAdapter(properties_with_additional) + + # Should accept valid data + result = validator.validate_python({"name": "Alice", "age": 30, "extra": "ok"}) + assert result.name == "Alice" + assert result.age == 30 + assert result.extra == "ok" + + # Should reject invalid types for known fields + with pytest.raises(ValidationError): + validator.validate_python({"name": "Alice", "age": "not_a_number"}) + + def test_properties_without_additional_is_dataclass( + self, properties_without_additional + ): + """Test that schema with properties but no additionalProperties returns dataclass""" + assert not issubclass(properties_without_additional, BaseModel) + assert hasattr(properties_without_additional, "__dataclass_fields__") + + def test_properties_without_additional_ignores_extra_fields( + self, properties_without_additional + ): + """Test that dataclass ignores extra properties (current behavior)""" + validator = TypeAdapter(properties_without_additional) + data = {"name": "Alice", "age": 30, "extra": "ignored"} + result = validator.validate_python(data) + + # Check standard properties + assert result.name == "Alice" + assert result.age == 30 + + # Check extra property is ignored + assert not hasattr(result, "extra") + + def test_required_properties_with_additional( + self, required_properties_with_additional + ): + """Test BaseModel with required fields and additional properties""" + validator = TypeAdapter(required_properties_with_additional) + + # Should accept valid data with required field + result = validator.validate_python({"name": "Alice", "extra": "field"}) + assert result.name == "Alice" + assert result.age is None # Optional field + assert result.extra == "field" + + # Should reject missing required field + with pytest.raises(ValidationError): + validator.validate_python({"age": 30, "extra": "field"}) + + def test_nested_additional_properties(self): + """Test nested objects with additionalProperties""" + schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": True, + }, + "settings": { + "type": "object", + "properties": {"theme": {"type": "string"}}, + }, + }, + "additionalProperties": True, + } + + Type = json_schema_to_type(schema) + validator = TypeAdapter(Type) + + data = { + "user": {"name": "Alice", "extra_user_field": "value"}, + "settings": {"theme": "dark", "extra_settings_field": "ignored"}, + "top_level_extra": "preserved", + } + + result = validator.validate_python(data) + + # Check top-level extra field (BaseModel) + assert result.top_level_extra == "preserved" + + # Check nested user extra field (BaseModel) + assert result.user.name == "Alice" + assert result.user.extra_user_field == "value" + + # Check nested settings - should be dataclass + assert result.settings.theme == "dark" + # Note: When nested in BaseModel with extra='allow', Pydantic may preserve extra fields + # even on dataclass children. The important thing is that settings is still a dataclass. + assert not issubclass(type(result.settings), BaseModel) + + def test_additional_properties_false_vs_missing(self): + """Test difference between additionalProperties: false and missing additionalProperties""" + # Schema with explicit additionalProperties: false + schema_false = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": False, + } + + # Schema with no additionalProperties key + schema_missing = { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + + Type_false = json_schema_to_type(schema_false) + Type_missing = json_schema_to_type(schema_missing) + + # Both should create dataclasses (not BaseModel) + assert not issubclass(Type_false, BaseModel) + assert not issubclass(Type_missing, BaseModel) + assert hasattr(Type_false, "__dataclass_fields__") + assert hasattr(Type_missing, "__dataclass_fields__") + + def test_additional_properties_with_defaults(self): + """Test additionalProperties with default values""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string", "default": "anonymous"}, + "age": {"type": "integer", "default": 0}, + }, + "additionalProperties": True, + } + + Type = json_schema_to_type(schema) + validator = TypeAdapter(Type) + + # Test with extra fields and defaults + result = validator.validate_python({"extra": "field"}) + assert result.name == "anonymous" + assert result.age == 0 + assert result.extra == "field" + + def test_additional_properties_type_consistency(self): + """Test that the same schema always returns the same type""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": True, + } + + Type1 = json_schema_to_type(schema) + Type2 = json_schema_to_type(schema) + + # Should be the same cached class + assert Type1 is Type2 + assert issubclass(Type1, BaseModel) diff --git a/tests/utilities/test_mcp_config.py b/tests/utilities/test_mcp_config.py index 7775e12bc..f6136e793 100644 --- a/tests/utilities/test_mcp_config.py +++ b/tests/utilities/test_mcp_config.py @@ -136,8 +136,8 @@ def add(a: int, b: int) -> int: result_1 = await client.call_tool("test_1_add", {"a": 1, "b": 2}) result_2 = await client.call_tool("test_2_add", {"a": 1, "b": 2}) - assert result_1[0].text == "3" # type: ignore[attr-dict] - assert result_2[0].text == "3" # type: ignore[attr-dict] + assert result_1.data == 3 + assert result_2.data == 3 async def test_remote_config_default_no_auth(): diff --git a/tests/utilities/test_types.py b/tests/utilities/test_types.py index 1f8338318..4ac9109c0 100644 --- a/tests/utilities/test_types.py +++ b/tests/utilities/test_types.py @@ -12,6 +12,7 @@ find_kwarg_by_type, is_class_member_of_type, issubclass_safe, + replace_type, ) @@ -536,3 +537,29 @@ def func(a: int, b, c: str): pass assert find_kwarg_by_type(func, str) == "c" + + +class TestReplaceType: + @pytest.mark.parametrize( + "input,type_map,expected", + [ + (int, {}, int), + (int, {int: str}, str), + (int, {int: int}, int), + (int, {int: float, bool: str}, float), + (bool, {int: float, bool: str}, str), + (int, {int: list[int]}, list[int]), + (list[int], {int: str}, list[str]), + (list[int], {int: list[str]}, list[list[str]]), + ( + list[int], + {int: float, list[int]: bool}, + bool, + ), # list[int] will match before int + (list[int | bool], {int: str}, list[str | bool]), + (list[list[int]], {int: str}, list[list[str]]), + ], + ) + def test_replace_type(self, input, type_map, expected): + """Test replacing a type with another type.""" + assert replace_type(input, type_map) == expected diff --git a/uv.lock b/uv.lock index 87a47f5ee..ebfad50b7 100644 --- a/uv.lock +++ b/uv.lock @@ -378,6 +378,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973, upload-time = "2024-10-09T18:35:44.272Z" }, ] +[[package]] +name = "dnspython" +version = "2.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/4a/263763cb2ba3816dd94b08ad3a33d5fdae34ecb856678773cc40a3605829/dnspython-2.7.0.tar.gz", hash = "sha256:ce9c432eda0dc91cf618a5cedf1a4e142651196bbcd2c80e89ed5a907e5cfaf1", size = 345197, upload-time = "2024-10-05T20:14:59.362Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632, upload-time = "2024-10-05T20:14:57.687Z" }, +] + +[[package]] +name = "email-validator" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/48/ce/13508a1ec3f8bb981ae4ca79ea40384becc868bfae97fd1c942bb3a001b1/email_validator-2.2.0.tar.gz", hash = "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7", size = 48967, upload-time = "2024-06-20T11:30:30.034Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/ee/bf0adb559ad3c786f12bcbc9296b3f5675f529199bef03e2df281fa1fadb/email_validator-2.2.0-py3-none-any.whl", hash = "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631", size = 33521, upload-time = "2024-06-20T11:30:28.248Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.0" @@ -444,6 +466,7 @@ dependencies = [ { name = "httpx" }, { name = "mcp" }, { name = "openapi-pydantic" }, + { name = "pydantic", extra = ["email"] }, { name = "python-dotenv" }, { name = "rich" }, { name = "typer" }, @@ -484,6 +507,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "mcp", specifier = ">=1.10.0" }, { name = "openapi-pydantic", specifier = ">=0.5.1" }, + { name = "pydantic", extras = ["email"], specifier = ">=2.11.7" }, { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "rich", specifier = ">=13.9.4" }, { name = "typer", specifier = ">=0.15.2" }, @@ -935,6 +959,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" }, ] +[package.optional-dependencies] +email = [ + { name = "email-validator" }, +] + [[package]] name = "pydantic-core" version = "2.33.2"