Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .config/nats.dic
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ ObjectMetadata
S2
inactive_threshold
max_ack_pending
footgun
35 changes: 27 additions & 8 deletions async-nats/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::AuthError;
use crate::ClientError;
use crate::ClientOp;
use crate::ConnectError;
use crate::ConnectErrorKind;
use crate::ConnectInfo;
use crate::Event;
use crate::Protocol;
Expand Down Expand Up @@ -60,6 +61,7 @@ pub(crate) struct ConnectorOptions {
pub(crate) read_buffer_capacity: u16,
pub(crate) reconnect_delay_callback: Box<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
pub(crate) max_reconnects: Option<usize>,
}

/// Maintains a list of servers and establishes connections.
Expand Down Expand Up @@ -100,16 +102,24 @@ impl Connector {
})
}

pub(crate) async fn connect(&mut self) -> (ServerInfo, Connection) {
pub(crate) async fn connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
loop {
match self.try_connect().await {
Ok(inner) => return inner,
Err(error) => {
self.events_tx
.send(Event::ClientError(ClientError::Other(error.to_string())))
.await
.ok();
}
Ok(inner) => return Ok(inner),
Err(error) => match error.kind() {
ConnectErrorKind::MaxReconnects => {
return Err(ConnectError::with_source(
crate::ConnectErrorKind::MaxReconnects,
error,
))
}
other => {
self.events_tx
.send(Event::ClientError(ClientError::Other(other.to_string())))
.await
.ok();
}
},
}
}
}
Expand All @@ -126,6 +136,15 @@ impl Connector {

for (server_addr, _) in servers {
self.attempts += 1;
if let Some(max_reconnects) = self.options.max_reconnects {
if self.attempts > max_reconnects {
self.events_tx
.send(Event::ClientError(ClientError::MaxReconnects))
.await
.ok();
return Err(ConnectError::new(crate::ConnectErrorKind::MaxReconnects));
}
}

let duration = (self.options.reconnect_delay_callback)(self.attempts);

Expand Down
27 changes: 21 additions & 6 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,9 @@ impl ConnectionHandler {
ExitReason::Disconnected(err) => {
debug!(?err, "disconnected");

self.handle_disconnect().await;
if self.handle_disconnect().await.is_err() {
break;
};
debug!("reconnected");
}
ExitReason::Closed => break,
Expand Down Expand Up @@ -796,16 +798,16 @@ impl ConnectionHandler {
}
}

async fn handle_disconnect(&mut self) {
async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
self.pending_pings = 0;
self.connector.events_tx.try_send(Event::Disconnected).ok();
self.connector.state_tx.send(State::Disconnected).ok();

self.handle_reconnect().await;
self.handle_reconnect().await
}

async fn handle_reconnect(&mut self) {
let (info, connection) = self.connector.connect().await;
async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
let (info, connection) = self.connector.connect().await?;
self.connection = connection;
let _ = self.info_sender.send(info);

Expand All @@ -829,6 +831,7 @@ impl ConnectionHandler {
}

self.connector.events_tx.try_send(Event::Connected).ok();
Ok(())
}
}

Expand Down Expand Up @@ -874,6 +877,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
read_buffer_capacity: options.read_buffer_capacity,
reconnect_delay_callback: options.reconnect_delay_callback,
auth_callback: options.auth_callback,
max_reconnects: options.max_reconnects,
},
events_tx,
state_tx,
Expand Down Expand Up @@ -912,7 +916,13 @@ pub async fn connect_with_options<A: ToServerAddrs>(

task::spawn(async move {
if connection.is_none() && options.retry_on_initial_connect {
let (info, connection_ok) = connector.connect().await;
let (info, connection_ok) = match connector.connect().await {
Ok((info, connection)) => (info, connection),
Err(err) => {
error!("connection closed: {}", err);
return;
}
};
info_sender.send(info).ok();
connection = Some(connection_ok);
}
Expand Down Expand Up @@ -1034,6 +1044,8 @@ pub enum ConnectErrorKind {
Tls,
/// Other IO error.
Io,
/// Reached the maximum number of reconnects.
MaxReconnects,
}

impl Display for ConnectErrorKind {
Expand All @@ -1046,6 +1058,7 @@ impl Display for ConnectErrorKind {
Self::TimedOut => write!(f, "timed out"),
Self::Tls => write!(f, "TLS error"),
Self::Io => write!(f, "IO error"),
Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
}
}
}
Expand Down Expand Up @@ -1224,11 +1237,13 @@ pub enum ServerError {
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ClientError {
Other(String),
MaxReconnects,
}
impl std::fmt::Display for ClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Other(error) => write!(f, "nats: {error}"),
Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
}
}
}
Expand Down
27 changes: 26 additions & 1 deletion async-nats/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl Default for ConnectOptions {
no_echo: false,
retry_on_failed_connect: false,
reconnect_buffer_size: 8 * 1024 * 1024,
max_reconnects: Some(60),
max_reconnects: None,
connection_timeout: Duration::from_secs(5),
tls_required: false,
tls_first: false,
Expand Down Expand Up @@ -829,6 +829,31 @@ impl ConnectOptions {
self
}

/// Specifies the number of consecutive reconnect attempts the client will
/// make before giving up. This is useful for preventing zombie services
/// from endlessly reaching the servers, but it can also be a footgun and
/// surprise for users who do not expect that the client can give up
/// entirely.
///
/// Pass `None` or `0` for no limit.
///
/// # Examples
/// ```
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// async_nats::ConnectOptions::new()
/// .max_reconnects(None)
/// .connect("demo.nats.io")
/// .await?;
/// # Ok(())
/// # }
/// ```
pub fn max_reconnects<T: Into<Option<usize>>>(mut self, max_reconnects: T) -> ConnectOptions {
let val: Option<usize> = max_reconnects.into();
self.max_reconnects = if val == Some(0) { None } else { val };
self
}

/// By default, a server may advertise other servers in the cluster known to it.
/// By setting this option, the client will ignore the advertised servers.
/// This may be useful if the client may not be able to reach them.
Expand Down
29 changes: 29 additions & 0 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,4 +837,33 @@ mod client {
.await
.unwrap();
}

#[tokio::test]
async fn max_reconnects() {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let _client = ConnectOptions::new()
.max_reconnects(5)
.retry_on_initial_connect()
.event_callback(move |event| {
let tx = tx.clone();
async move {
println!("event: {event}");
tx.send(event).unwrap();
}
})
.connect("localhost:7777")
.await
.unwrap();

for _ in 0..5 {
match rx.recv().await.unwrap() {
Event::ClientError(async_nats::ClientError::Other(_)) => (),
other => panic!("unexpected event: {:?}", other),
};
}
assert_eq!(
rx.recv().await.unwrap(),
Event::ClientError(async_nats::ClientError::MaxReconnects)
);
}
}