Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 100 additions & 16 deletions crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,30 +177,114 @@ pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenSt
pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Result<TokenStream> {
let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?;
let tool_box_ident = tool_impl_attr.tool_box;

// get all tool function ident
let mut tool_fn_idents = Vec::new();
for item in &input.items {
if let syn::ImplItem::Fn(method) = item {
for attr in &method.attrs {
if attr.path().is_ident(TOOL_IDENT) {
tool_fn_idents.push(method.sig.ident.clone());
}
}
}
}

// handle different cases
if input.trait_.is_some() {
if let Some(ident) = tool_box_ident {
input.items.push(parse_quote!(
rmcp::tool_box!(@derive #ident);
));
// check if there are generic parameters
if !input.generics.params.is_empty() {
// for trait implementation with generic parameters, directly use the already generated *_inner method

// generate call_tool method
input.items.push(parse_quote! {
async fn call_tool(
&self,
request: rmcp::model::CallToolRequestParam,
context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
self.call_tool_inner(request, context).await
}
});

// generate list_tools method
input.items.push(parse_quote! {
async fn list_tools(
&self,
request: rmcp::model::PaginatedRequestParam,
context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::ListToolsResult, rmcp::Error> {
self.list_tools_inner(request, context).await
}
});
} else {
// if there are no generic parameters, add tool box derive
input.items.push(parse_quote!(
rmcp::tool_box!(@derive #ident);
));
}
}
} else if let Some(ident) = tool_box_ident {
let mut tool_fn_idents = Vec::new();
for item in &input.items {
if let syn::ImplItem::Fn(method) = item {
for attr in &method.attrs {
if attr.path().is_ident(TOOL_IDENT) {
tool_fn_idents.push(method.sig.ident.clone());
// if it is a normal impl block
if !input.generics.params.is_empty() {
// if there are generic parameters, not use tool_box! macro, but generate code directly

// create call code for each tool function
let match_arms = tool_fn_idents.iter().map(|ident| {
let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span());
let call_fn = Ident::new(&format!("{}_tool_call", ident), ident.span());
quote! {
name if name == Self::#attr_fn().name => {
Self::#call_fn(tcc).await
}
}
}
});

let tool_attrs = tool_fn_idents.iter().map(|ident| {
let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span());
quote! { Self::#attr_fn() }
});

// implement call_tool method
input.items.push(parse_quote! {
async fn call_tool_inner(
&self,
request: rmcp::model::CallToolRequestParam,
context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context);
match tcc.name() {
#(#match_arms,)*
_ => Err(rmcp::Error::invalid_params("tool not found", None)),
}
}
});

// implement list_tools method
input.items.push(parse_quote! {
async fn list_tools_inner(
&self,
_: rmcp::model::PaginatedRequestParam,
_: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::ListToolsResult, rmcp::Error> {
Ok(rmcp::model::ListToolsResult {
next_cursor: None,
tools: vec![#(#tool_attrs),*],
})
}
});
} else {
// if there are no generic parameters, use the original tool_box! macro
let this_type_ident = &input.self_ty;
input.items.push(parse_quote!(
rmcp::tool_box!(#this_type_ident {
#(#tool_fn_idents),*
} #ident);
));
}
let this_type_ident = &input.self_ty;
input.items.push(parse_quote!(
rmcp::tool_box!(#this_type_ident {
#(#tool_fn_idents),*
} #ident);
));
}

Ok(quote! {
#input
})
Expand Down
47 changes: 47 additions & 0 deletions crates/rmcp/tests/test_tool_macros.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
Expand All @@ -21,6 +23,7 @@ impl ServerHandler for Server {
}
}
}

#[derive(Debug, Clone, Default)]
pub struct Server {}

Expand All @@ -35,6 +38,40 @@ impl Server {
async fn empty_param(&self) {}
}

// define generic service trait
pub trait DataService: Send + Sync + 'static {
fn get_data(&self) -> String;
}

// mock service for test
#[derive(Clone)]
struct MockDataService;
impl DataService for MockDataService {
fn get_data(&self) -> String {
"mock data".to_string()
}
}

// define generic server
#[derive(Debug, Clone)]
pub struct GenericServer<DS: DataService> {
data_service: Arc<DS>,
}

#[tool(tool_box)]
impl<DS: DataService> GenericServer<DS> {
pub fn new(data_service: DS) -> Self {
Self {
data_service: Arc::new(data_service),
}
}

#[tool(description = "Get data from the service")]
async fn get_data(&self) -> String {
self.data_service.get_data()
}
}

#[tokio::test]
async fn test_tool_macros() {
let server = Server::default();
Expand All @@ -52,4 +89,14 @@ async fn test_tool_macros_with_empty_param() {
assert!(_attr.input_schema.get("properties").is_none());
}

#[tokio::test]
async fn test_tool_macros_with_generics() {
let mock_service = MockDataService;
let server = GenericServer::new(mock_service);
let _attr = GenericServer::<MockDataService>::get_data_tool_attr();
let _get_data_call_fn = GenericServer::<MockDataService>::get_data_tool_call;
let _get_data_fn = GenericServer::<MockDataService>::get_data;
assert_eq!(server.get_data().await, "mock data");
}

impl GetWeatherRequest {}
6 changes: 5 additions & 1 deletion examples/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,8 @@ path = "src/axum.rs"

[[example]]
name = "servers_axum_router"
path = "src/axum_router.rs"
path = "src/axum_router.rs"

[[example]]
name = "servers_generic_server"
path = "src/generic_service.rs"
1 change: 1 addition & 0 deletions examples/servers/src/common/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct Counter {
}
#[tool(tool_box)]
impl Counter {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
counter: Arc::new(Mutex::new(0)),
Expand Down
73 changes: 73 additions & 0 deletions examples/servers/src/common/generic_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::sync::Arc;

use rmcp::{
ServerHandler,
model::{ServerCapabilities, ServerInfo},
schemars, tool,
};

#[allow(dead_code)]
pub trait DataService: Send + Sync + 'static {
fn get_data(&self) -> String;
fn set_data(&mut self, data: String);
}

#[derive(Debug, Clone)]
pub struct MemoryDataService {
data: String,
}

impl MemoryDataService {
#[allow(dead_code)]
pub fn new(initial_data: impl Into<String>) -> Self {
Self {
data: initial_data.into(),
}
}
}

impl DataService for MemoryDataService {
fn get_data(&self) -> String {
self.data.clone()
}

fn set_data(&mut self, data: String) {
self.data = data;
}
}

#[derive(Debug, Clone)]
pub struct GenericService<DS: DataService> {
#[allow(dead_code)]
data_service: Arc<DS>,
}

#[tool(tool_box)]
impl<DS: DataService> GenericService<DS> {
pub fn new(data_service: DS) -> Self {
Self {
data_service: Arc::new(data_service),
}
}

#[tool(description = "get memory from service")]
pub async fn get_data(&self) -> String {
self.data_service.get_data()
}

#[tool(description = "set memory to service")]
pub async fn set_data(&self, #[tool(param)] data: String) -> String {
let new_data = data.clone();
format!("Current memory: {}", new_data)
}
}

impl<DS: DataService> ServerHandler for GenericService<DS> {
fn get_info(&self) -> ServerInfo {
ServerInfo {
instructions: Some("generic data service".into()),
capabilities: ServerCapabilities::builder().enable_tools().build(),
..Default::default()
}
}
}
1 change: 1 addition & 0 deletions examples/servers/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod calculator;
pub mod counter;
pub mod generic_service;
18 changes: 18 additions & 0 deletions examples/servers/src/generic_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use std::error::Error;
mod common;
use common::generic_service::{GenericService, MemoryDataService};
use rmcp::serve_server;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let memory_service = MemoryDataService::new("initial data");

let generic_service = GenericService::new(memory_service);

println!("start server, connect to standard input/output");

let io = (tokio::io::stdin(), tokio::io::stdout());

serve_server(generic_service, io).await?;
Ok(())
}
Loading