diff --git a/.config/nats.dic b/.config/nats.dic index fec16ee97..0a635906d 100644 --- a/.config/nats.dic +++ b/.config/nats.dic @@ -156,3 +156,7 @@ create_consumer_strict_on_stream leafnodes get_stream get_stream_no_info +lifecycle +AtomicU64 +with_deleted +StreamInfoBuilder diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0cd348b0c..1463ede3d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -164,7 +164,7 @@ jobs: - name: Install msrv Rust on ubuntu-latest id: install-rust - uses: dtolnay/rust-toolchain@1.70.0 + uses: dtolnay/rust-toolchain@1.79.0 - name: Cache the build artifacts uses: Swatinem/rust-cache@v2 with: diff --git a/async-nats/CHANGELOG.md b/async-nats/CHANGELOG.md index 0b8e17b21..e03012477 100644 --- a/async-nats/CHANGELOG.md +++ b/async-nats/CHANGELOG.md @@ -1,5 +1,15 @@ -# v0.36.0 +# v0.37.0 +## Overview +A smaller release containing stats and Watcher improvements. + +## What's Changed +* Add Client stats by @Jarema in https://github.com/nats-io/nats.rs/pull/1314 +* Improve kv::Watcher without messages by @Jarema in https://github.com/nats-io/nats.rs/pull/1321 + +**Full Changelog**: https://github.com/nats-io/nats.rs/compare/async-nats/v0.36.0...async-nats/v0.37.0 + +# v0.36.0 ## Overview This release adds a useful `futures::Sink`, and ability to get `Stream` handle without IO call, among other changes. diff --git a/async-nats/Cargo.toml b/async-nats/Cargo.toml index 97ae22cc7..229ab568d 100644 --- a/async-nats/Cargo.toml +++ b/async-nats/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "async-nats" authors = ["Tomasz Pietrek ", "Casper Beyer "] -version = "0.36.0" +version = "0.37.0" edition = "2021" -rust = "1.74.0" +rust = "1.79.0" description = "A async Rust NATS client" license = "Apache-2.0" documentation = "https://docs.rs/async-nats" @@ -41,6 +41,8 @@ ring = { version = "0.17", optional = true } rand = "0.8" webpki = { package = "rustls-webpki", version = "0.102" } portable-atomic = "1" +tokio-websockets = { version = "0.10", features = ["client", "rand", "rustls-native-roots"], optional = true } +pin-project = "1.0" [dev-dependencies] ring = "0.17" @@ -57,13 +59,13 @@ jsonschema = "0.17.1" # for -Z minimal-versions num = "0.4.1" - [features] default = ["server_2_10", "ring"] # Enables Service API for the client. service = [] -aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs"] -ring = ["dep:ring", "tokio-rustls/ring"] +websockets = ["dep:tokio-websockets"] +aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs", "tokio-websockets/aws-lc-rs"] +ring = ["dep:ring", "tokio-rustls/ring", "tokio-websockets/ring"] fips = ["aws-lc-rs", "tokio-rustls/fips"] # All experimental features are part of this feature flag. experimental = ["service"] diff --git a/async-nats/src/auth.rs b/async-nats/src/auth.rs index e04e72455..968c5c6da 100644 --- a/async-nats/src/auth.rs +++ b/async-nats/src/auth.rs @@ -5,7 +5,7 @@ pub struct Auth { pub jwt: Option, pub nkey: Option, pub(crate) signature_callback: Option>>, - pub signature: Option, + pub signature: Option>, pub username: Option, pub password: Option, pub token: Option, diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 4ff1140c8..5f85f67c7 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -83,6 +83,7 @@ pub struct Client { inbox_prefix: Arc, request_timeout: Option, max_payload: Arc, + connection_stats: Arc, } impl Sink for Client { @@ -108,6 +109,7 @@ impl Sink for Client { } impl Client { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( info: tokio::sync::watch::Receiver, state: tokio::sync::watch::Receiver, @@ -116,6 +118,7 @@ impl Client { inbox_prefix: String, request_timeout: Option, max_payload: Arc, + statistics: Arc, ) -> Client { let poll_sender = PollSender::new(sender.clone()); Client { @@ -128,9 +131,25 @@ impl Client { inbox_prefix: inbox_prefix.into(), request_timeout, max_payload, + connection_stats: statistics, } } + /// Returns the default timeout for requests set when creating the client. + /// + /// # Examples + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// let client = async_nats::connect("demo.nats.io").await?; + /// println!("default request timeout: {:?}", client.timeout()); + /// # Ok(()) + /// # } + /// ``` + pub fn timeout(&self) -> Option { + self.request_timeout + } + /// Returns last received info from the server. /// /// # Examples @@ -612,6 +631,39 @@ impl Client { Ok(()) } + /// Drains all subscriptions, stops any new messages from being published, and flushes any remaining + /// messages, then closes the connection. Once completed, any associated streams associated with the + /// client will be closed, and further client commands will fail + /// + /// # Examples + /// + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// use futures::StreamExt; + /// let client = async_nats::connect("demo.nats.io").await?; + /// let mut subscription = client.subscribe("events.>").await?; + /// + /// client.drain().await?; + /// + /// # // existing subscriptions are closed and further commands will fail + /// assert!(subscription.next().await.is_none()); + /// client + /// .subscribe("events.>") + /// .await + /// .expect_err("Expected further commands to fail"); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn drain(&self) -> Result<(), DrainError> { + // Drain all subscriptions + self.sender.send(Command::Drain { sid: None }).await?; + + // Remaining process is handled on the handler-side + Ok(()) + } + /// Returns the current state of the connection. /// /// # Examples @@ -649,6 +701,26 @@ impl Client { .await .map_err(Into::into) } + + /// Returns struct representing statistics of the whole lifecycle of the client. + /// This includes number of bytes sent/received, number of messages sent/received, + /// and number of times the connection was established. + /// As this returns [Arc] with [AtomicU64] fields, it can be safely reused and shared + /// across threads. + /// + /// # Examples + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// let client = async_nats::connect("demo.nats.io").await?; + /// let statistics = client.statistics(); + /// println!("client statistics: {:#?}", statistics); + /// # Ok(()) + /// # } + /// ``` + pub fn statistics(&self) -> Arc { + self.connection_stats.clone() + } } /// Used for building customized requests. @@ -769,6 +841,16 @@ impl From> for SubscribeError { } } +#[derive(Error, Debug)] +#[error("failed to send drain: {0}")] +pub struct DrainError(#[source] crate::Error); + +impl From> for DrainError { + fn from(err: tokio::sync::mpsc::error::SendError) -> Self { + DrainError(Box::new(err)) + } +} + #[derive(Clone, Copy, Debug, PartialEq)] pub enum RequestErrorKind { /// There are services listening on requested subject, but they didn't respond @@ -826,3 +908,19 @@ impl Display for FlushErrorKind { } pub type FlushError = Error; + +/// Represents statistics for the instance of the client throughout its lifecycle. +#[derive(Default, Debug)] +pub struct Statistics { + /// Number of bytes received. This does not include the protocol overhead. + pub in_bytes: AtomicU64, + /// Number of bytes sent. This doe not include the protocol overhead. + pub out_bytes: AtomicU64, + /// Number of messages received. + pub in_messages: AtomicU64, + /// Number of messages sent. + pub out_messages: AtomicU64, + /// Number of times connection was established. + /// Initial connect will be counted as well, then all successful reconnects. + pub connects: AtomicU64, +} diff --git a/async-nats/src/connection.rs b/async-nats/src/connection.rs index 4d7352266..1533d3f9a 100644 --- a/async-nats/src/connection.rs +++ b/async-nats/src/connection.rs @@ -19,15 +19,25 @@ use std::future::{self, Future}; use std::io::IoSlice; use std::pin::Pin; use std::str::{self, FromStr}; +use std::sync::atomic::Ordering; +use std::sync::Arc; use std::task::{Context, Poll}; +#[cfg(feature = "websockets")] +use { + futures::{SinkExt, StreamExt}, + pin_project::pin_project, + tokio::io::ReadBuf, + tokio_websockets::WebSocketStream, +}; + use bytes::{Buf, Bytes, BytesMut}; use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite}; use crate::header::{HeaderMap, HeaderName, IntoHeaderValue}; use crate::status::StatusCode; use crate::subject::Subject; -use crate::{ClientOp, ServerError, ServerOp}; +use crate::{ClientOp, ServerError, ServerOp, Statistics}; /// Soft limit for the amount of bytes in [`Connection::write_buf`] /// and [`Connection::flattened_writes`]. @@ -80,12 +90,17 @@ pub(crate) struct Connection { write_buf_len: usize, flattened_writes: BytesMut, can_flush: bool, + statistics: Arc, } /// Internal representation of the connection. /// Holds connection with NATS Server and communicates with `Client` via channels. impl Connection { - pub(crate) fn new(stream: Box, read_buffer_capacity: usize) -> Self { + pub(crate) fn new( + stream: Box, + read_buffer_capacity: usize, + statistics: Arc, + ) -> Self { Self { stream, read_buf: BytesMut::with_capacity(read_buffer_capacity), @@ -93,6 +108,7 @@ impl Connection { write_buf_len: 0, flattened_writes: BytesMut::new(), can_flush: false, + statistics, } } @@ -407,7 +423,10 @@ impl Connection { Poll::Pending => Poll::Pending, Poll::Ready(Ok(0)) if self.read_buf.is_empty() => Poll::Ready(Ok(None)), Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::ConnectionReset.into())), - Poll::Ready(Ok(_n)) => continue, + Poll::Ready(Ok(n)) => { + self.statistics.in_bytes.add(n as u64, Ordering::Relaxed); + continue; + } Poll::Ready(Err(err)) => Poll::Ready(Err(err)), }; } @@ -544,6 +563,7 @@ impl Connection { match Pin::new(&mut self.stream).poll_write(cx, buf) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(n)) => { + self.statistics.out_bytes.add(n as u64, Ordering::Relaxed); self.write_buf_len -= n; self.can_flush = true; @@ -564,7 +584,6 @@ impl Connection { } } } - /// Write the internal buffers into the write stream using vectored write operations /// /// Writes [`WRITE_VECTORED_CHUNKS`] at a time. More efficient _if_ @@ -595,6 +614,7 @@ impl Connection { match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(mut n)) => { + self.statistics.out_bytes.add(n as u64, Ordering::Relaxed); self.write_buf_len -= n; self.can_flush = true; @@ -671,16 +691,120 @@ impl Connection { } } +#[cfg(feature = "websockets")] +#[pin_project] +pub(crate) struct WebSocketAdapter { + #[pin] + pub(crate) inner: WebSocketStream, + pub(crate) read_buf: BytesMut, +} + +#[cfg(feature = "websockets")] +impl WebSocketAdapter { + pub(crate) fn new(inner: WebSocketStream) -> Self { + Self { + inner, + read_buf: BytesMut::new(), + } + } +} + +#[cfg(feature = "websockets")] +impl AsyncRead for WebSocketAdapter +where + T: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut this = self.project(); + + loop { + // If we have data in the read buffer, let's move it to the output buffer. + if !this.read_buf.is_empty() { + let len = std::cmp::min(buf.remaining(), this.read_buf.len()); + buf.put_slice(&this.read_buf.split_to(len)); + return Poll::Ready(Ok(())); + } + + match this.inner.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(message))) => { + this.read_buf.extend_from_slice(message.as_payload()); + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))); + } + Poll::Ready(None) => { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "WebSocket closed", + ))); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} + +#[cfg(feature = "websockets")] +impl AsyncWrite for WebSocketAdapter +where + T: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + + let data = buf.to_vec(); + match this.inner.poll_ready_unpin(cx) { + Poll::Ready(Ok(())) => match this + .inner + .start_send_unpin(tokio_websockets::Message::binary(data)) + { + Ok(()) => Poll::Ready(Ok(buf.len())), + Err(e) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))), + }, + Poll::Ready(Err(e)) => { + Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))) + } + Poll::Pending => Poll::Pending, + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .inner + .poll_flush_unpin(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .inner + .poll_close_unpin(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + } +} + #[cfg(test)] mod read_op { + use std::sync::Arc; + use super::Connection; - use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, StatusCode}; + use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode}; use tokio::io::{self, AsyncWriteExt}; #[tokio::test] async fn ok() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); server.write_all(b"+OK\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -690,7 +814,7 @@ mod read_op { #[tokio::test] async fn ping() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); server.write_all(b"PING\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -700,7 +824,7 @@ mod read_op { #[tokio::test] async fn pong() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); server.write_all(b"PONG\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -710,7 +834,7 @@ mod read_op { #[tokio::test] async fn info() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); server.write_all(b"INFO {}\r\n").await.unwrap(); server.flush().await.unwrap(); @@ -737,7 +861,7 @@ mod read_op { #[tokio::test] async fn error() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); server.write_all(b"INFO {}\r\n").await.unwrap(); let result = connection.read_op().await.unwrap(); @@ -759,7 +883,7 @@ mod read_op { #[tokio::test] async fn message() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); server .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n") @@ -906,7 +1030,7 @@ mod read_op { #[tokio::test] async fn unknown() { let (stream, mut server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); server.write_all(b"ONE\r\n").await.unwrap(); connection.read_op().await.unwrap_err(); @@ -956,14 +1080,16 @@ mod read_op { #[cfg(test)] mod write_op { + use std::sync::Arc; + use super::Connection; - use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol}; + use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics}; use tokio::io::{self, AsyncBufReadExt, BufReader}; #[tokio::test] async fn publish() { let (stream, server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); connection .easy_write_and_flush( @@ -1032,7 +1158,7 @@ mod write_op { #[tokio::test] async fn subscribe() { let (stream, server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); connection .easy_write_and_flush( @@ -1071,7 +1197,7 @@ mod write_op { #[tokio::test] async fn unsubscribe() { let (stream, server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); connection .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter()) @@ -1102,7 +1228,7 @@ mod write_op { #[tokio::test] async fn ping() { let (stream, server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); let mut reader = BufReader::new(server); let mut buffer = String::new(); @@ -1120,7 +1246,7 @@ mod write_op { #[tokio::test] async fn pong() { let (stream, server) = io::duplex(128); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); let mut reader = BufReader::new(server); let mut buffer = String::new(); @@ -1138,7 +1264,7 @@ mod write_op { #[tokio::test] async fn connect() { let (stream, server) = io::duplex(1024); - let mut connection = Connection::new(Box::new(stream), 0); + let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default())); let mut reader = BufReader::new(server); let mut buffer = String::new(); diff --git a/async-nats/src/connector.rs b/async-nats/src/connector.rs index 87112c038..9c61a8553 100644 --- a/async-nats/src/connector.rs +++ b/async-nats/src/connector.rs @@ -12,8 +12,11 @@ // limitations under the License. use crate::auth::Auth; +use crate::client::Statistics; use crate::connection::Connection; use crate::connection::State; +#[cfg(feature = "websockets")] +use crate::connection::WebSocketAdapter; use crate::options::CallbackArg1; use crate::tls; use crate::AuthError; @@ -40,6 +43,7 @@ use std::cmp; use std::io; use std::path::PathBuf; use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; @@ -70,6 +74,7 @@ pub(crate) struct Connector { /// A map of servers and number of connect attempts. servers: Vec<(ServerAddr, usize)>, options: ConnectorOptions, + pub(crate) connect_stats: Arc, attempts: usize, pub(crate) events_tx: tokio::sync::mpsc::Sender, pub(crate) state_tx: tokio::sync::watch::Sender, @@ -93,6 +98,7 @@ impl Connector { events_tx: tokio::sync::mpsc::Sender, state_tx: tokio::sync::watch::Sender, max_payload: Arc, + connect_stats: Arc, ) -> Result { let servers = addrs.to_server_addrs()?.map(|addr| (addr, 0)).collect(); @@ -103,13 +109,16 @@ impl Connector { events_tx, state_tx, max_payload, + connect_stats, }) } pub(crate) async fn connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> { loop { match self.try_connect().await { - Ok(inner) => return Ok(inner), + Ok(inner) => { + return Ok(inner); + } Err(error) => match error.kind() { ConnectErrorKind::MaxReconnects => { return Err(ConnectError::with_source( @@ -161,7 +170,11 @@ impl Connector { .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?; for socket_addr in socket_addrs { match self - .try_connect_to(&socket_addr, server_addr.tls_required(), server_addr.host()) + .try_connect_to( + &socket_addr, + server_addr.tls_required(), + server_addr.clone(), + ) .await { Ok((server_info, mut connection)) => { @@ -284,6 +297,7 @@ impl Connector { Some(_) => { tracing::debug!("connected to {}", server_info.port); self.attempts = 0; + self.connect_stats.connects.add(1, Ordering::Relaxed); self.events_tx.send(Event::Connected).await.ok(); self.state_tx.send(State::Connected).ok(); self.max_payload.store( @@ -313,21 +327,76 @@ impl Connector { &self, socket_addr: &SocketAddr, tls_required: bool, - tls_host: &str, + server_addr: ServerAddr, ) -> Result<(ServerInfo, Connection), ConnectError> { - let tcp_stream = tokio::time::timeout( - self.options.connection_timeout, - TcpStream::connect(socket_addr), - ) - .await - .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??; - - tcp_stream.set_nodelay(true)?; + let mut connection = match server_addr.scheme() { + #[cfg(feature = "websockets")] + "ws" => { + let ws = tokio::time::timeout( + self.options.connection_timeout, + tokio_websockets::client::Builder::new() + .uri(format!("{}://{}", server_addr.scheme(), socket_addr).as_str()) + .map_err(|err| { + ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err) + })? + .connect(), + ) + .await + .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))? + .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?; - let mut connection = Connection::new( - Box::new(tcp_stream), - self.options.read_buffer_capacity.into(), - ); + let con = WebSocketAdapter::new(ws.0); + Connection::new(Box::new(con), 0, self.connect_stats.clone()) + } + #[cfg(feature = "websockets")] + "wss" => { + let domain = webpki::types::ServerName::try_from(server_addr.host()) + .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?; + let tls_config = + Arc::new(tls::config_tls(&self.options).await.map_err(|err| { + ConnectError::with_source(crate::ConnectErrorKind::Tls, err) + })?); + let tls_connector = tokio_rustls::TlsConnector::from(tls_config); + let ws = tokio::time::timeout( + self.options.connection_timeout, + tokio_websockets::client::Builder::new() + .connector(&tokio_websockets::Connector::Rustls(tls_connector)) + .uri( + format!( + "{}://{}:{}", + server_addr.scheme(), + domain.to_str(), + server_addr.port() + ) + .as_str(), + ) + .map_err(|err| { + ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err) + })? + .connect(), + ) + .await + .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))? + .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?; + let con = WebSocketAdapter::new(ws.0); + Connection::new(Box::new(con), 0, self.connect_stats.clone()) + } + _ => { + let tcp_stream = tokio::time::timeout( + self.options.connection_timeout, + TcpStream::connect(socket_addr), + ) + .await + .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??; + tcp_stream.set_nodelay(true)?; + + Connection::new( + Box::new(tcp_stream), + self.options.read_buffer_capacity.into(), + self.connect_stats.clone(), + ) + } + }; let tls_connection = |connection: Connection| async { let tls_config = Arc::new( @@ -337,20 +406,24 @@ impl Connector { ); let tls_connector = tokio_rustls::TlsConnector::from(tls_config); - let domain = webpki::types::ServerName::try_from(tls_host) + let domain = webpki::types::ServerName::try_from(server_addr.host()) .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?; let tls_stream = tls_connector .connect(domain.to_owned(), connection.stream) .await?; - Ok::(Connection::new(Box::new(tls_stream), 0)) + Ok::(Connection::new( + Box::new(tls_stream), + 0, + self.connect_stats.clone(), + )) }; // If `tls_first` was set, establish TLS connection before getting INFO. // There is no point in checking if tls is required, because // the connection has to be be upgraded to TLS anyway as it's different flow. - if self.options.tls_first { + if self.options.tls_first && !server_addr.is_websocket() { connection = tls_connection(connection).await?; } @@ -373,6 +446,7 @@ impl Connector { // If `tls_first` was not set, establish TLS connection if it is required. if !self.options.tls_first + && !server_addr.is_websocket() && (self.options.tls_required || info.tls_required || tls_required) { connection = tls_connection(connection).await?; diff --git a/async-nats/src/header.rs b/async-nats/src/header.rs index 2dec615ed..cc595cb42 100644 --- a/async-nats/src/header.rs +++ b/async-nats/src/header.rs @@ -110,6 +110,10 @@ impl HeaderMap { pub fn is_empty(&self) -> bool { self.inner.is_empty() } + + pub fn len(&self) -> usize { + self.inner.len() + } } impl HeaderMap { diff --git a/async-nats/src/jetstream/context.rs b/async-nats/src/jetstream/context.rs index eb79d243c..ae13e8429 100644 --- a/async-nats/src/jetstream/context.rs +++ b/async-nats/src/jetstream/context.rs @@ -330,7 +330,7 @@ impl Context { /// let client = async_nats::connect("localhost:4222").await?; /// let jetstream = async_nats::jetstream::new(client); /// - /// let stream = jetstream.get_stream("events").await?; + /// let stream = jetstream.get_stream_no_info("events").await?; /// # Ok(()) /// # } /// ``` diff --git a/async-nats/src/jetstream/kv/mod.rs b/async-nats/src/jetstream/kv/mod.rs index 9f48668e0..439a20d91 100644 --- a/async-nats/src/jetstream/kv/mod.rs +++ b/async-nats/src/jetstream/kv/mod.rs @@ -601,6 +601,8 @@ impl Store { })?; Ok(Watch { + no_messages: deliver_policy != DeliverPolicy::New + && consumer.cached_info().num_pending == 0, subscription: consumer.messages().await.map_err(|err| match err.kind() { crate::jetstream::consumer::StreamErrorKind::TimedOut => { WatchError::new(WatchErrorKind::TimedOut) @@ -1072,6 +1074,7 @@ impl Store { /// A structure representing a watch on a key-value bucket, yielding values whenever there are changes. pub struct Watch { + no_messages: bool, seen_current: bool, subscription: super::consumer::push::Ordered, prefix: String, @@ -1085,6 +1088,9 @@ impl futures::Stream for Watch { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { + if self.no_messages { + return Poll::Ready(None); + } match self.subscription.poll_next_unpin(cx) { Poll::Ready(message) => match message { None => Poll::Ready(None), diff --git a/async-nats/src/jetstream/stream.rs b/async-nats/src/jetstream/stream.rs index efd0bb952..691c6dfd4 100755 --- a/async-nats/src/jetstream/stream.rs +++ b/async-nats/src/jetstream/stream.rs @@ -13,9 +13,8 @@ // //! Manage operations on a [Stream], create/delete/update [Consumer]. -#[cfg(feature = "server_2_10")] -use std::collections::HashMap; use std::{ + collections::{self, HashMap}, fmt::{self, Debug, Display}, future::IntoFuture, io::{self, ErrorKind}, @@ -31,7 +30,7 @@ use crate::{ use base64::engine::general_purpose::STANDARD; use base64::engine::Engine; use bytes::Bytes; -use futures::{future::BoxFuture, TryFutureExt}; +use futures::{future::BoxFuture, FutureExt, TryFutureExt}; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::json; use time::{serde::rfc3339, OffsetDateTime}; @@ -192,6 +191,79 @@ impl Stream { } } + /// Retrieves [[Info]] from the server and returns a [[futures::Stream]] that allows + /// iterating over all subjects in the stream fetched via paged API. + /// + /// # Examples + /// + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// use futures::TryStreamExt; + /// let client = async_nats::connect("localhost:4222").await?; + /// let jetstream = async_nats::jetstream::new(client); + /// + /// let mut stream = jetstream.get_stream("events").await?; + /// + /// let mut info = stream.info_with_subjects("events.>").await?; + /// + /// while let Some((subject, count)) = info.try_next().await? { + /// println!("Subject: {} count: {}", subject, count); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn info_with_subjects>( + &self, + subjects_filter: F, + ) -> Result { + let subjects_filter = subjects_filter.as_ref().to_string(); + // TODO: validate the subject and decide if this should be a `Subject` + let info = stream_info_with_details( + self.context.clone(), + self.name.clone(), + 0, + false, + subjects_filter.clone(), + ) + .await?; + + Ok(InfoWithSubjects::new( + self.context.clone(), + info, + subjects_filter, + )) + } + + /// Creates a builder that allows to customize `Stream::Info`. + /// + /// # Examples + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// use futures::TryStreamExt; + /// let client = async_nats::connect("localhost:4222").await?; + /// let jetstream = async_nats::jetstream::new(client); + /// + /// let mut stream = jetstream.get_stream("events").await?; + /// + /// let mut info = stream + /// .info_builder() + /// .with_deleted(true) + /// .subjects("events.>") + /// .fetch() + /// .await?; + /// + /// while let Some((subject, count)) = info.try_next().await? { + /// println!("Subject: {} count: {}", subject, count); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn info_builder(&self) -> StreamInfoBuilder { + StreamInfoBuilder::new(self.context.clone(), self.name.clone()) + } + /// Gets next message for a [Stream]. /// /// Requires a [Stream] with `allow_direct` set to `true`. @@ -494,7 +566,7 @@ impl Stream { /// /// ```no_run /// #[tokio::main] - /// # async fn mains() -> Result<(), async_nats::Error> { + /// # async fn main() -> Result<(), async_nats::Error> { /// use futures::StreamExt; /// use futures::TryStreamExt; /// @@ -550,7 +622,7 @@ impl Stream { /// /// ```no_run /// #[tokio::main] - /// # async fn mains() -> Result<(), async_nats::Error> { + /// # async fn main() -> Result<(), async_nats::Error> { /// use futures::StreamExt; /// use futures::TryStreamExt; /// @@ -995,6 +1067,47 @@ impl Stream { } } +pub struct StreamInfoBuilder { + pub(crate) context: Context, + pub(crate) name: String, + pub(crate) deleted: bool, + pub(crate) subject: String, +} + +impl StreamInfoBuilder { + fn new(context: Context, name: String) -> Self { + Self { + context, + name, + deleted: false, + subject: "".to_string(), + } + } + + pub fn with_deleted(mut self, deleted: bool) -> Self { + self.deleted = deleted; + self + } + + pub fn subjects>(mut self, subject: S) -> Self { + self.subject = subject.into(); + self + } + + pub async fn fetch(self) -> Result { + let info = stream_info_with_details( + self.context.clone(), + self.name.clone(), + 0, + self.deleted, + self.subject.clone(), + ) + .await?; + + Ok(InfoWithSubjects::new(self.context, info, self.subject)) + } +} + /// `StreamConfig` determines the properties for a stream. /// There are sensible defaults for most. If no subjects are /// given the name will be used as the only subject. @@ -1246,6 +1359,122 @@ pub enum StorageType { Memory = 1, } +async fn stream_info_with_details( + context: Context, + stream: String, + offset: usize, + deleted_details: bool, + subjects_filter: String, +) -> Result { + let subject = format!("STREAM.INFO.{}", stream); + + let payload = StreamInfoRequest { + offset, + deleted_details, + subjects_filter, + }; + + let response: Response = context.request(subject, &payload).await?; + + match response { + Response::Ok(info) => Ok(info), + Response::Err { error } => Err(error.into()), + } +} + +type InfoRequest = BoxFuture<'static, Result>; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct StreamInfoRequest { + offset: usize, + deleted_details: bool, + subjects_filter: String, +} + +pub struct InfoWithSubjects { + stream: String, + context: Context, + pub info: Info, + offset: usize, + subjects: collections::hash_map::IntoIter, + info_request: Option, + subjects_filter: String, + pages_done: bool, +} + +impl InfoWithSubjects { + pub fn new(context: Context, mut info: Info, subject: String) -> Self { + let subjects = info.state.subjects.take().unwrap_or_default(); + let name = info.config.name.clone(); + InfoWithSubjects { + context, + info, + pages_done: subjects.is_empty(), + offset: subjects.len(), + subjects: subjects.into_iter(), + subjects_filter: subject, + stream: name, + info_request: None, + } + } +} + +impl futures::Stream for InfoWithSubjects { + type Item = Result<(String, usize), InfoError>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.subjects.next() { + Some((subject, count)) => Poll::Ready(Some(Ok((subject, count)))), + None => { + // If we have already requested all pages, stop the iterator. + if self.pages_done { + return Poll::Ready(None); + } + let stream = self.stream.clone(); + let context = self.context.clone(); + let subjects_filter = self.subjects_filter.clone(); + let offset = self.offset; + match self + .info_request + .get_or_insert_with(|| { + Box::pin(stream_info_with_details( + context, + stream, + offset, + false, + subjects_filter, + )) + }) + .poll_unpin(cx) + { + Poll::Ready(resp) => match resp { + Ok(info) => { + let subjects = info.state.subjects.clone(); + self.offset += subjects.as_ref().map_or_else(|| 0, |s| s.len()); + self.info_request = None; + let subjects = subjects.unwrap_or_default(); + self.subjects = info.state.subjects.unwrap_or_default().into_iter(); + let total = info.paged_info.map(|info| info.total).unwrap_or(0); + if total <= self.offset || subjects.is_empty() { + self.pages_done = true; + } + match self.subjects.next() { + Some((subject, count)) => Poll::Ready(Some(Ok((subject, count)))), + None => Poll::Ready(None), + } + } + Err(err) => Poll::Ready(Some(Err(err))), + }, + Poll::Pending => Poll::Pending, + } + } + } + } +} + /// Shows config and current state for this stream. #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] pub struct Info { @@ -1264,6 +1493,15 @@ pub struct Info { /// Information about sources configs if present. #[serde(default)] pub sources: Vec, + #[serde(flatten)] + paged_info: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +pub struct PagedInfo { + offset: usize, + total: usize, + limit: usize, } #[derive(Deserialize)] @@ -1272,7 +1510,7 @@ pub struct DeleteStatus { } /// information about the given stream. -#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] pub struct State { /// The number of messages contained in this stream pub messages: u64, @@ -1292,6 +1530,18 @@ pub struct State { pub last_timestamp: time::OffsetDateTime, /// The number of consumers configured to consume this stream pub consumer_count: usize, + /// The number of subjects in the stream + #[serde(default, rename = "num_subjects")] + pub subjects_count: u64, + /// The number of deleted messages in the stream + #[serde(default, rename = "num_deleted")] + pub deleted_count: Option, + /// The list of deleted subjects from the Stream. + /// This field will be filled only if [[StreamInfoBuilder::with_deleted]] option is set. + #[serde(default)] + pub deleted: Option>, + + pub(crate) subjects: Option>, } /// A raw stream message in the representation it is stored. diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 3981f172c..69bb26acf 100755 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -213,6 +213,7 @@ use std::pin::Pin; use std::slice; use std::str::{self, FromStr}; use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::ErrorKind; @@ -251,7 +252,9 @@ mod connector; mod options; pub use auth::Auth; -pub use client::{Client, PublishError, Request, RequestError, RequestErrorKind, SubscribeError}; +pub use client::{ + Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError, +}; pub use options::{AuthError, ConnectOptions}; mod crypto; @@ -375,6 +378,9 @@ pub(crate) enum Command { Flush { observer: oneshot::Sender<()>, }, + Drain { + sid: Option, + }, Reconnect, } @@ -408,6 +414,7 @@ struct Subscription { queue_group: Option, delivered: u64, max: Option, + is_draining: bool, } #[derive(Debug)] @@ -428,6 +435,7 @@ pub(crate) struct ConnectionHandler { ping_interval: Interval, should_reconnect: bool, flush_observers: Vec>, + is_draining: bool, } impl ConnectionHandler { @@ -450,6 +458,7 @@ impl ConnectionHandler { ping_interval, should_reconnect: false, flush_observers: Vec::new(), + is_draining: false, } } @@ -529,6 +538,20 @@ impl ConnectionHandler { } } + // Before handling any commands, drop any subscriptions which are draining + // Note: safe to assume subscription drain has completed at this point, as we would have flushed + // all outgoing UNSUB messages in the previous call to this fn, and we would have processed and + // delivered any remaining messages to the subscription in the loop above. + self.handler.subscriptions.retain(|_, s| !s.is_draining); + + if self.handler.is_draining { + // The entire connection is draining. This means we flushed outgoing messages in the previous + // call to this fn, we handled any remaining messages from the server in the loop above, and + // all subs were drained, so drain is complete and we should exit instead of processing any + // further messages + return Poll::Ready(ExitReason::Closed); + } + // WARNING: after the following loop `handle_command`, // or other functions which call `enqueue_write_op`, // cannot be called anymore. Runtime wakeups won't @@ -627,7 +650,11 @@ impl ConnectionHandler { }; debug!("reconnected"); } - ExitReason::Closed => break, + ExitReason::Closed => { + // Safe to ignore result as we're shutting down anyway + self.connector.events_tx.try_send(Event::Closed).ok(); + break; + } ExitReason::ReconnectRequested => { debug!("reconnect requested"); // Should be ok to ingore error, as that means we are not in connected state. @@ -667,6 +694,11 @@ impl ConnectionHandler { description, length, } => { + self.connector + .connect_stats + .in_messages + .add(1, Ordering::Relaxed); + if let Some(subscription) = self.subscriptions.get_mut(&sid) { let message: Message = Message { subject, @@ -765,6 +797,26 @@ impl ConnectionHandler { Command::Flush { observer } => { self.flush_observers.push(observer); } + Command::Drain { sid } => { + let mut drain_sub = |sid: u64, sub: &mut Subscription| { + sub.is_draining = true; + self.connection + .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None }); + }; + + if let Some(sid) = sid { + if let Some(sub) = self.subscriptions.get_mut(&sid) { + drain_sub(sid, sub); + } + } else { + // sid isn't set, so drain the whole client + self.connector.events_tx.try_send(Event::Draining).ok(); + self.is_draining = true; + for (&sid, sub) in self.subscriptions.iter_mut() { + drain_sub(sid, sub); + } + } + } Command::Subscribe { sid, subject, @@ -777,6 +829,7 @@ impl ConnectionHandler { max: None, subject: subject.to_owned(), queue_group: queue_group.to_owned(), + is_draining: false, }; self.subscriptions.insert(sid, subscription); @@ -814,13 +867,19 @@ impl ConnectionHandler { senders: HashMap::new(), }) }; + self.connector + .connect_stats + .out_messages + .add(1, Ordering::Relaxed); multiplexer.senders.insert(token.to_owned(), sender); + let respond: Subject = format!("{}{}", multiplexer.prefix, token).into(); + let pub_op = ClientOp::Publish { subject, payload, - respond: Some(format!("{}{}", multiplexer.prefix, token).into()), + respond: Some(respond), headers, }; @@ -833,6 +892,24 @@ impl ConnectionHandler { reply: respond, headers, }) => { + self.connector + .connect_stats + .out_messages + .add(1, Ordering::Relaxed); + + let header_len = headers + .as_ref() + .map(|headers| headers.len()) + .unwrap_or_default(); + + self.connector.connect_stats.out_bytes.add( + (payload.len() + + respond.as_ref().map_or_else(|| 0, |r| r.len()) + + subject.len() + + header_len) as u64, + Ordering::Relaxed, + ); + self.connection.enqueue_write_op(&ClientOp::Publish { subject, payload, @@ -907,6 +984,7 @@ pub async fn connect_with_options( let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending); // We're setting it to the default server payload size. let max_payload = Arc::new(AtomicUsize::new(1024 * 1024)); + let statistics = Arc::new(Statistics::default()); let mut connector = Connector::new( addrs, @@ -931,6 +1009,7 @@ pub async fn connect_with_options( events_tx, state_tx, max_payload.clone(), + statistics.clone(), ) .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?; @@ -954,6 +1033,7 @@ pub async fn connect_with_options( options.inbox_prefix, options.request_timeout, max_payload, + statistics, ); task::spawn(async move { @@ -991,6 +1071,8 @@ pub enum Event { Connected, Disconnected, LameDuckMode, + Draining, + Closed, SlowConsumer(u64), ServerError(ServerError), ClientError(ClientError), @@ -1002,6 +1084,8 @@ impl fmt::Display for Event { Event::Connected => write!(f, "connected"), Event::Disconnected => write!(f, "disconnected"), Event::LameDuckMode => write!(f, "lame duck mode detected"), + Event::Draining => write!(f, "draining"), + Event::Closed => write!(f, "closed"), Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"), Event::ServerError(err) => write!(f, "server error: {err}"), Event::ClientError(err) => write!(f, "client error: {err}"), @@ -1216,6 +1300,48 @@ impl Subscriber { .await?; Ok(()) } + + /// Unsubscribes immediately but leaves the subscription open to allow any in-flight messages + /// on the subscription to be delivered. The stream will be closed after any remaining messages + /// are delivered + /// + /// # Examples + /// ```no_run + /// # use futures::StreamExt; + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// let client = async_nats::connect("demo.nats.io").await?; + /// + /// let mut subscriber = client.subscribe("test").await?; + /// + /// tokio::spawn({ + /// let task_client = client.clone(); + /// async move { + /// loop { + /// _ = task_client.publish("test", "data".into()).await; + /// } + /// } + /// }); + /// + /// client.flush().await?; + /// subscriber.drain().await?; + /// + /// while let Some(message) = subscriber.next().await { + /// println!("message received: {:?}", message); + /// } + /// println!("no more messages, unsubscribed"); + /// # Ok(()) + /// # } + /// ``` + pub async fn drain(&mut self) -> Result<(), UnsubscribeError> { + self.sender + .send(Command::Drain { + sid: Some(self.sid), + }) + .await?; + + Ok(()) + } } #[derive(Error, Debug, PartialEq)] @@ -1420,7 +1546,11 @@ impl FromStr for ServerAddr { impl ServerAddr { /// Check if the URL is a valid NATS server address. pub fn from_url(url: Url) -> io::Result { - if url.scheme() != "nats" && url.scheme() != "tls" { + if url.scheme() != "nats" + && url.scheme() != "tls" + && url.scheme() != "ws" + && url.scheme() != "wss" + { return Err(std::io::Error::new( ErrorKind::InvalidInput, format!("invalid scheme for NATS server URL: {}", url.scheme()), @@ -1445,6 +1575,10 @@ impl ServerAddr { self.0.username() != "" } + pub fn scheme(&self) -> &str { + self.0.scheme() + } + /// Returns the host. pub fn host(&self) -> &str { match self.0.host() { @@ -1458,6 +1592,10 @@ impl ServerAddr { } } + pub fn is_websocket(&self) -> bool { + self.0.scheme() == "ws" || self.0.scheme() == "wss" + } + /// Returns the port. pub fn port(&self) -> u16 { self.0.port().unwrap_or(4222) diff --git a/async-nats/src/service/mod.rs b/async-nats/src/service/mod.rs index c1dcdc25f..fc9589542 100644 --- a/async-nats/src/service/mod.rs +++ b/async-nats/src/service/mod.rs @@ -639,7 +639,7 @@ impl Group { queue_group: Z, ) -> Group { Group { - prefix: prefix.to_string(), + prefix: format!("{}.{}", self.prefix, prefix.to_string()), stats: self.stats.clone(), client: self.client.clone(), shutdown_tx: self.shutdown_tx.clone(), @@ -891,3 +891,30 @@ impl std::fmt::Debug for StatsHandler { write!(f, "Stats handler") } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_group_with_queue_group() { + let server = nats_server::run_basic_server(); + let client = crate::connect(server.client_url()).await.unwrap(); + + let group = Group { + prefix: "test".to_string(), + stats: Arc::new(Mutex::new(Endpoints { + endpoints: HashMap::new(), + })), + client, + shutdown_tx: tokio::sync::broadcast::channel(1).0, + subjects: Arc::new(Mutex::new(vec![])), + queue_group: "default".to_string(), + }; + + let new_group = group.group_with_queue_group("v1", "custom_queue"); + + assert_eq!(new_group.prefix, "test.v1"); + assert_eq!(new_group.queue_group, "custom_queue"); + } +} diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index f125e9765..0f4374cd5 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -22,6 +22,7 @@ mod client { use futures::stream::StreamExt; use std::path::PathBuf; use std::str::FromStr; + use std::sync::atomic::Ordering; use std::time::Duration; #[tokio::test] @@ -886,6 +887,28 @@ mod client { .unwrap(); } + #[tokio::test] + async fn custom_auth_callback_jwt() { + let server = nats_server::run_server("tests/configs/jwt.conf"); + + ConnectOptions::with_auth_callback(move |nonce| async move { + let mut auth = async_nats::Auth::new(); + auth.jwt = Some("eyJ0eXAiOiJKV1QiLCJhbGciOiJlZDI1NTE5LW5rZXkifQ.".to_owned() + + "eyJqdGkiOiJMN1dBT1hJU0tPSUZNM1QyNEhMQ09ENzJRT1czQkNVWEdETjRKVU1SSUtHTlQ3RzdZVFRRIiwiaWF0IjoxNjUxNzkwOTgyLCJpc3MiOiJBRFRRUzdaQ0ZWSk5XNTcyNkdPWVhXNVRTQ1pGTklRU0hLMlpHWVVCQ0Q1RDc3T1ROTE9PS1pPWiIsIm5hbWUiOiJUZXN0V" + + "XNlciIsInN1YiI6IlVBRkhHNkZVRDJVVTRTREZWQUZVTDVMREZPMlhNNFdZTTc2VU5YVFBKWUpLN0VFTVlSQkhUMlZFIiwibmF0cyI6eyJwdWIiOnt9LCJzdWIiOnt9LCJzdWJzIjotMSwiZGF0YSI6LTEsInBheWxvYWQiOi0xLCJ0eXBlIjoidXNlciIsInZlcnNpb24iOjJ9fQ." + + "bp2-Jsy33l4ayF7Ku1MNdJby4WiMKUrG-rSVYGBusAtV3xP4EdCa-zhSNUaBVIL3uYPPCQYCEoM1pCUdOnoJBg"); + + let key_pair = nkeys::KeyPair::from_seed("SUACH75SWCM5D2JMJM6EKLR2WDARVGZT4QC6LX3AGHSWOMVAKERABBBRWM").unwrap(); + let sign = key_pair.sign(&nonce).map_err(async_nats::AuthError::new)?; + auth.signature = Some(sign); + + Ok(auth) + }) + .connect(server.client_url()) + .await + .unwrap(); + } + #[tokio::test] async fn max_reconnects() { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -931,4 +954,230 @@ mod client { .await .unwrap(); } + + #[tokio::test] + async fn client_statistics() { + let server = nats_server::run_basic_server(); + + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let client = async_nats::ConnectOptions::new() + .event_callback(move |event| { + let tx = tx.clone(); + async move { + if let Event::Connected = event { + tx.send(()).await.unwrap(); + } + } + }) + .connect(server.client_url()) + .await + .unwrap(); + + tokio::time::timeout(Duration::from_secs(5), rx.recv()) + .await + .unwrap() + .unwrap(); + let stats = client.statistics(); + + assert_eq!(stats.in_messages.load(Ordering::Relaxed), 0); + assert_eq!(stats.out_messages.load(Ordering::Relaxed), 0); + assert!(stats.in_bytes.load(Ordering::Relaxed) != 0); + assert!(stats.out_bytes.load(Ordering::Relaxed) != 0); + assert_eq!(stats.connects.load(Ordering::Relaxed), 1); + + let mut responder = client.subscribe("request").await.unwrap(); + tokio::task::spawn({ + let client = client.clone(); + async move { + let msg = responder.next().await.unwrap(); + client + .publish(msg.reply.unwrap(), "response".into()) + .await + .unwrap(); + } + }); + client.request("request", "data".into()).await.unwrap(); + + let mut sub = client.subscribe("test").await.unwrap(); + client.publish("test", "data".into()).await.unwrap(); + client.publish("test", "data".into()).await.unwrap(); + sub.next().await.unwrap(); + sub.next().await.unwrap(); + + client.flush().await.unwrap(); + client.force_reconnect().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(5), rx.recv()) + .await + .unwrap() + .unwrap(); + + assert_eq!(stats.in_messages.load(Ordering::Relaxed), 4); + assert_eq!(stats.out_messages.load(Ordering::Relaxed), 4); + assert!(stats.in_bytes.load(Ordering::Relaxed) != 0); + assert!(stats.out_bytes.load(Ordering::Relaxed) != 0); + assert_eq!(stats.connects.load(Ordering::Relaxed), 2); + } + + #[tokio::test] + async fn client_timeout() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + assert_eq!(client.timeout(), Some(Duration::from_secs(10))); + + let client = async_nats::ConnectOptions::new() + .request_timeout(Some(Duration::from_secs(30))) + .connect(server.client_url()) + .await + .unwrap(); + + assert_eq!(client.timeout(), Some(Duration::from_secs(30))); + + let client = async_nats::ConnectOptions::new() + .request_timeout(None) + .connect(server.client_url()) + .await + .unwrap(); + + assert_eq!(client.timeout(), None); + } + + #[tokio::test] + async fn drain_subscription_basic() { + use std::error::Error; + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let mut sub = client.subscribe("test").await.unwrap(); + + // publish some data + client.publish("test", "data".into()).await.unwrap(); + client.flush().await.unwrap(); + + // confirm we receive that data + assert!(sub.next().await.is_some()); + + // now drain the subscription + let result = sub.drain().await; + match result { + Ok(()) => println!("ok"), + Err(err) => { + println!("error: {}", err); + println!("source: {:?}", err.source()) + } + } + + // assert the stream is closed after draining + assert!(sub.next().await.is_none()); + + // confirm we can still reconnect and send messages on a new subscription + let mut sub2 = client.subscribe("test2").await.unwrap(); + client.publish("test2", "data".into()).await.unwrap(); + client.flush().await.unwrap(); + assert!(sub2.next().await.is_some()); + } + + #[tokio::test] + async fn drain_subscription_unsub_after() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let mut sub = client.subscribe("test").await.unwrap(); + + sub.unsubscribe_after(120) + .await + .expect("Expected to send unsub_after"); + + // publish some data + client.publish("test", "data".into()).await.unwrap(); + client.publish("test", "data".into()).await.unwrap(); + client.flush().await.unwrap(); + + // Send the drain command + sub.drain().await.expect("Expected to drain the sub"); + + // we should receive all published data then close immediately + assert!(sub.next().await.is_some()); + assert!(sub.next().await.is_some()); + assert!(sub.next().await.is_none()); + } + + #[tokio::test] + async fn drain_subscription_active() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + // spawn a task to constantly write to the subscription + let constant_writer = tokio::spawn({ + let client = client.clone(); + async move { + loop { + client.publish("test", "data".into()).await.unwrap(); + client.flush().await.unwrap(); + } + } + }); + + let mut sub = client.subscribe("test").await.unwrap(); + + // confirm we receive some data + assert!(sub.next().await.is_some()); + + // now drain the subscription + sub.drain().await.unwrap(); + + // yield to the runtime to ensure constant_writer gets a chance to publish a message or two to the subject + tokio::time::sleep(Duration::from_millis(1)).await; + + // assert the subscription stream is closed after draining + let sleep_fut = async move { while sub.next().await.is_some() {} }; + tokio::time::timeout(Duration::from_secs(10), sleep_fut) + .await + .expect("Expected stream to drain within 10s"); + + // assert constant_writer doesn't fail to write after the only sub is drained (i.e. client operations still work fine) + assert!(!constant_writer.is_finished()); + + // confirm we can still reconnect and receive messages on the same subject on a new subscription + let mut sub2 = client.subscribe("test").await.unwrap(); + assert!(sub2.next().await.is_some()); + } + + #[tokio::test] + async fn drain_client_basic() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let mut sub = client.subscribe("test").await.unwrap(); + + // publish some data + client.publish("test", "data".into()).await.unwrap(); + client.flush().await.unwrap(); + + // confirm we receive that data + assert!(sub.next().await.is_some()); + + // now drain the client + client.drain().await.unwrap(); + + // assert the sub's stream is closed after draining + assert!(sub.next().await.is_none()); + + // we should not be able to perform any more operations on a drained client + client + .subscribe("test2") + .await + .expect_err("Expected client to be drained"); + + client + .publish("test", "data".into()) + .await + .expect_err("Expected client to be drained"); + + // we should be able to connect with a new client + let _client2 = async_nats::connect(server.client_url()) + .await + .expect("Expected to be able to create a new client"); + } } diff --git a/async-nats/tests/configs/ws.conf b/async-nats/tests/configs/ws.conf new file mode 100644 index 000000000..842e36d23 --- /dev/null +++ b/async-nats/tests/configs/ws.conf @@ -0,0 +1,5 @@ +jetstream {} +websocket { + port: 8444 + no_tls: true +} diff --git a/async-nats/tests/configs/ws_tls.conf b/async-nats/tests/configs/ws_tls.conf new file mode 100644 index 000000000..f202900ce --- /dev/null +++ b/async-nats/tests/configs/ws_tls.conf @@ -0,0 +1,15 @@ +authorization { + user: derek + password: porkchop + timeout: 1 +} + +websocket { + tls { + +cert_file: "./tests/configs/certs/server-cert.pem" + key_file: "./tests/configs/certs/server-key.pem" + ca_file: "./tests/configs/certs/rootCA.pem" + } + port: 8445 +} diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index 2c6965ec6..f2e107bc6 100755 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -36,14 +36,16 @@ mod jetstream { self, push, AckPolicy, DeliverPolicy, Info, OrderedPullConsumer, OrderedPushConsumer, PullConsumer, PushConsumer, ReplayPolicy, }; - use async_nats::jetstream::context::{GetStreamByNameErrorKind, Publish, PublishErrorKind}; + use async_nats::jetstream::context::{ + GetStreamByNameErrorKind, Publish, PublishAckFuture, PublishErrorKind, + }; use async_nats::jetstream::response::Response; + #[cfg(feature = "server_2_10")] + use async_nats::jetstream::stream::ConsumerLimits; use async_nats::jetstream::stream::{ self, ConsumerCreateStrictErrorKind, ConsumerUpdateErrorKind, DirectGetErrorKind, DiscardPolicy, StorageType, }; - #[cfg(feature = "server_2_10")] - use async_nats::jetstream::stream::{Compression, ConsumerLimits, Source, SubjectTransform}; use async_nats::jetstream::AckKind; use async_nats::ConnectOptions; use futures::stream::{StreamExt, TryStreamExt}; @@ -936,6 +938,14 @@ mod jetstream { .stream_sequence, 3 ); + + let info = stream + .info_builder() + .with_deleted(true) + .fetch() + .await + .unwrap(); + assert_eq!(info.info.state.deleted_count, Some(1)); } #[tokio::test] @@ -3260,7 +3270,9 @@ mod jetstream { .await .unwrap(); - assert_eq!(stream.info().await.unwrap().config.metadata, metadata); + let info = stream.info().await.unwrap(); + assert_eq!(info.config.metadata.get("key"), metadata.get("key")); + assert_eq!(info.config.metadata.get("other"), metadata.get("other")); let mut consumer = stream .create_consumer(async_nats::jetstream::consumer::pull::Config { @@ -3271,7 +3283,9 @@ mod jetstream { .await .unwrap(); - assert_eq!(consumer.info().await.unwrap().config.metadata, metadata); + let info = consumer.info().await.unwrap(); + assert_eq!(info.config.metadata.get("key"), metadata.get("key")); + assert_eq!(info.config.metadata.get("other"), metadata.get("other")); } #[tokio::test] @@ -3533,80 +3547,6 @@ mod jetstream { .unwrap(); } - #[cfg(feature = "server_2_10")] - #[tokio::test] - async fn stream_config() { - let server = nats_server::run_server("tests/configs/jetstream.conf"); - let client = async_nats::connect(server.client_url()).await.unwrap(); - - let jetstream = async_nats::jetstream::new(client); - - let config = async_nats::jetstream::stream::Config { - name: "EVENTS".to_string(), - max_bytes: 1024 * 1024, - max_messages: 1_000_000, - max_messages_per_subject: 100, - discard: DiscardPolicy::New, - discard_new_per_subject: true, - subjects: vec!["events.>".to_string()], - retention: stream::RetentionPolicy::WorkQueue, - max_consumers: 10, - max_age: Duration::from_secs(900), - max_message_size: 1024 * 1024, - storage: StorageType::Memory, - num_replicas: 1, - no_ack: true, - duplicate_window: Duration::from_secs(90), - template_owner: "".to_string(), - sealed: false, - description: Some("A Stream".to_string()), - allow_rollup: true, - deny_delete: false, - deny_purge: false, - republish: Some(stream::Republish { - source: "data.>".to_string(), - destination: "dest.>".to_string(), - headers_only: true, - }), - allow_direct: true, - mirror_direct: false, - mirror: None, - sources: Some(vec![Source { - name: "source_one_of_many".to_string(), - start_sequence: Some(5), - start_time: Some(OffsetDateTime::now_utc()), - filter_subject: Some("filter".to_string()), - external: Some(stream::External { - api_prefix: "API.PREFIX".to_string(), - delivery_prefix: Some("delivery_prefix".to_string()), - }), - domain: None, - subject_transforms: vec![SubjectTransform { - source: "source".to_string(), - destination: "dest".to_string(), - }], - }]), - metadata: HashMap::from([("key".to_string(), "value".to_string())]), - subject_transform: Some(SubjectTransform { - source: "source".to_string(), - destination: "dest".to_string(), - }), - compression: Some(Compression::S2), - consumer_limits: Some(ConsumerLimits { - inactive_threshold: Duration::from_secs(120), - max_ack_pending: 150, - }), - first_sequence: Some(505), - placement: Some(stream::Placement { - cluster: Some("CLUSTER".to_string()), - tags: vec!["tag".to_string()], - }), - }; - - let mut stream = jetstream.create_stream(config.clone()).await.unwrap(); - assert_eq!(config, stream.info().await.unwrap().config); - } - #[tokio::test] async fn limits() { let server = nats_server::run_server("tests/configs/jetstream.conf"); @@ -3772,4 +3712,63 @@ mod jetstream { let err = jetstream.stream_by_subject("foo").await.unwrap_err(); assert_eq!(err.kind(), GetStreamByNameErrorKind::NotFound); } + + #[tokio::test] + async fn stream_subjects() { + let server = nats_server::run_server("tests/configs/jetstream.conf"); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let jetstream = async_nats::jetstream::new(client); + + let stream = jetstream + .create_stream(stream::Config { + name: "events".to_string(), + subjects: vec!["events.>".to_string()], + ..Default::default() + }) + .await + .unwrap(); + + let (tx, mut rx) = tokio::sync::mpsc::channel::(1000); + + let (done_tx, done_rx) = tokio::sync::oneshot::channel(); + tokio::task::spawn(async move { + while let Some(ack) = rx.recv().await { + ack.await.unwrap(); + } + done_tx.send(()).unwrap(); + }); + + for i in 0..220_000 { + let ack = jetstream + .publish(format!("events.{i}"), "data".into()) + .await + .unwrap(); + tx.send(ack).await.unwrap(); + } + drop(tx); + done_rx.await.unwrap(); + + let info = stream.info_with_subjects("events.>").await.unwrap(); + + let i = info.info.clone(); + let count = info.count().await; + println!("messages: {:?}", i.state.messages); + println!("count: {count}"); + assert!(count.eq(&220_000)); + + let info = stream + .info_builder() + .subjects("events.>") + .with_deleted(true) + .fetch() + .await + .unwrap(); + + let i = info.info.clone(); + let count = info.count().await; + println!("messages: {:?}", i.state.messages); + println!("count: {count}"); + assert!(count.eq(&220_000)); + } } diff --git a/async-nats/tests/kv_tests.rs b/async-nats/tests/kv_tests.rs index 1901ca092..19b670478 100644 --- a/async-nats/tests/kv_tests.rs +++ b/async-nats/tests/kv_tests.rs @@ -532,6 +532,29 @@ mod kv { } } } + + #[tokio::test] + async fn watch_no_messages() { + let server = nats_server::run_server("tests/configs/jetstream.conf"); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let context = async_nats::jetstream::new(client); + let kv = context + .create_key_value(async_nats::jetstream::kv::Config { + bucket: "history".to_string(), + description: "test_description".to_string(), + history: 15, + storage: StorageType::File, + num_replicas: 1, + ..Default::default() + }) + .await + .unwrap(); + + let mut watcher = kv.watch_with_history("foo").await.unwrap(); + assert!(watcher.next().await.is_none()); + } + #[tokio::test] async fn watch() { let server = nats_server::run_server("tests/configs/jetstream.conf"); diff --git a/async-nats/tests/websocket_test.rs b/async-nats/tests/websocket_test.rs new file mode 100644 index 000000000..e3230cc28 --- /dev/null +++ b/async-nats/tests/websocket_test.rs @@ -0,0 +1,78 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(feature = "websockets")] +mod websockets { + use std::path::PathBuf; + + use futures::StreamExt; + + #[tokio::test] + async fn core() { + let _server = nats_server::run_server("tests/configs/ws.conf"); + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let client = async_nats::ConnectOptions::new() + .retry_on_initial_connect() + .connect("ws://localhost:8444") + .await + .unwrap(); + + // Simple pub/sub + let mut sub = client.subscribe("foo").await.unwrap(); + client.publish("foo", "hello".into()).await.unwrap(); + assert_eq!(sub.next().await.unwrap().payload, "hello"); + + // Large messages + let payload = bytes::Bytes::from(vec![22; 1024 * 1024]); + + let mut sub = client.subscribe("foo").await.unwrap().take(10); + for _ in 0..10 { + client.publish("foo", payload.clone()).await.unwrap(); + } + while let Some(msg) = sub.next().await { + assert_eq!(msg.payload, payload); + } + + // Request/reply + let mut requests = client.subscribe("foo").await.unwrap(); + tokio::task::spawn({ + let client = client.clone(); + async move { + let request = requests.next().await.unwrap(); + client + .publish(request.reply.unwrap(), request.payload) + .await + .unwrap(); + } + }); + let response = client.request("foo", "hello".into()).await.unwrap(); + assert_eq!(response.payload, "hello"); + } + + #[tokio::test] + async fn tls() { + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let _server = nats_server::run_server("tests/configs/ws_tls.conf"); + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let client = async_nats::ConnectOptions::new() + .user_and_password("derek".into(), "porkchop".into()) + .add_root_certificates(path.join("tests/configs/certs/rootCA.pem")) + .connect("wss://localhost:8445") + .await + .unwrap(); + + let mut sub = client.subscribe("foo").await.unwrap(); + client.publish("foo", "hello".into()).await.unwrap(); + assert_eq!(sub.next().await.unwrap().payload, "hello"); + } +}