|
1 | 1 | from __future__ import annotations as _annotations
|
2 | 2 |
|
3 | 3 | from collections.abc import Callable
|
4 |
| -from typing import TYPE_CHECKING, Any |
| 4 | +from typing import TYPE_CHECKING, Any, overload |
| 5 | + |
| 6 | +from pydantic import AnyUrl |
5 | 7 |
|
6 | 8 | from mcp.server.fastmcp.exceptions import ToolError
|
7 | 9 | from mcp.server.fastmcp.tools.base import Tool
|
@@ -39,15 +41,28 @@ def _normalize_to_uri(self, name_or_uri: str) -> str:
|
39 | 41 | """Convert name to URI if needed."""
|
40 | 42 | return normalize_to_tool_uri(name_or_uri)
|
41 | 43 |
|
42 |
| - def get_tool(self, name: str) -> Tool | None: |
| 44 | + @overload |
| 45 | + def get_tool(self, name_or_uri: str) -> Tool | None: |
| 46 | + """Get tool by name.""" |
| 47 | + ... |
| 48 | + |
| 49 | + @overload |
| 50 | + def get_tool(self, name_or_uri: AnyUrl) -> Tool | None: |
| 51 | + """Get tool by URI.""" |
| 52 | + ... |
| 53 | + |
| 54 | + def get_tool(self, name_or_uri: AnyUrl | str) -> Tool | None: |
43 | 55 | """Get tool by name or URI."""
|
44 |
| - uri = self._normalize_to_uri(name) |
| 56 | + if isinstance(name_or_uri, AnyUrl): |
| 57 | + return self._tools.get(str(name_or_uri)) |
| 58 | + uri = self._normalize_to_uri(name_or_uri) |
45 | 59 | return self._tools.get(uri)
|
46 | 60 |
|
47 |
| - def list_tools(self, uri_paths: list[str] | None = None) -> list[Tool]: |
| 61 | + def list_tools(self, uri_paths: list[AnyUrl] | None = None) -> list[Tool]: |
48 | 62 | """List all registered tools, optionally filtered by URI paths."""
|
49 | 63 | tools = list(self._tools.values())
|
50 |
| - tools = filter_by_uri_paths(tools, uri_paths, lambda t: t.uri) |
| 64 | + if uri_paths: |
| 65 | + tools = filter_by_uri_paths(tools, uri_paths) |
51 | 66 | logger.debug("Listing tools", extra={"count": len(tools), "uri_paths": uri_paths})
|
52 | 67 | return tools
|
53 | 68 |
|
@@ -77,16 +92,38 @@ def add_tool(
|
77 | 92 | self._tools[str(tool.uri)] = tool
|
78 | 93 | return tool
|
79 | 94 |
|
| 95 | + @overload |
80 | 96 | async def call_tool(
|
81 | 97 | self,
|
82 |
| - name: str, |
| 98 | + name_or_uri: str, |
83 | 99 | arguments: dict[str, Any],
|
84 | 100 | context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
85 | 101 | convert_result: bool = False,
|
86 | 102 | ) -> Any:
|
87 | 103 | """Call a tool by name with arguments."""
|
88 |
| - tool = self.get_tool(name) |
| 104 | + ... |
| 105 | + |
| 106 | + @overload |
| 107 | + async def call_tool( |
| 108 | + self, |
| 109 | + name_or_uri: AnyUrl, |
| 110 | + arguments: dict[str, Any], |
| 111 | + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, |
| 112 | + convert_result: bool = False, |
| 113 | + ) -> Any: |
| 114 | + """Call a tool by URI with arguments.""" |
| 115 | + ... |
| 116 | + |
| 117 | + async def call_tool( |
| 118 | + self, |
| 119 | + name_or_uri: AnyUrl | str, |
| 120 | + arguments: dict[str, Any], |
| 121 | + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, |
| 122 | + convert_result: bool = False, |
| 123 | + ) -> Any: |
| 124 | + """Call a tool by name or URI with arguments.""" |
| 125 | + tool = self.get_tool(name_or_uri) |
89 | 126 | if not tool:
|
90 |
| - raise ToolError(f"Unknown tool: {name}") |
| 127 | + raise ToolError(f"Unknown tool: {name_or_uri}") |
91 | 128 |
|
92 | 129 | return await tool.run(arguments, context=context, convert_result=convert_result)
|
0 commit comments