diff --git a/.config/nats.dic b/.config/nats.dic index 1cebf663d..a993a92db 100644 --- a/.config/nats.dic +++ b/.config/nats.dic @@ -146,3 +146,4 @@ ObjectMetadata S2 inactive_threshold max_ack_pending +footgun diff --git a/async-nats/src/connector.rs b/async-nats/src/connector.rs index c7b6e8728..9bca6f11f 100644 --- a/async-nats/src/connector.rs +++ b/async-nats/src/connector.rs @@ -20,6 +20,7 @@ use crate::AuthError; use crate::ClientError; use crate::ClientOp; use crate::ConnectError; +use crate::ConnectErrorKind; use crate::ConnectInfo; use crate::Event; use crate::Protocol; @@ -60,6 +61,7 @@ pub(crate) struct ConnectorOptions { pub(crate) read_buffer_capacity: u16, pub(crate) reconnect_delay_callback: Box Duration + Send + Sync + 'static>, pub(crate) auth_callback: Option, Result>>, + pub(crate) max_reconnects: Option, } /// Maintains a list of servers and establishes connections. @@ -100,16 +102,24 @@ impl Connector { }) } - pub(crate) async fn connect(&mut self) -> (ServerInfo, Connection) { + pub(crate) async fn connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> { loop { match self.try_connect().await { - Ok(inner) => return inner, - Err(error) => { - self.events_tx - .send(Event::ClientError(ClientError::Other(error.to_string()))) - .await - .ok(); - } + Ok(inner) => return Ok(inner), + Err(error) => match error.kind() { + ConnectErrorKind::MaxReconnects => { + return Err(ConnectError::with_source( + crate::ConnectErrorKind::MaxReconnects, + error, + )) + } + other => { + self.events_tx + .send(Event::ClientError(ClientError::Other(other.to_string()))) + .await + .ok(); + } + }, } } } @@ -126,6 +136,15 @@ impl Connector { for (server_addr, _) in servers { self.attempts += 1; + if let Some(max_reconnects) = self.options.max_reconnects { + if self.attempts > max_reconnects { + self.events_tx + .send(Event::ClientError(ClientError::MaxReconnects)) + .await + .ok(); + return Err(ConnectError::new(crate::ConnectErrorKind::MaxReconnects)); + } + } let duration = (self.options.reconnect_delay_callback)(self.attempts); diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 45244189f..d1ce061c1 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -584,7 +584,9 @@ impl ConnectionHandler { ExitReason::Disconnected(err) => { debug!(?err, "disconnected"); - self.handle_disconnect().await; + if self.handle_disconnect().await.is_err() { + break; + }; debug!("reconnected"); } ExitReason::Closed => break, @@ -796,16 +798,16 @@ impl ConnectionHandler { } } - async fn handle_disconnect(&mut self) { + async fn handle_disconnect(&mut self) -> Result<(), ConnectError> { self.pending_pings = 0; self.connector.events_tx.try_send(Event::Disconnected).ok(); self.connector.state_tx.send(State::Disconnected).ok(); - self.handle_reconnect().await; + self.handle_reconnect().await } - async fn handle_reconnect(&mut self) { - let (info, connection) = self.connector.connect().await; + async fn handle_reconnect(&mut self) -> Result<(), ConnectError> { + let (info, connection) = self.connector.connect().await?; self.connection = connection; let _ = self.info_sender.send(info); @@ -829,6 +831,7 @@ impl ConnectionHandler { } self.connector.events_tx.try_send(Event::Connected).ok(); + Ok(()) } } @@ -874,6 +877,7 @@ pub async fn connect_with_options( read_buffer_capacity: options.read_buffer_capacity, reconnect_delay_callback: options.reconnect_delay_callback, auth_callback: options.auth_callback, + max_reconnects: options.max_reconnects, }, events_tx, state_tx, @@ -912,7 +916,13 @@ pub async fn connect_with_options( task::spawn(async move { if connection.is_none() && options.retry_on_initial_connect { - let (info, connection_ok) = connector.connect().await; + let (info, connection_ok) = match connector.connect().await { + Ok((info, connection)) => (info, connection), + Err(err) => { + error!("connection closed: {}", err); + return; + } + }; info_sender.send(info).ok(); connection = Some(connection_ok); } @@ -1034,6 +1044,8 @@ pub enum ConnectErrorKind { Tls, /// Other IO error. Io, + /// Reached the maximum number of reconnects. + MaxReconnects, } impl Display for ConnectErrorKind { @@ -1046,6 +1058,7 @@ impl Display for ConnectErrorKind { Self::TimedOut => write!(f, "timed out"), Self::Tls => write!(f, "TLS error"), Self::Io => write!(f, "IO error"), + Self::MaxReconnects => write!(f, "reached maximum number of reconnects"), } } } @@ -1224,11 +1237,13 @@ pub enum ServerError { #[derive(Clone, Debug, Eq, PartialEq)] pub enum ClientError { Other(String), + MaxReconnects, } impl std::fmt::Display for ClientError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Other(error) => write!(f, "nats: {error}"), + Self::MaxReconnects => write!(f, "nats: max reconnects reached"), } } } diff --git a/async-nats/src/options.rs b/async-nats/src/options.rs index 7842bba03..46c9c0a1a 100644 --- a/async-nats/src/options.rs +++ b/async-nats/src/options.rs @@ -101,7 +101,7 @@ impl Default for ConnectOptions { no_echo: false, retry_on_failed_connect: false, reconnect_buffer_size: 8 * 1024 * 1024, - max_reconnects: Some(60), + max_reconnects: None, connection_timeout: Duration::from_secs(5), tls_required: false, tls_first: false, @@ -829,6 +829,31 @@ impl ConnectOptions { self } + /// Specifies the number of consecutive reconnect attempts the client will + /// make before giving up. This is useful for preventing zombie services + /// from endlessly reaching the servers, but it can also be a footgun and + /// surprise for users who do not expect that the client can give up + /// entirely. + /// + /// Pass `None` or `0` for no limit. + /// + /// # Examples + /// ``` + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// async_nats::ConnectOptions::new() + /// .max_reconnects(None) + /// .connect("demo.nats.io") + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn max_reconnects>>(mut self, max_reconnects: T) -> ConnectOptions { + let val: Option = max_reconnects.into(); + self.max_reconnects = if val == Some(0) { None } else { val }; + self + } + /// By default, a server may advertise other servers in the cluster known to it. /// By setting this option, the client will ignore the advertised servers. /// This may be useful if the client may not be able to reach them. diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index aea71422c..03ab866b9 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -837,4 +837,33 @@ mod client { .await .unwrap(); } + + #[tokio::test] + async fn max_reconnects() { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let _client = ConnectOptions::new() + .max_reconnects(5) + .retry_on_initial_connect() + .event_callback(move |event| { + let tx = tx.clone(); + async move { + println!("event: {event}"); + tx.send(event).unwrap(); + } + }) + .connect("localhost:7777") + .await + .unwrap(); + + for _ in 0..5 { + match rx.recv().await.unwrap() { + Event::ClientError(async_nats::ClientError::Other(_)) => (), + other => panic!("unexpected event: {:?}", other), + }; + } + assert_eq!( + rx.recv().await.unwrap(), + Event::ClientError(async_nats::ClientError::MaxReconnects) + ); + } }