From 6789c24d3e8a51dbf0003497225eb454859cc25a Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Sat, 5 Apr 2025 10:45:47 +0800 Subject: [PATCH] fix(macro): add generics marco types support add generic marco support Signed-off-by: jokemanfire --- crates/rmcp-macros/src/tool.rs | 116 +++++++++++++++--- crates/rmcp/tests/test_tool_macros.rs | 47 +++++++ examples/servers/Cargo.toml | 6 +- examples/servers/src/common/counter.rs | 1 + .../servers/src/common/generic_service.rs | 73 +++++++++++ examples/servers/src/common/mod.rs | 1 + examples/servers/src/generic_service.rs | 18 +++ 7 files changed, 245 insertions(+), 17 deletions(-) create mode 100644 examples/servers/src/common/generic_service.rs create mode 100644 examples/servers/src/generic_service.rs diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index a59e4b3f..85946574 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -177,30 +177,114 @@ pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result syn::Result { let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?; let tool_box_ident = tool_impl_attr.tool_box; + + // get all tool function ident + let mut tool_fn_idents = Vec::new(); + for item in &input.items { + if let syn::ImplItem::Fn(method) = item { + for attr in &method.attrs { + if attr.path().is_ident(TOOL_IDENT) { + tool_fn_idents.push(method.sig.ident.clone()); + } + } + } + } + + // handle different cases if input.trait_.is_some() { if let Some(ident) = tool_box_ident { - input.items.push(parse_quote!( - rmcp::tool_box!(@derive #ident); - )); + // check if there are generic parameters + if !input.generics.params.is_empty() { + // for trait implementation with generic parameters, directly use the already generated *_inner method + + // generate call_tool method + input.items.push(parse_quote! { + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + self.call_tool_inner(request, context).await + } + }); + + // generate list_tools method + input.items.push(parse_quote! { + async fn list_tools( + &self, + request: rmcp::model::PaginatedRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + self.list_tools_inner(request, context).await + } + }); + } else { + // if there are no generic parameters, add tool box derive + input.items.push(parse_quote!( + rmcp::tool_box!(@derive #ident); + )); + } } } else if let Some(ident) = tool_box_ident { - let mut tool_fn_idents = Vec::new(); - for item in &input.items { - if let syn::ImplItem::Fn(method) = item { - for attr in &method.attrs { - if attr.path().is_ident(TOOL_IDENT) { - tool_fn_idents.push(method.sig.ident.clone()); + // if it is a normal impl block + if !input.generics.params.is_empty() { + // if there are generic parameters, not use tool_box! macro, but generate code directly + + // create call code for each tool function + let match_arms = tool_fn_idents.iter().map(|ident| { + let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); + let call_fn = Ident::new(&format!("{}_tool_call", ident), ident.span()); + quote! { + name if name == Self::#attr_fn().name => { + Self::#call_fn(tcc).await } } - } + }); + + let tool_attrs = tool_fn_idents.iter().map(|ident| { + let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); + quote! { Self::#attr_fn() } + }); + + // implement call_tool method + input.items.push(parse_quote! { + async fn call_tool_inner( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + match tcc.name() { + #(#match_arms,)* + _ => Err(rmcp::Error::invalid_params("tool not found", None)), + } + } + }); + + // implement list_tools method + input.items.push(parse_quote! { + async fn list_tools_inner( + &self, + _: rmcp::model::PaginatedRequestParam, + _: rmcp::service::RequestContext, + ) -> Result { + Ok(rmcp::model::ListToolsResult { + next_cursor: None, + tools: vec![#(#tool_attrs),*], + }) + } + }); + } else { + // if there are no generic parameters, use the original tool_box! macro + let this_type_ident = &input.self_ty; + input.items.push(parse_quote!( + rmcp::tool_box!(#this_type_ident { + #(#tool_fn_idents),* + } #ident); + )); } - let this_type_ident = &input.self_ty; - input.items.push(parse_quote!( - rmcp::tool_box!(#this_type_ident { - #(#tool_fn_idents),* - } #ident); - )); } + Ok(quote! { #input }) diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 8bb15e0f..daa5ee3d 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -21,6 +23,7 @@ impl ServerHandler for Server { } } } + #[derive(Debug, Clone, Default)] pub struct Server {} @@ -35,6 +38,40 @@ impl Server { async fn empty_param(&self) {} } +// define generic service trait +pub trait DataService: Send + Sync + 'static { + fn get_data(&self) -> String; +} + +// mock service for test +#[derive(Clone)] +struct MockDataService; +impl DataService for MockDataService { + fn get_data(&self) -> String { + "mock data".to_string() + } +} + +// define generic server +#[derive(Debug, Clone)] +pub struct GenericServer { + data_service: Arc, +} + +#[tool(tool_box)] +impl GenericServer { + pub fn new(data_service: DS) -> Self { + Self { + data_service: Arc::new(data_service), + } + } + + #[tool(description = "Get data from the service")] + async fn get_data(&self) -> String { + self.data_service.get_data() + } +} + #[tokio::test] async fn test_tool_macros() { let server = Server::default(); @@ -52,4 +89,14 @@ async fn test_tool_macros_with_empty_param() { assert!(_attr.input_schema.get("properties").is_none()); } +#[tokio::test] +async fn test_tool_macros_with_generics() { + let mock_service = MockDataService; + let server = GenericServer::new(mock_service); + let _attr = GenericServer::::get_data_tool_attr(); + let _get_data_call_fn = GenericServer::::get_data_tool_call; + let _get_data_fn = GenericServer::::get_data; + assert_eq!(server.get_data().await, "mock data"); +} + impl GetWeatherRequest {} diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index abc89244..63d9d6d2 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -39,4 +39,8 @@ path = "src/axum.rs" [[example]] name = "servers_axum_router" -path = "src/axum_router.rs" \ No newline at end of file +path = "src/axum_router.rs" + +[[example]] +name = "servers_generic_server" +path = "src/generic_service.rs" \ No newline at end of file diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 81151205..be43ce7e 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -19,6 +19,7 @@ pub struct Counter { } #[tool(tool_box)] impl Counter { + #[allow(dead_code)] pub fn new() -> Self { Self { counter: Arc::new(Mutex::new(0)), diff --git a/examples/servers/src/common/generic_service.rs b/examples/servers/src/common/generic_service.rs new file mode 100644 index 00000000..433a4308 --- /dev/null +++ b/examples/servers/src/common/generic_service.rs @@ -0,0 +1,73 @@ +use std::sync::Arc; + +use rmcp::{ + ServerHandler, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, +}; + +#[allow(dead_code)] +pub trait DataService: Send + Sync + 'static { + fn get_data(&self) -> String; + fn set_data(&mut self, data: String); +} + +#[derive(Debug, Clone)] +pub struct MemoryDataService { + data: String, +} + +impl MemoryDataService { + #[allow(dead_code)] + pub fn new(initial_data: impl Into) -> Self { + Self { + data: initial_data.into(), + } + } +} + +impl DataService for MemoryDataService { + fn get_data(&self) -> String { + self.data.clone() + } + + fn set_data(&mut self, data: String) { + self.data = data; + } +} + +#[derive(Debug, Clone)] +pub struct GenericService { + #[allow(dead_code)] + data_service: Arc, +} + +#[tool(tool_box)] +impl GenericService { + pub fn new(data_service: DS) -> Self { + Self { + data_service: Arc::new(data_service), + } + } + + #[tool(description = "get memory from service")] + pub async fn get_data(&self) -> String { + self.data_service.get_data() + } + + #[tool(description = "set memory to service")] + pub async fn set_data(&self, #[tool(param)] data: String) -> String { + let new_data = data.clone(); + format!("Current memory: {}", new_data) + } +} + +impl ServerHandler for GenericService { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some("generic data service".into()), + capabilities: ServerCapabilities::builder().enable_tools().build(), + ..Default::default() + } + } +} diff --git a/examples/servers/src/common/mod.rs b/examples/servers/src/common/mod.rs index 7c651b22..5919bccd 100644 --- a/examples/servers/src/common/mod.rs +++ b/examples/servers/src/common/mod.rs @@ -1,2 +1,3 @@ pub mod calculator; pub mod counter; +pub mod generic_service; diff --git a/examples/servers/src/generic_service.rs b/examples/servers/src/generic_service.rs new file mode 100644 index 00000000..546621e3 --- /dev/null +++ b/examples/servers/src/generic_service.rs @@ -0,0 +1,18 @@ +use std::error::Error; +mod common; +use common::generic_service::{GenericService, MemoryDataService}; +use rmcp::serve_server; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let memory_service = MemoryDataService::new("initial data"); + + let generic_service = GenericService::new(memory_service); + + println!("start server, connect to standard input/output"); + + let io = (tokio::io::stdin(), tokio::io::stdout()); + + serve_server(generic_service, io).await?; + Ok(()) +}