diff --git a/crates/apollo-mcp-server/src/errors.rs b/crates/apollo-mcp-server/src/errors.rs index e19dc152..ea9989ba 100644 --- a/crates/apollo-mcp-server/src/errors.rs +++ b/crates/apollo-mcp-server/src/errors.rs @@ -93,7 +93,7 @@ pub enum ServerError { StartupError(#[from] JoinError), #[error("Failed to initialize MCP server")] - McpInitializeError(#[from] rmcp::service::ServerInitializeError), + McpInitializeError(#[from] Box>), #[error(transparent)] UrlParseError(ParseError), diff --git a/crates/apollo-mcp-server/src/health.rs b/crates/apollo-mcp-server/src/health.rs index 5345696a..0af59336 100644 --- a/crates/apollo-mcp-server/src/health.rs +++ b/crates/apollo-mcp-server/src/health.rs @@ -7,11 +7,12 @@ use std::{ sync::{ Arc, - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicBool, Ordering}, }, time::Duration, }; +use crate::telemetry::Telemetry; use axum::http::StatusCode; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -112,15 +113,13 @@ pub struct HealthCheck { config: HealthCheckConfig, live: Arc, ready: Arc, - rejected: Arc, ticker: Arc>, } impl HealthCheck { - pub fn new(config: HealthCheckConfig) -> Self { + pub fn new(config: HealthCheckConfig, telemetry: Arc) -> Self { let live = Arc::new(AtomicBool::new(true)); // Start as live - let ready = Arc::new(AtomicBool::new(true)); // Start as ready - let rejected = Arc::new(AtomicUsize::new(0)); + let ready = Arc::new(AtomicBool::new(true)); // Start as ready; let allowed = config.readiness.allowed; let sampling_interval = config.readiness.interval.sampling; @@ -130,8 +129,8 @@ impl HealthCheck { .unready .unwrap_or(2 * sampling_interval); - let my_rejected = rejected.clone(); let my_ready = ready.clone(); + let telemetry_clone = Arc::clone(&telemetry); let ticker = tokio::spawn(async move { loop { @@ -139,11 +138,11 @@ impl HealthCheck { let mut interval = tokio::time::interval_at(start, sampling_interval); loop { interval.tick().await; - if my_rejected.load(Ordering::Relaxed) > allowed { + if telemetry_clone.errors() > allowed { debug!("Health check readiness threshold exceeded, marking as unready"); my_ready.store(false, Ordering::SeqCst); tokio::time::sleep(recovery_interval).await; - my_rejected.store(0, Ordering::Relaxed); + telemetry_clone.set_error_count(0); my_ready.store(true, Ordering::SeqCst); debug!("Health check readiness restored"); break; @@ -156,15 +155,10 @@ impl HealthCheck { config, live, ready, - rejected, ticker: Arc::new(ticker), } } - pub fn record_rejection(&self) { - self.rejected.fetch_add(1, Ordering::Relaxed); - } - pub fn config(&self) -> &HealthCheckConfig { &self.config } @@ -215,6 +209,7 @@ impl Drop for HealthCheck { #[cfg(test)] mod tests { use super::*; + use crate::telemetry::InMemoryTelemetry; use tokio::time::{Duration, sleep}; #[test] @@ -234,7 +229,9 @@ mod tests { config.readiness.interval.sampling = Duration::from_millis(50); config.readiness.interval.unready = Some(Duration::from_millis(100)); - let health_check = HealthCheck::new(config); + let mock_telemetry: Arc = Arc::new(InMemoryTelemetry::new()); + + let health_check = HealthCheck::new(config, Arc::clone(&mock_telemetry)); // Should be live and ready initially assert!(health_check.live.load(Ordering::SeqCst)); @@ -242,7 +239,7 @@ mod tests { // Record rejections beyond threshold for _ in 0..5 { - health_check.record_rejection(); + mock_telemetry.record_error(); } // Wait for the ticker to process diff --git a/crates/apollo-mcp-server/src/lib.rs b/crates/apollo-mcp-server/src/lib.rs index 21b89ce8..0a09014f 100644 --- a/crates/apollo-mcp-server/src/lib.rs +++ b/crates/apollo-mcp-server/src/lib.rs @@ -11,3 +11,7 @@ pub mod operations; pub mod sanitize; pub(crate) mod schema_tree_shake; pub mod server; +pub mod server_config; +pub mod telemetry; + +pub mod server_handler; diff --git a/crates/apollo-mcp-server/src/main.rs b/crates/apollo-mcp-server/src/main.rs index ae5102e6..b9f5cc98 100644 --- a/crates/apollo-mcp-server/src/main.rs +++ b/crates/apollo-mcp-server/src/main.rs @@ -1,5 +1,4 @@ -use std::path::PathBuf; - +use crate::runtime::Serve; use apollo_mcp_registry::platform_api::operation_collections::collection_poller::CollectionSource; use apollo_mcp_registry::uplink::persisted_queries::ManifestSource; use apollo_mcp_registry::uplink::schema::SchemaSource; @@ -7,11 +6,18 @@ use apollo_mcp_server::custom_scalar_map::CustomScalarMap; use apollo_mcp_server::errors::ServerError; use apollo_mcp_server::operations::OperationSource; use apollo_mcp_server::server::Server; +use apollo_mcp_server::server_config::ServerConfig; +use apollo_mcp_server::server_handler::ApolloMcpServerHandler; +use apollo_mcp_server::telemetry::{InMemoryTelemetry, Telemetry}; use clap::Parser; use clap::builder::Styles; use clap::builder::styling::{AnsiColor, Effects}; use runtime::IdOrDefault; use runtime::logging::Logging; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; use tracing::{info, warn}; mod runtime; @@ -109,11 +115,15 @@ async fn main() -> anyhow::Result<()> { .then(|| config.graphos.graph_ref()) .transpose()?; - Ok(Server::builder() - .transport(config.transport) - .schema_source(schema_source) - .operation_source(operation_source) - .endpoint(config.endpoint.into_inner()) + let telemetry: Option> = config + .health_check + .enabled + .then(|| Arc::new(InMemoryTelemetry::new()) as Arc); + let server_handler = + ApolloMcpServerHandler::new(config.headers.clone(), config.endpoint.clone(), telemetry); + let cancellation_token = CancellationToken::new(); + + let server_config = ServerConfig::builder() .maybe_explorer_graph_ref(explorer_graph_ref) .headers(config.headers) .execute_introspection(config.introspection.execute.enabled) @@ -133,8 +143,25 @@ async fn main() -> anyhow::Result<()> { ) .search_leaf_depth(config.introspection.search.leaf_depth) .index_memory_bytes(config.introspection.search.index_memory_bytes) - .health_check(config.health_check) + .build(); + + Server::builder() + .schema_source(schema_source) + .operation_source(operation_source) + .server_handler(Arc::new(RwLock::new(server_handler.clone()))) + .cancellation_token(cancellation_token.child_token()) + .server_config(server_config) .build() .start() - .await?) + .await?; + + Serve::serve( + server_handler, + config.transport, + cancellation_token, + config.health_check, + ) + .await?; + + Ok(()) } diff --git a/crates/apollo-mcp-server/src/runtime.rs b/crates/apollo-mcp-server/src/runtime.rs index 7b42fd53..16ef3a55 100644 --- a/crates/apollo-mcp-server/src/runtime.rs +++ b/crates/apollo-mcp-server/src/runtime.rs @@ -12,6 +12,7 @@ mod operation_source; mod overrides; mod schema_source; mod schemas; +mod serve; use std::path::Path; @@ -22,6 +23,7 @@ use figment::{ }; pub use operation_source::{IdOrDefault, OperationSource}; pub use schema_source::SchemaSource; +pub use serve::Serve; /// Separator to use when drilling down into nested options in the env figment const ENV_NESTED_SEPARATOR: &str = "__"; diff --git a/crates/apollo-mcp-server/src/runtime/endpoint.rs b/crates/apollo-mcp-server/src/runtime/endpoint.rs index bacb203b..f6f816a1 100644 --- a/crates/apollo-mcp-server/src/runtime/endpoint.rs +++ b/crates/apollo-mcp-server/src/runtime/endpoint.rs @@ -13,13 +13,6 @@ use url::Url; #[derive(Debug)] pub struct Endpoint(Url); -impl Endpoint { - /// Unwrap the endpoint into its inner URL - pub fn into_inner(self) -> Url { - self.0 - } -} - impl Default for Endpoint { fn default() -> Self { Self(defaults::endpoint()) diff --git a/crates/apollo-mcp-server/src/runtime/serve.rs b/crates/apollo-mcp-server/src/runtime/serve.rs new file mode 100644 index 00000000..4774b80e --- /dev/null +++ b/crates/apollo-mcp-server/src/runtime/serve.rs @@ -0,0 +1,185 @@ +use apollo_mcp_server::auth::Config; +use apollo_mcp_server::errors::ServerError; +use apollo_mcp_server::health::{HealthCheck, HealthCheckConfig}; +use apollo_mcp_server::server::Transport; +use apollo_mcp_server::server::states::shutdown_signal; +use apollo_mcp_server::server_handler::ApolloMcpServerHandler; +use apollo_mcp_server::telemetry::{InMemoryTelemetry, Telemetry}; +use axum::extract::Query; +use axum::routing::get; +use axum::{Json, Router}; +use http::StatusCode; +use rmcp::service::{RunningService, ServerInitializeError}; +use rmcp::transport::sse_server::SseServerConfig; +use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; +use rmcp::transport::{SseServer, StreamableHttpService, stdio}; +use rmcp::{RoleServer, ServiceExt}; +use serde_json::json; +use std::io::Error; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, error, info, trace}; + +// Helper to enable auth +macro_rules! with_auth { + ($router:expr, $auth:ident) => {{ + let mut router = $router; + if let Some(auth) = $auth { + router = auth.enable_middleware(router); + } + + router + }}; +} + +pub struct Serve; + +impl Serve { + pub async fn serve( + server_handler: ApolloMcpServerHandler, + transport: Transport, + cancellation_token: CancellationToken, + health_check_config: HealthCheckConfig, + ) -> Result<(), ServerError> { + match transport { + Transport::StreamableHttp { + auth, + address, + port, + } => { + serve_streamable_http(auth, address, port, server_handler, health_check_config) + .await?; + } + Transport::SSE { + auth, + address, + port, + } => { + serve_sse(auth, address, port, server_handler, cancellation_token).await?; + } + Transport::Stdio => { + let service = serve_stdio(server_handler) + .await + .map_err(|e| ServerError::McpInitializeError(e.into()))?; + service.waiting().await.map_err(ServerError::StartupError)?; + } + } + + Ok(()) + } +} + +// Create health check if enabled (only for StreamableHttp transport) +fn create_health_check(config: HealthCheckConfig) -> Option { + let telemetry: Arc = Arc::new(InMemoryTelemetry::new()); + Some(HealthCheck::new(config, telemetry)) +} + +async fn serve_streamable_http( + auth: Option, + address: IpAddr, + port: u16, + server_handler: ApolloMcpServerHandler, + health_check_config: HealthCheckConfig, +) -> Result<(), ServerError> { + info!(port = ?port, address = ?address, "Starting MCP server in Streamable HTTP mode"); + let listen_address = SocketAddr::new(address, port); + let service = StreamableHttpService::new( + move || Ok(server_handler.clone()), + LocalSessionManager::default().into(), + Default::default(), + ); + + let mut router = with_auth!(Router::new().nest_service("/mcp", service), auth); + + // Add health check endpoint if configured + if health_check_config.enabled { + if let Some(health_check) = create_health_check(health_check_config) { + let health_router = Router::new() + .route(&health_check.config().path, get(health_endpoint)) + .with_state(health_check.clone()); + router = router.merge(health_router); + } + } + + let tcp_listener = tokio::net::TcpListener::bind(listen_address).await?; + tokio::spawn(async move { + // Health check is already active from creation + if let Err(e) = axum::serve(tcp_listener, router) + .with_graceful_shutdown(shutdown_signal()) + .await + { + // This can never really happen + error!("Failed to start MCP server: {e:?}"); + } + }); + + Ok(()) +} + +async fn serve_sse( + auth: Option, + address: IpAddr, + port: u16, + server_handler: ApolloMcpServerHandler, + cancellation_token: CancellationToken, +) -> Result<(), Error> { + info!(port = ?port, address = ?address, "Starting MCP server in SSE mode"); + let listen_address = SocketAddr::new(address, port); + + let (server, router) = SseServer::new(SseServerConfig { + bind: listen_address, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: cancellation_token, + sse_keep_alive: None, + }); + + // Optionally wrap the router with auth, if enabled + let router = with_auth!(router, auth); + + // Start up the SSE server + // Note: Until RMCP consolidates SSE with the same tower system as StreamableHTTP, + // we need to basically copy the implementation of `SseServer::serve_with_config` here. + let listener = tokio::net::TcpListener::bind(server.config.bind).await?; + let ct = server.config.ct.child_token(); + let axum_server = axum::serve(listener, router).with_graceful_shutdown(async move { + ct.cancelled().await; + info!("mcp server cancelled"); + }); + + tokio::spawn( + async move { + if let Err(e) = axum_server.await { + error!(error = %e, "mcp shutdown with error"); + } + } + .instrument(tracing::info_span!("mcp-server", bind_address = %server.config.bind)), + ); + + server.with_service(move || server_handler.clone()); + Ok(()) +} + +async fn serve_stdio( + server_handler: ApolloMcpServerHandler, +) -> Result, ServerInitializeError> { + info!("Starting MCP server in stdio mode"); + server_handler.serve(stdio()).await.inspect_err(|e| { + error!("serving error: {:?}", e); + }) +} + +/// Health check endpoint handler +async fn health_endpoint( + axum::extract::State(health_check): axum::extract::State, + Query(params): Query>, +) -> Result<(StatusCode, Json), StatusCode> { + let query = params.keys().next().map(|k| k.as_str()); + let (health, status_code) = health_check.get_health_state(query); + + trace!(?health, query = ?query, "health check"); + + Ok((status_code, Json(json!(health)))) +} diff --git a/crates/apollo-mcp-server/src/server.rs b/crates/apollo-mcp-server/src/server.rs index 96c0d772..a6b1d96b 100644 --- a/crates/apollo-mcp-server/src/server.rs +++ b/crates/apollo-mcp-server/src/server.rs @@ -1,44 +1,31 @@ -use std::net::{IpAddr, Ipv4Addr}; - use apollo_mcp_registry::uplink::schema::SchemaSource; use bon::bon; -use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; +use reqwest::header::{CONTENT_TYPE, HeaderValue}; use schemars::JsonSchema; use serde::Deserialize; -use url::Url; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; use crate::auth; -use crate::custom_scalar_map::CustomScalarMap; use crate::errors::ServerError; use crate::event::Event as ServerEvent; -use crate::health::HealthCheckConfig; -use crate::operations::{MutationMode, OperationSource}; +use crate::operations::OperationSource; -mod states; +pub mod states; +use crate::server_config::ServerConfig; +use crate::server_handler::ApolloMcpServerHandler; use states::StateMachine; /// An Apollo MCP Server pub struct Server { - transport: Transport, schema_source: SchemaSource, operation_source: OperationSource, - endpoint: Url, - headers: HeaderMap, - execute_introspection: bool, - validate_introspection: bool, - introspect_introspection: bool, - introspect_minify: bool, - search_minify: bool, - search_introspection: bool, - explorer_graph_ref: Option, - custom_scalar_map: Option, - mutation_mode: MutationMode, - disable_type_description: bool, - disable_schema_description: bool, - search_leaf_depth: usize, - index_memory_bytes: usize, - health_check: HealthCheckConfig, + server_handler: Arc>, + cancellation_token: CancellationToken, + server_config: ServerConfig, } #[derive(Debug, Clone, Deserialize, Default, JsonSchema)] @@ -96,51 +83,24 @@ impl Transport { impl Server { #[builder] pub fn new( - transport: Transport, + mut server_config: ServerConfig, schema_source: SchemaSource, operation_source: OperationSource, - endpoint: Url, - headers: HeaderMap, - execute_introspection: bool, - validate_introspection: bool, - introspect_introspection: bool, - search_introspection: bool, - introspect_minify: bool, - search_minify: bool, - explorer_graph_ref: Option, - #[builder(required)] custom_scalar_map: Option, - mutation_mode: MutationMode, - disable_type_description: bool, - disable_schema_description: bool, - search_leaf_depth: usize, - index_memory_bytes: usize, - health_check: HealthCheckConfig, + server_handler: Arc>, + cancellation_token: CancellationToken, ) -> Self { - let headers = { - let mut headers = headers.clone(); + server_config.headers = { + let mut headers = server_config.headers.clone(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); headers }; + Self { - transport, schema_source, operation_source, - endpoint, - headers, - execute_introspection, - validate_introspection, - introspect_introspection, - search_introspection, - introspect_minify, - search_minify, - explorer_graph_ref, - custom_scalar_map, - mutation_mode, - disable_type_description, - disable_schema_description, - search_leaf_depth, - index_memory_bytes, - health_check, + server_handler, + cancellation_token, + server_config, } } diff --git a/crates/apollo-mcp-server/src/server/states.rs b/crates/apollo-mcp-server/src/server/states.rs index 81211cda..a9cf66ad 100644 --- a/crates/apollo-mcp-server/src/server/states.rs +++ b/crates/apollo-mcp-server/src/server/states.rs @@ -2,17 +2,11 @@ use apollo_compiler::{Schema, validation::Valid}; use apollo_federation::{ApiSchemaOptions, Supergraph}; use apollo_mcp_registry::uplink::schema::{SchemaState, event::Event as SchemaEvent}; use futures::{FutureExt as _, Stream, StreamExt as _, stream}; -use reqwest::header::HeaderMap; -use url::Url; +use std::sync::Arc; -use crate::{ - custom_scalar_map::CustomScalarMap, - errors::{OperationError, ServerError}, - health::HealthCheckConfig, - operations::MutationMode, -}; +use crate::errors::{OperationError, ServerError}; -use super::{Server, ServerEvent, Transport}; +use super::{Server, ServerEvent}; mod configuring; mod operations_configured; @@ -28,27 +22,6 @@ use starting::Starting; pub(super) struct StateMachine {} -/// Common configuration options for the states -struct Config { - transport: Transport, - endpoint: Url, - headers: HeaderMap, - execute_introspection: bool, - validate_introspection: bool, - introspect_introspection: bool, - search_introspection: bool, - introspect_minify: bool, - search_minify: bool, - explorer_graph_ref: Option, - custom_scalar_map: Option, - mutation_mode: MutationMode, - disable_type_description: bool, - disable_schema_description: bool, - search_leaf_depth: usize, - index_memory_bytes: usize, - health_check: HealthCheckConfig, -} - impl StateMachine { pub(crate) async fn start(self, server: Server) -> Result<(), ServerError> { let schema_stream = server @@ -61,25 +34,7 @@ impl StateMachine { let mut stream = stream::select_all(vec![schema_stream, operation_stream, ctrl_c_stream]); let mut state = State::Configuring(Configuring { - config: Config { - transport: server.transport, - endpoint: server.endpoint, - headers: server.headers, - execute_introspection: server.execute_introspection, - validate_introspection: server.validate_introspection, - introspect_introspection: server.introspect_introspection, - search_introspection: server.search_introspection, - introspect_minify: server.introspect_minify, - search_minify: server.search_minify, - explorer_graph_ref: server.explorer_graph_ref, - custom_scalar_map: server.custom_scalar_map, - mutation_mode: server.mutation_mode, - disable_type_description: server.disable_type_description, - disable_schema_description: server.disable_schema_description, - search_leaf_depth: server.search_leaf_depth, - index_memory_bytes: server.index_memory_bytes, - health_check: server.health_check, - }, + config: server.server_config, }); while let Some(event) = stream.next().await { @@ -129,15 +84,23 @@ impl StateMachine { State::Error(ServerError::Operation(OperationError::Collection(e))) } ServerEvent::Shutdown => match state { - State::Running(running) => { - running.cancellation_token.cancel(); + State::Running(_running) => { + server.cancellation_token.cancel(); State::Stopping } _ => State::Stopping, }, }; if let State::Starting(starting) = state { - state = starting.start().await.into(); + server + .server_handler + .write() + .await + .configure(&starting.config, starting.schema.clone())?; + state = starting + .start(Arc::clone(&server.server_handler)) + .await + .into(); } if matches!(&state, State::Error(_) | State::Stopping) { break; @@ -171,7 +134,7 @@ impl StateMachine { } #[allow(clippy::expect_used)] -async fn shutdown_signal() { +pub async fn shutdown_signal() { let ctrl_c = async { tokio::signal::ctrl_c() .await diff --git a/crates/apollo-mcp-server/src/server/states/configuring.rs b/crates/apollo-mcp-server/src/server/states/configuring.rs index b91db3ae..52520d2c 100644 --- a/crates/apollo-mcp-server/src/server/states/configuring.rs +++ b/crates/apollo-mcp-server/src/server/states/configuring.rs @@ -1,12 +1,12 @@ use apollo_compiler::{Schema, validation::Valid}; use tracing::debug; +use super::{OperationsConfigured, SchemaConfigured}; +use crate::server_config::ServerConfig; use crate::{errors::ServerError, operations::RawOperation}; -use super::{Config, OperationsConfigured, SchemaConfigured}; - pub(super) struct Configuring { - pub(super) config: Config, + pub(super) config: ServerConfig, } impl Configuring { diff --git a/crates/apollo-mcp-server/src/server/states/operations_configured.rs b/crates/apollo-mcp-server/src/server/states/operations_configured.rs index 742d4320..b305668a 100644 --- a/crates/apollo-mcp-server/src/server/states/operations_configured.rs +++ b/crates/apollo-mcp-server/src/server/states/operations_configured.rs @@ -1,12 +1,11 @@ use apollo_compiler::{Schema, validation::Valid}; use tracing::debug; +use crate::server_config::ServerConfig; use crate::{errors::ServerError, operations::RawOperation, server::states::Starting}; -use super::Config; - pub(super) struct OperationsConfigured { - pub(super) config: Config, + pub(super) config: ServerConfig, pub(super) operations: Vec, } diff --git a/crates/apollo-mcp-server/src/server/states/running.rs b/crates/apollo-mcp-server/src/server/states/running.rs index 219dae3f..06e1109e 100644 --- a/crates/apollo-mcp-server/src/server/states/running.rs +++ b/crates/apollo-mcp-server/src/server/states/running.rs @@ -1,58 +1,24 @@ -use std::ops::Deref as _; use std::sync::Arc; use apollo_compiler::{Schema, validation::Valid}; -use headers::HeaderMapExt as _; -use reqwest::header::HeaderMap; -use rmcp::model::Implementation; -use rmcp::{ - Peer, RoleServer, ServerHandler, ServiceError, - model::{ - CallToolRequestParam, CallToolResult, ErrorCode, InitializeRequestParam, InitializeResult, - ListToolsResult, PaginatedRequestParam, ServerCapabilities, ServerInfo, - }, - service::RequestContext, -}; -use serde_json::Value; use tokio::sync::{Mutex, RwLock}; -use tokio_util::sync::CancellationToken; use tracing::{debug, error}; -use url::Url; +use crate::server_handler::ApolloMcpServerHandler; use crate::{ - auth::ValidToken, custom_scalar_map::CustomScalarMap, - errors::{McpError, ServerError}, - explorer::{EXPLORER_TOOL_NAME, Explorer}, - graphql::{self, Executable as _}, - health::HealthCheck, - introspection::tools::{ - execute::{EXECUTE_TOOL_NAME, Execute}, - introspect::{INTROSPECT_TOOL_NAME, Introspect}, - search::{SEARCH_TOOL_NAME, Search}, - validate::{VALIDATE_TOOL_NAME, Validate}, - }, + errors::ServerError, operations::{MutationMode, Operation, RawOperation}, }; #[derive(Clone)] -pub(super) struct Running { +pub struct Running { pub(super) schema: Arc>>, - pub(super) operations: Arc>>, - pub(super) headers: HeaderMap, - pub(super) endpoint: Url, - pub(super) execute_tool: Option, - pub(super) introspect_tool: Option, - pub(super) search_tool: Option, - pub(super) explorer_tool: Option, - pub(super) validate_tool: Option, + pub(super) server_handler: Arc>, pub(super) custom_scalar_map: Option, - pub(super) peers: Arc>>>, - pub(super) cancellation_token: CancellationToken, pub(super) mutation_mode: MutationMode, pub(super) disable_type_description: bool, pub(super) disable_schema_description: bool, - pub(super) health_check: Option, } impl Running { @@ -63,7 +29,10 @@ impl Running { // Update the operations based on the new schema. This is necessary because the MCP tool // input schemas and description are derived from the schema. let operations: Vec = self - .operations + .server_handler + .read() + .await + .operations() .lock() .await .iter() @@ -90,13 +59,17 @@ impl Running { operations.len(), serde_json::to_string_pretty(&operations)? ); - *self.operations.lock().await = operations; + *self.server_handler.read().await.operations().lock().await = operations; // Update the schema itself *self.schema.lock().await = schema; // Notify MCP clients that tools have changed - Self::notify_tool_list_changed(self.peers.clone()).await; + self.server_handler + .read() + .await + .notify_tool_list_changed(self.server_handler.read().await.peers()) + .await; Ok(self) } @@ -132,186 +105,24 @@ impl Running { updated_operations.len(), serde_json::to_string_pretty(&updated_operations)? ); - *self.operations.lock().await = updated_operations; + *self.server_handler.write().await.operations().lock().await = updated_operations; } // Notify MCP clients that tools have changed - Self::notify_tool_list_changed(self.peers.clone()).await; + self.server_handler + .read() + .await + .notify_tool_list_changed(self.server_handler.read().await.peers()) + .await; Ok(self) } - - /// Notify any peers that tools have changed. Drops unreachable peers from the list. - async fn notify_tool_list_changed(peers: Arc>>>) { - let mut peers = peers.write().await; - if !peers.is_empty() { - debug!( - "Operations changed, notifying {} peers of tool change", - peers.len() - ); - } - let mut retained_peers = Vec::new(); - for peer in peers.iter() { - if !peer.is_transport_closed() { - match peer.notify_tool_list_changed().await { - Ok(_) => retained_peers.push(peer.clone()), - Err(ServiceError::TransportSend(_) | ServiceError::TransportClosed) => { - error!("Failed to notify peer of tool list change - dropping peer",); - } - Err(e) => { - error!("Failed to notify peer of tool list change {:?}", e); - retained_peers.push(peer.clone()); - } - } - } - } - *peers = retained_peers; - } -} - -impl ServerHandler for Running { - async fn initialize( - &self, - _request: InitializeRequestParam, - context: RequestContext, - ) -> Result { - // TODO: how to remove these? - let mut peers = self.peers.write().await; - peers.push(context.peer); - Ok(self.get_info()) - } - - async fn call_tool( - &self, - request: CallToolRequestParam, - context: RequestContext, - ) -> Result { - let result = match request.name.as_ref() { - INTROSPECT_TOOL_NAME => { - self.introspect_tool - .as_ref() - .ok_or(tool_not_found(&request.name))? - .execute(convert_arguments(request)?) - .await - } - SEARCH_TOOL_NAME => { - self.search_tool - .as_ref() - .ok_or(tool_not_found(&request.name))? - .execute(convert_arguments(request)?) - .await - } - EXPLORER_TOOL_NAME => { - self.explorer_tool - .as_ref() - .ok_or(tool_not_found(&request.name))? - .execute(convert_arguments(request)?) - .await - } - EXECUTE_TOOL_NAME => { - self.execute_tool - .as_ref() - .ok_or(tool_not_found(&request.name))? - .execute(graphql::Request { - input: Value::from(request.arguments.clone()), - endpoint: &self.endpoint, - headers: self.headers.clone(), - }) - .await - } - VALIDATE_TOOL_NAME => { - self.validate_tool - .as_ref() - .ok_or(tool_not_found(&request.name))? - .execute(convert_arguments(request)?) - .await - } - _ => { - // Optionally extract the validated token and propagate it to upstream servers - // if found - let mut headers = self.headers.clone(); - if let Some(token) = context.extensions.get::() { - headers.typed_insert(token.deref().clone()); - } - - let graphql_request = graphql::Request { - input: Value::from(request.arguments.clone()), - endpoint: &self.endpoint, - headers, - }; - self.operations - .lock() - .await - .iter() - .find(|op| op.as_ref().name == request.name) - .ok_or(tool_not_found(&request.name))? - .execute(graphql_request) - .await - } - }; - - // Track errors for health check - if let (Err(_), Some(health_check)) = (&result, &self.health_check) { - health_check.record_rejection(); - } - - result - } - - async fn list_tools( - &self, - _request: Option, - _context: RequestContext, - ) -> Result { - Ok(ListToolsResult { - next_cursor: None, - tools: self - .operations - .lock() - .await - .iter() - .map(|op| op.as_ref().clone()) - .chain(self.execute_tool.as_ref().iter().map(|e| e.tool.clone())) - .chain(self.introspect_tool.as_ref().iter().map(|e| e.tool.clone())) - .chain(self.search_tool.as_ref().iter().map(|e| e.tool.clone())) - .chain(self.explorer_tool.as_ref().iter().map(|e| e.tool.clone())) - .chain(self.validate_tool.as_ref().iter().map(|e| e.tool.clone())) - .collect(), - }) - } - - fn get_info(&self) -> ServerInfo { - ServerInfo { - server_info: Implementation { - name: "Apollo MCP Server".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - }, - capabilities: ServerCapabilities::builder() - .enable_tools() - .enable_tool_list_changed() - .build(), - ..Default::default() - } - } -} - -fn tool_not_found(name: &str) -> McpError { - McpError::new( - ErrorCode::METHOD_NOT_FOUND, - format!("Tool {name} not found"), - None, - ) -} - -fn convert_arguments( - arguments: CallToolRequestParam, -) -> Result { - serde_json::from_value(Value::from(arguments.arguments)) - .map_err(|_| McpError::new(ErrorCode::INVALID_PARAMS, "Invalid input".to_string(), None)) } #[cfg(test)] mod tests { use super::*; + use http::HeaderMap; + use url::Url; #[tokio::test] async fn invalid_operations_should_not_crash_server() { @@ -320,23 +131,19 @@ mod tests { .validate() .unwrap(); + let server_handler = ApolloMcpServerHandler::new( + HeaderMap::new(), + Url::parse("http://localhost:8080/graphql").unwrap(), + None, + ); + let running = Running { schema: Arc::new(Mutex::new(schema)), - operations: Arc::new(Mutex::new(vec![])), - headers: HeaderMap::new(), - endpoint: "http://localhost:4000".parse().unwrap(), - execute_tool: None, - introspect_tool: None, - search_tool: None, - explorer_tool: None, - validate_tool: None, custom_scalar_map: None, - peers: Arc::new(RwLock::new(vec![])), - cancellation_token: CancellationToken::new(), mutation_mode: MutationMode::None, disable_type_description: false, disable_schema_description: false, - health_check: None, + server_handler: Arc::new(RwLock::new(server_handler)), }; let operations = vec![ @@ -355,9 +162,10 @@ mod tests { ]; let updated_running = running.update_operations(operations).await.unwrap(); - let updated_operations = updated_running.operations.lock().await; + let updated_operations = updated_running.server_handler.read().await.operations(); + let operations_guard = updated_operations.lock().await; - assert_eq!(updated_operations.len(), 1); - assert_eq!(updated_operations.first().unwrap().as_ref().name, "Valid"); + assert_eq!(operations_guard.len(), 1); + assert_eq!(operations_guard.first().unwrap().as_ref().name, "Valid"); } } diff --git a/crates/apollo-mcp-server/src/server/states/schema_configured.rs b/crates/apollo-mcp-server/src/server/states/schema_configured.rs index 54377df8..b99ebf31 100644 --- a/crates/apollo-mcp-server/src/server/states/schema_configured.rs +++ b/crates/apollo-mcp-server/src/server/states/schema_configured.rs @@ -1,12 +1,12 @@ use apollo_compiler::{Schema, validation::Valid}; use tracing::debug; +use super::Starting; +use crate::server_config::ServerConfig; use crate::{errors::ServerError, operations::RawOperation}; -use super::{Config, Starting}; - pub(super) struct SchemaConfigured { - pub(super) config: Config, + pub(super) config: ServerConfig, pub(super) schema: Valid, } diff --git a/crates/apollo-mcp-server/src/server/states/starting.rs b/crates/apollo-mcp-server/src/server/states/starting.rs index a23b137b..301a8016 100644 --- a/crates/apollo-mcp-server/src/server/states/starting.rs +++ b/crates/apollo-mcp-server/src/server/states/starting.rs @@ -1,41 +1,25 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; -use apollo_compiler::{Name, Schema, ast::OperationType, validation::Valid}; -use axum::{Router, extract::Query, http::StatusCode, response::Json, routing::get}; -use rmcp::transport::StreamableHttpService; -use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; -use rmcp::{ - ServiceExt as _, - transport::{SseServer, sse_server::SseServerConfig, stdio}, -}; -use serde_json::json; +use apollo_compiler::{Schema, validation::Valid}; use tokio::sync::{Mutex, RwLock}; -use tokio_util::sync::CancellationToken; -use tracing::{Instrument as _, debug, error, info, trace}; +use tracing::{debug, error}; -use crate::{ - errors::ServerError, - explorer::Explorer, - health::HealthCheck, - introspection::tools::{ - execute::Execute, introspect::Introspect, search::Search, validate::Validate, - }, - operations::{MutationMode, RawOperation}, - server::Transport, -}; - -use super::{Config, Running, shutdown_signal}; +use super::Running; +use crate::server_config::ServerConfig; +use crate::server_handler::ApolloMcpServerHandler; +use crate::{errors::ServerError, operations::RawOperation}; pub(super) struct Starting { - pub(super) config: Config, + pub(super) config: ServerConfig, pub(super) schema: Valid, pub(super) operations: Vec, } impl Starting { - pub(super) async fn start(self) -> Result { - let peers = Arc::new(RwLock::new(Vec::new())); - + pub(super) async fn start( + self, + server_handler: Arc>, + ) -> Result { let operations: Vec<_> = self .operations .into_iter() @@ -61,210 +45,21 @@ impl Starting { serde_json::to_string_pretty(&operations)? ); - let execute_tool = self - .config - .execute_introspection - .then(|| Execute::new(self.config.mutation_mode)); - - let root_query_type = self - .config - .introspect_introspection - .then(|| { - self.schema - .root_operation(OperationType::Query) - .map(Name::as_str) - .map(|s| s.to_string()) - }) - .flatten(); - let root_mutation_type = self - .config - .introspect_introspection - .then(|| { - matches!(self.config.mutation_mode, MutationMode::All) - .then(|| { - self.schema - .root_operation(OperationType::Mutation) - .map(Name::as_str) - .map(|s| s.to_string()) - }) - .flatten() - }) - .flatten(); + server_handler + .write() + .await + .configure(&self.config, self.schema.clone())?; let schema = Arc::new(Mutex::new(self.schema)); - let introspect_tool = self.config.introspect_introspection.then(|| { - Introspect::new( - schema.clone(), - root_query_type, - root_mutation_type, - self.config.introspect_minify, - ) - }); - let validate_tool = self - .config - .validate_introspection - .then(|| Validate::new(schema.clone())); - let search_tool = if self.config.search_introspection { - Some(Search::new( - schema.clone(), - matches!(self.config.mutation_mode, MutationMode::All), - self.config.search_leaf_depth, - self.config.index_memory_bytes, - self.config.search_minify, - )?) - } else { - None - }; - - let explorer_tool = self.config.explorer_graph_ref.map(Explorer::new); - - let cancellation_token = CancellationToken::new(); - - // Create health check if enabled (only for StreamableHttp transport) - let health_check = match (&self.config.transport, self.config.health_check.enabled) { - ( - Transport::StreamableHttp { - auth: _, - address: _, - port: _, - }, - true, - ) => Some(HealthCheck::new(self.config.health_check.clone())), - _ => None, // No health check for SSE, Stdio, or when disabled - }; let running = Running { schema, - operations: Arc::new(Mutex::new(operations)), - headers: self.config.headers, - endpoint: self.config.endpoint, - execute_tool, - introspect_tool, - search_tool, - explorer_tool, - validate_tool, + server_handler, custom_scalar_map: self.config.custom_scalar_map, - peers, - cancellation_token: cancellation_token.clone(), mutation_mode: self.config.mutation_mode, disable_type_description: self.config.disable_type_description, disable_schema_description: self.config.disable_schema_description, - health_check: health_check.clone(), }; - // Helper to enable auth - macro_rules! with_auth { - ($router:expr, $auth:ident) => {{ - let mut router = $router; - if let Some(auth) = $auth { - router = auth.enable_middleware(router); - } - - router - }}; - } - match self.config.transport { - Transport::StreamableHttp { - auth, - address, - port, - } => { - info!(port = ?port, address = ?address, "Starting MCP server in Streamable HTTP mode"); - let running = running.clone(); - let listen_address = SocketAddr::new(address, port); - let service = StreamableHttpService::new( - move || Ok(running.clone()), - LocalSessionManager::default().into(), - Default::default(), - ); - let mut router = - with_auth!(axum::Router::new().nest_service("/mcp", service), auth); - - // Add health check endpoint if configured - if let Some(health_check) = health_check.filter(|h| h.config().enabled) { - let health_router = Router::new() - .route(&health_check.config().path, get(health_endpoint)) - .with_state(health_check.clone()); - router = router.merge(health_router); - } - - let tcp_listener = tokio::net::TcpListener::bind(listen_address).await?; - tokio::spawn(async move { - // Health check is already active from creation - if let Err(e) = axum::serve(tcp_listener, router) - .with_graceful_shutdown(shutdown_signal()) - .await - { - // This can never really happen - error!("Failed to start MCP server: {e:?}"); - } - }); - } - Transport::SSE { - auth, - address, - port, - } => { - info!(port = ?port, address = ?address, "Starting MCP server in SSE mode"); - let running = running.clone(); - let listen_address = SocketAddr::new(address, port); - - let (server, router) = SseServer::new(SseServerConfig { - bind: listen_address, - sse_path: "/sse".to_string(), - post_path: "/message".to_string(), - ct: cancellation_token, - sse_keep_alive: None, - }); - - // Optionally wrap the router with auth, if enabled - let router = with_auth!(router, auth); - - // Start up the SSE server - // Note: Until RMCP consolidates SSE with the same tower system as StreamableHTTP, - // we need to basically copy the implementation of `SseServer::serve_with_config` here. - let listener = tokio::net::TcpListener::bind(server.config.bind).await?; - let ct = server.config.ct.child_token(); - let axum_server = - axum::serve(listener, router).with_graceful_shutdown(async move { - ct.cancelled().await; - tracing::info!("mcp server cancelled"); - }); - - tokio::spawn( - async move { - if let Err(e) = axum_server.await { - tracing::error!(error = %e, "mcp shutdown with error"); - } - } - .instrument( - tracing::info_span!("mcp-server", bind_address = %server.config.bind), - ), - ); - - server.with_service(move || running.clone()); - } - Transport::Stdio => { - info!("Starting MCP server in stdio mode"); - let service = running.clone().serve(stdio()).await.inspect_err(|e| { - error!("serving error: {:?}", e); - })?; - service.waiting().await.map_err(ServerError::StartupError)?; - } - } - Ok(running) } } - -/// Health check endpoint handler -async fn health_endpoint( - axum::extract::State(health_check): axum::extract::State, - Query(params): Query>, -) -> Result<(StatusCode, Json), StatusCode> { - let query = params.keys().next().map(|k| k.as_str()); - let (health, status_code) = health_check.get_health_state(query); - - trace!(?health, query = ?query, "health check"); - - Ok((status_code, Json(json!(health)))) -} diff --git a/crates/apollo-mcp-server/src/server_config.rs b/crates/apollo-mcp-server/src/server_config.rs new file mode 100644 index 00000000..5c57a16e --- /dev/null +++ b/crates/apollo-mcp-server/src/server_config.rs @@ -0,0 +1,60 @@ +use crate::custom_scalar_map::CustomScalarMap; +use crate::operations::MutationMode; +use bon::bon; +use http::HeaderMap; + +/// Common configuration options for the server +pub struct ServerConfig { + pub(crate) headers: HeaderMap, + pub(crate) execute_introspection: bool, + pub(crate) validate_introspection: bool, + pub(crate) introspect_introspection: bool, + pub(crate) search_introspection: bool, + pub(crate) introspect_minify: bool, + pub(crate) search_minify: bool, + pub(crate) explorer_graph_ref: Option, + pub(crate) custom_scalar_map: Option, + pub(crate) mutation_mode: MutationMode, + pub(crate) disable_type_description: bool, + pub(crate) disable_schema_description: bool, + pub(crate) search_leaf_depth: usize, + pub(crate) index_memory_bytes: usize, +} + +#[bon] +impl ServerConfig { + #[builder] + pub fn new( + headers: HeaderMap, + execute_introspection: bool, + validate_introspection: bool, + introspect_introspection: bool, + search_introspection: bool, + introspect_minify: bool, + search_minify: bool, + explorer_graph_ref: Option, + #[builder(required)] custom_scalar_map: Option, + mutation_mode: MutationMode, + disable_type_description: bool, + disable_schema_description: bool, + search_leaf_depth: usize, + index_memory_bytes: usize, + ) -> Self { + Self { + headers, + execute_introspection, + validate_introspection, + introspect_introspection, + search_introspection, + introspect_minify, + search_minify, + explorer_graph_ref, + custom_scalar_map, + mutation_mode, + disable_type_description, + disable_schema_description, + search_leaf_depth, + index_memory_bytes, + } + } +} diff --git a/crates/apollo-mcp-server/src/server_handler.rs b/crates/apollo-mcp-server/src/server_handler.rs new file mode 100644 index 00000000..cc41d991 --- /dev/null +++ b/crates/apollo-mcp-server/src/server_handler.rs @@ -0,0 +1,303 @@ +use crate::auth::ValidToken; +use crate::errors::{McpError, ServerError}; +use crate::explorer::{EXPLORER_TOOL_NAME, Explorer}; +use crate::graphql; +use crate::graphql::Executable; +use crate::introspection::tools::execute::{EXECUTE_TOOL_NAME, Execute}; +use crate::introspection::tools::introspect::{INTROSPECT_TOOL_NAME, Introspect}; +use crate::introspection::tools::search::{SEARCH_TOOL_NAME, Search}; +use crate::introspection::tools::validate::{VALIDATE_TOOL_NAME, Validate}; +use crate::operations::{MutationMode, Operation}; +use crate::server_config::ServerConfig; +use crate::telemetry::Telemetry; +use apollo_compiler::ast::OperationType; +use apollo_compiler::validation::Valid; +use apollo_compiler::{Name, Schema}; +use headers::HeaderMapExt; +use http::HeaderMap; +use rmcp::model::{ + CallToolRequestParam, CallToolResult, ErrorCode, Implementation, InitializeRequestParam, + InitializeResult, ListToolsResult, PaginatedRequestParam, ServerCapabilities, ServerInfo, +}; +use rmcp::service::RequestContext; +use rmcp::{Peer, RoleServer, ServerHandler, ServiceError}; +use serde_json::Value; +use std::ops::Deref; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; +use tracing::{debug, error}; +use url::Url; + +#[derive(Clone)] +pub struct ApolloMcpServerHandler { + pub(super) operations: Arc>>, + pub(super) headers: HeaderMap, + pub(super) endpoint: Url, + pub(super) execute_tool: Option, + pub(super) introspect_tool: Option, + pub(super) search_tool: Option, + pub(super) explorer_tool: Option, + pub(super) validate_tool: Option, + pub(super) peers: Arc>>>, + pub(super) telemetry: Option>, +} + +impl ApolloMcpServerHandler { + pub fn new( + headers: HeaderMap, + endpoint: Url, + telemetry: Option>, + ) -> ApolloMcpServerHandler { + Self { + operations: Arc::new(Mutex::new(Vec::new())), + headers, + endpoint, + execute_tool: None, + introspect_tool: None, + search_tool: None, + explorer_tool: None, + validate_tool: None, + peers: Arc::new(RwLock::new(Vec::new())), + telemetry, + } + } + + pub(crate) fn configure( + &mut self, + config: &ServerConfig, + schema: Valid, + ) -> Result<(), ServerError> { + let root_query_type = config + .introspect_introspection + .then(|| { + schema + .root_operation(OperationType::Query) + .map(Name::as_str) + .map(|s| s.to_string()) + }) + .flatten(); + let root_mutation_type = config + .introspect_introspection + .then(|| { + matches!(config.mutation_mode, MutationMode::All) + .then(|| { + schema + .root_operation(OperationType::Mutation) + .map(Name::as_str) + .map(|s| s.to_string()) + }) + .flatten() + }) + .flatten(); + + let schema = Arc::new(Mutex::new(schema)); + self.execute_tool = config + .execute_introspection + .then(|| Execute::new(config.mutation_mode)); + + self.introspect_tool = config.introspect_introspection.then(|| { + Introspect::new( + schema.clone(), + root_query_type, + root_mutation_type, + config.introspect_minify, + ) + }); + self.validate_tool = config + .validate_introspection + .then(|| Validate::new(schema.clone())); + + self.search_tool = if config.search_introspection { + Some(Search::new( + schema.clone(), + matches!(config.mutation_mode, MutationMode::All), + config.search_leaf_depth, + config.index_memory_bytes, + config.search_minify, + )?) + } else { + None + }; + + self.explorer_tool = config.explorer_graph_ref.clone().map(Explorer::new); + + self.peers = Arc::new(RwLock::new(Vec::new())); + + Ok(()) + } + + pub(crate) fn peers(&self) -> Arc>>> { + Arc::clone(&self.peers) + } + + pub(crate) fn operations(&self) -> Arc>> { + Arc::clone(&self.operations) + } + + pub(crate) async fn notify_tool_list_changed(&self, peers: Arc>>>) { + let mut peers = peers.write().await; + if !peers.is_empty() { + debug!( + "Operations changed, notifying {} peers of tool change", + peers.len() + ); + } + let mut retained_peers = Vec::new(); + for peer in peers.iter() { + if !peer.is_transport_closed() { + match peer.notify_tool_list_changed().await { + Ok(_) => retained_peers.push(peer.clone()), + Err(ServiceError::TransportSend(_) | ServiceError::TransportClosed) => { + error!("Failed to notify peer of tool list change - dropping peer",); + } + Err(e) => { + error!("Failed to notify peer of tool list change {:?}", e); + retained_peers.push(peer.clone()); + } + } + } + } + *peers = retained_peers; + } +} + +impl ServerHandler for ApolloMcpServerHandler { + async fn initialize( + &self, + _request: InitializeRequestParam, + context: RequestContext, + ) -> Result { + // TODO: how to remove these? + let mut peers = self.peers.write().await; + peers.push(context.peer); + Ok(self.get_info()) + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> Result { + let result = match request.name.as_ref() { + INTROSPECT_TOOL_NAME => { + self.introspect_tool + .as_ref() + .ok_or(tool_not_found(&request.name))? + .execute(convert_arguments(request)?) + .await + } + SEARCH_TOOL_NAME => { + self.search_tool + .as_ref() + .ok_or(tool_not_found(&request.name))? + .execute(convert_arguments(request)?) + .await + } + EXPLORER_TOOL_NAME => { + self.explorer_tool + .as_ref() + .ok_or(tool_not_found(&request.name))? + .execute(convert_arguments(request)?) + .await + } + EXECUTE_TOOL_NAME => { + self.execute_tool + .as_ref() + .ok_or(tool_not_found(&request.name))? + .execute(graphql::Request { + input: Value::from(request.arguments.clone()), + endpoint: &self.endpoint, + headers: self.headers.clone(), + }) + .await + } + VALIDATE_TOOL_NAME => { + self.validate_tool + .as_ref() + .ok_or(tool_not_found(&request.name))? + .execute(convert_arguments(request)?) + .await + } + _ => { + // Optionally extract the validated token and propagate it to upstream servers + // if found + let mut headers = self.headers.clone(); + if let Some(token) = context.extensions.get::() { + headers.typed_insert(token.deref().clone()); + } + + let graphql_request = graphql::Request { + input: Value::from(request.arguments.clone()), + endpoint: &self.endpoint, + headers, + }; + self.operations + .lock() + .await + .iter() + .find(|op| op.as_ref().name == request.name) + .ok_or(tool_not_found(&request.name))? + .execute(graphql_request) + .await + } + }; + + // Track errors for health check + if let (Err(_), Some(telemetry)) = (&result, &self.telemetry) { + telemetry.record_error() + } + + result + } + + async fn list_tools( + &self, + _request: Option, + _context: RequestContext, + ) -> Result { + Ok(ListToolsResult { + next_cursor: None, + tools: self + .operations + .lock() + .await + .iter() + .map(|op| op.as_ref().clone()) + .chain(self.execute_tool.as_ref().iter().map(|e| e.tool.clone())) + .chain(self.introspect_tool.as_ref().iter().map(|e| e.tool.clone())) + .chain(self.search_tool.as_ref().iter().map(|e| e.tool.clone())) + .chain(self.explorer_tool.as_ref().iter().map(|e| e.tool.clone())) + .chain(self.validate_tool.as_ref().iter().map(|e| e.tool.clone())) + .collect(), + }) + } + + fn get_info(&self) -> ServerInfo { + ServerInfo { + server_info: Implementation { + name: "Apollo MCP Server".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }, + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_tool_list_changed() + .build(), + ..Default::default() + } + } +} + +fn tool_not_found(name: &str) -> McpError { + McpError::new( + ErrorCode::METHOD_NOT_FOUND, + format!("Tool {name} not found"), + None, + ) +} + +fn convert_arguments( + arguments: CallToolRequestParam, +) -> Result { + serde_json::from_value(Value::from(arguments.arguments)) + .map_err(|_| McpError::new(ErrorCode::INVALID_PARAMS, "Invalid input".to_string(), None)) +} diff --git a/crates/apollo-mcp-server/src/telemetry.rs b/crates/apollo-mcp-server/src/telemetry.rs new file mode 100644 index 00000000..0a59b2f0 --- /dev/null +++ b/crates/apollo-mcp-server/src/telemetry.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +pub struct InMemoryTelemetry { + errored: Arc, +} + +impl Default for InMemoryTelemetry { + fn default() -> Self { + Self { + errored: Arc::new(AtomicUsize::new(0)), + } + } +} + +impl InMemoryTelemetry { + pub fn new() -> Self { + Self::default() + } +} + +pub trait Telemetry: Send + Sync { + fn errors(&self) -> usize; + fn set_error_count(&self, errors: usize); + fn record_error(&self); +} + +impl Telemetry for InMemoryTelemetry { + fn errors(&self) -> usize { + self.errored.load(Ordering::Relaxed) + } + + fn set_error_count(&self, errors: usize) { + self.errored.store(errors, Ordering::Relaxed) + } + + fn record_error(&self) { + self.errored.fetch_add(1, Ordering::Relaxed); + } +}