From eabbde7312f0fdd86b26dafd430d8a82260f9f3b Mon Sep 17 00:00:00 2001 From: Linfeng Qian Date: Sat, 18 Nov 2023 19:12:02 +0800 Subject: [PATCH 1/3] refactor: remove glommio runtime --- README-CN.md | 4 +- README.md | 4 +- akasa-core/Cargo.toml | 3 - akasa-core/src/server/mod.rs | 2 - akasa-core/src/server/rt_glommio.rs | 229 ---------------------------- akasa/src/main.rs | 23 +-- 6 files changed, 4 insertions(+), 261 deletions(-) delete mode 100644 akasa-core/src/server/rt_glommio.rs diff --git a/README-CN.md b/README-CN.md index a9156ec..57c3382 100644 --- a/README-CN.md +++ b/README-CN.md @@ -3,14 +3,13 @@ Akasa 是一个 Rust 写的高性能,低延迟,高度可扩展的 MQTT 服务器。 -Akasa 用 [glommio][glommio] 来实现高性能低延迟的网络 IO. 它底层的 MQTT 协议消息编解码器 ([mqtt-proto][mqtt-proto]) 是为了高性能和 async 环境而精心设计实现的。 +它底层的 MQTT 协议消息编解码器 ([mqtt-proto][mqtt-proto]) 是为了高性能和 async 环境而精心设计实现的。 ## 特性 - [x] 完全支持 MQTT v3.1/v3.1.1/v5.0 - [x] 支持 TLS (包括双向认证) - [x] 支持 WebSocket (包括 TLS 支持) - [x] 支持 [Proxy Protocol V2][proxy-protocol] -- [x] 使用 `io_uring` ([glommio][glommio]) 来实现高性能低延迟 IO (非 Linux 环境可以用 tokio) - [x] 使用 [Hook trait][hook-trait] 来扩展服务器 - [x] 用一个密码文件来支持简单的认证 - [ ] 基于 Raft 的服务器集群 (*敬请期待*) @@ -95,7 +94,6 @@ Akasa 会有一个企业版本,企业版中的额外功能包括: [mqtt-proto]: https://github.com/akasamq/mqtt-proto [mqtt-proto-fuzz]: https://github.com/akasamq/mqtt-proto/tree/master/fuzz [proxy-protocol]: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt -[glommio]: https://github.com/DataDog/glommio [bsl]: https://mariadb.com/bsl-faq-mariadb/ [hook-trait]: https://github.com/akasamq/akasa/blob/5ade2d788d9a919671f81b01d720155caf8e4e2d/akasa-core/src/hook.rs#L43 [tensorflow]: https://blog.tensorflow.org/2020/09/supercharging-tensorflowjs-webassembly.html diff --git a/README.md b/README.md index bc440bc..364835f 100644 --- a/README.md +++ b/README.md @@ -6,14 +6,13 @@ English | [简体中文](README-CN.md) Akasa is a high performance, low latency and high extendable MQTT server in Rust. -It uses [glommio][glommio] for high performance and low latency network IO. The underlying MQTT protocol message codec ([mqtt-proto][mqtt-proto]) is carefully crafted for high performance and async environment. +The underlying MQTT protocol message codec ([mqtt-proto][mqtt-proto]) is carefully crafted for high performance and async environment. ## Features - [x] Full support MQTT v3.1/v3.1.1/v5.0 - [x] Support TLS (include two-way authentication) - [x] Support WebSocket (include TLS support) - [x] Support [Proxy Protocol V2][proxy-protocol] -- [x] Use `io_uring` ([glommio][glommio]) for high performance low latency IO (can use tokio on non-Linux OS) - [x] Use a [Hook trait][hook-trait] to extend the server - [x] Simple password file based authentication - [ ] Raft based cluster (*coming soon*) @@ -100,7 +99,6 @@ Akasa will have an enterprise edition. In this edition, it provides: [mqtt-proto]: https://github.com/akasamq/mqtt-proto [mqtt-proto-fuzz]: https://github.com/akasamq/mqtt-proto/tree/master/fuzz [proxy-protocol]: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt -[glommio]: https://github.com/DataDog/glommio [bsl]: https://mariadb.com/bsl-faq-mariadb/ [hook-trait]: https://github.com/akasamq/akasa/blob/5ade2d788d9a919671f81b01d720155caf8e4e2d/akasa-core/src/hook.rs#L43 [tensorflow]: https://blog.tensorflow.org/2020/09/supercharging-tensorflowjs-webassembly.html diff --git a/akasa-core/Cargo.toml b/akasa-core/Cargo.toml index 5a46cd6..bbae8dd 100644 --- a/akasa-core/Cargo.toml +++ b/akasa-core/Cargo.toml @@ -37,9 +37,6 @@ crc32c = "0.6.3" openssl = "0.10.51" async-tungstenite = "0.21.0" -[target.'cfg(target_os = "linux")'.dependencies] -glommio = { version = "0.8.0" } - [dev-dependencies] futures-sink = "0.3.26" tokio-util = "0.7.7" diff --git a/akasa-core/src/server/mod.rs b/akasa-core/src/server/mod.rs index a06b8b4..17040b8 100644 --- a/akasa-core/src/server/mod.rs +++ b/akasa-core/src/server/mod.rs @@ -1,7 +1,5 @@ mod io_compat; mod proxy; -#[cfg(target_os = "linux")] -pub mod rt_glommio; pub mod rt_tokio; #[allow(dead_code)] mod tls; diff --git a/akasa-core/src/server/rt_glommio.rs b/akasa-core/src/server/rt_glommio.rs deleted file mode 100644 index 32bd0d3..0000000 --- a/akasa-core/src/server/rt_glommio.rs +++ /dev/null @@ -1,229 +0,0 @@ -use std::future::Future; -use std::io; -use std::os::unix::io::AsRawFd; -use std::rc::Rc; -use std::sync::Arc; -use std::time::Duration; - -use glommio::{ - net::TcpListener, - spawn_local, - timer::{sleep, TimerActionRepeat}, - CpuSet, Latency, LocalExecutorPoolBuilder, PoolPlacement, Shares, TaskQueueHandle, -}; - -use super::{build_tls_context, handle_accept, ConnectionArgs}; -use crate::config::{Listener, ProxyMode, TlsListener}; -use crate::hook::Hook; -use crate::state::{Executor, GlobalState}; - -pub fn start(hook_handler: H, global: Arc) -> io::Result<()> -where - H: Hook + Clone + Send + Sync + 'static, -{ - let cpu_set = CpuSet::online().expect("online cpus"); - let cpu_num = num_cpus::get(); - let placement = PoolPlacement::MaxSpread(cpu_num, Some(cpu_set)); - - let mqtts_tls_acceptor = global - .config - .listeners - .mqtts - .as_ref() - .map(|listener| { - log::info!("Building TLS context for mqtts..."); - build_tls_context(listener) - }) - .transpose()?; - let wss_tls_acceptor = global - .config - .listeners - .wss - .as_ref() - .map(|listener| { - log::info!("Building TLS context for wss..."); - build_tls_context(listener) - }) - .transpose()?; - - LocalExecutorPoolBuilder::new(placement) - .on_all_shards(move || async move { - let id = glommio::executor().id(); - // Do clean up tasks, such as: - // * kick out keep alive timeout connections - let gc_queue = glommio::executor().create_task_queue( - Shares::default(), - Latency::Matters(Duration::from_secs(15)), - "gc", - ); - let executor = Rc::new(GlommioExecutor::new(id, gc_queue)); - - let listeners = &global.config.listeners; - let tasks: Vec<_> = [ - listeners - .mqtt - .as_ref() - .map(|Listener { addr, proxy_mode }| ConnectionArgs { - addr: *addr, - proxy: proxy_mode.is_some(), - proxy_tls_termination: *proxy_mode == Some(ProxyMode::TlsTermination), - websocket: false, - tls_acceptor: None, - }), - listeners - .mqtts - .as_ref() - .map(|TlsListener { addr, proxy, .. }| ConnectionArgs { - addr: *addr, - proxy: *proxy, - proxy_tls_termination: false, - websocket: false, - tls_acceptor: mqtts_tls_acceptor.map(Into::into), - }), - listeners - .ws - .as_ref() - .map(|Listener { addr, proxy_mode }| ConnectionArgs { - addr: *addr, - proxy: proxy_mode.is_some(), - proxy_tls_termination: *proxy_mode == Some(ProxyMode::TlsTermination), - websocket: true, - tls_acceptor: None, - }), - listeners - .wss - .as_ref() - .map(|TlsListener { addr, proxy, .. }| ConnectionArgs { - addr: *addr, - proxy: *proxy, - proxy_tls_termination: false, - websocket: true, - tls_acceptor: wss_tls_acceptor.map(Into::into), - }), - ] - .into_iter() - .flatten() - .map(|conn_args| { - let global = Arc::clone(&global); - let hook_handler = hook_handler.clone(); - let executor = Rc::clone(&executor); - spawn_local(async move { - loop { - let global = Arc::clone(&global); - let hook_handler = hook_handler.clone(); - let executor = Rc::clone(&executor); - if let Err(err) = - listen(conn_args.clone(), hook_handler, executor, global).await - { - log::error!("Listen error: {:?}", err); - sleep(Duration::from_secs(1)).await; - } - } - }) - .detach() - }) - .collect(); - - if tasks.is_empty() { - log::error!("No binding address in config"); - } - for task in tasks { - let _ = task.await; - } - }) - .expect("executor pool") - .join_all(); - Ok(()) -} - -async fn listen( - conn_args: ConnectionArgs, - hook_handler: H, - executor: Rc, - global: Arc, -) -> io::Result<()> { - let addr = conn_args.addr; - let listener = TcpListener::bind(addr)?; - let listen_type = match (conn_args.websocket, conn_args.tls_acceptor.is_some()) { - (false, false) => "mqtt", - (false, true) => "mqtts", - (true, false) => "ws", - (true, true) => "wss", - }; - let listen_type = if conn_args.proxy { - format!("{listen_type}(proxy)") - } else { - listen_type.to_owned() - }; - log::info!("Listen {listen_type}@{addr} success! (glommio)"); - loop { - let conn = listener.accept().await?.buffered(); - let conn_args = conn_args.clone(); - let fd = conn.as_raw_fd(); - let peer = conn.peer_addr()?; - log::debug!("executor {:03}, #{} {} connected", executor.id(), fd, peer); - spawn_local({ - let hook_handler = hook_handler.clone(); - let executor = Rc::clone(&executor); - let global = Arc::clone(&global); - async move { - let _ = handle_accept( - conn, - conn_args, - peer, - hook_handler, - executor, - Arc::clone(&global), - ) - .await; - } - }) - .detach(); - } -} - -struct GlommioExecutor { - id: usize, - gc_queue: TaskQueueHandle, -} - -impl GlommioExecutor { - fn new(id: usize, gc_queue: TaskQueueHandle) -> GlommioExecutor { - GlommioExecutor { id, gc_queue } - } -} - -impl Executor for GlommioExecutor { - fn id(&self) -> usize { - self.id - } - - fn spawn_local(&self, future: F) - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - spawn_local(future).detach(); - } - - fn spawn_sleep(&self, duration: Duration, task: F) - where - F: Future + Send + 'static, - { - spawn_local(async move { - sleep(duration).await; - task.await; - }) - .detach(); - } - - fn spawn_interval(&self, action_gen: G) -> io::Result<()> - where - G: (Fn() -> F) + Send + Sync + 'static, - F: Future> + Send + 'static, - { - TimerActionRepeat::repeat_into(action_gen, self.gc_queue) - .map(|_| ()) - .map_err(|_err| io::Error::from(io::ErrorKind::Other)) - } -} diff --git a/akasa/src/main.rs b/akasa/src/main.rs index c765094..e1152e6 100644 --- a/akasa/src/main.rs +++ b/akasa/src/main.rs @@ -32,14 +32,6 @@ enum Commands { /// The config file path #[clap(long, value_name = "FILE")] config: PathBuf, - - /// Async runtime - #[cfg(target_os = "linux")] - #[clap(long, default_value_t = Runtime::Glommio, value_enum)] - runtime: Runtime, - #[cfg(not(target_os = "linux"))] - #[clap(long, default_value_t = Runtime::Tokio, value_enum)] - runtime: Runtime, }, /// Generate default config to stdout @@ -84,13 +76,6 @@ enum Commands { }, } -#[derive(ValueEnum, Clone, Debug)] -enum Runtime { - #[cfg(target_os = "linux")] - Glommio, - Tokio, -} - #[derive(ValueEnum, Clone, Debug)] enum HashAlgorithm { Sha256, @@ -106,7 +91,7 @@ fn main() -> anyhow::Result<()> { log::debug!("{:#?}", cli); match cli.command { - Commands::Start { config, runtime } => { + Commands::Start { config } => { let config: Config = { let content = fs::read_to_string(config)?; serde_yaml::from_str(&content) @@ -129,11 +114,7 @@ fn main() -> anyhow::Result<()> { let mut global_state = GlobalState::new(config); global_state.auth_passwords = auth_passwords; let global = Arc::new(global_state); - match runtime { - #[cfg(target_os = "linux")] - Runtime::Glommio => server::rt_glommio::start(hook_handler, global)?, - Runtime::Tokio => server::rt_tokio::start(hook_handler, global)?, - } + server::rt_tokio::start(hook_handler, global)?; } Commands::DefaultConfig { allow_anonymous } => { let config = if allow_anonymous { From 132ee3e8f39a132b09105d2b12d71e5743bf32d7 Mon Sep 17 00:00:00 2001 From: Linfeng Qian Date: Sat, 18 Nov 2023 20:04:30 +0800 Subject: [PATCH 2/3] refactor: update mqtt-proto --- akasa-core/Cargo.toml | 3 +- akasa-core/src/protocols/mqtt/online_loop.rs | 6 +- akasa-core/src/protocols/mqtt/v3/message.rs | 6 +- .../src/protocols/mqtt/v3/packet/common.rs | 2 +- .../src/protocols/mqtt/v3/packet/connect.rs | 2 +- akasa-core/src/protocols/mqtt/v5/message.rs | 6 +- .../src/protocols/mqtt/v5/packet/common.rs | 2 +- .../src/protocols/mqtt/v5/packet/connect.rs | 2 +- akasa-core/src/server/io_compat.rs | 42 --- akasa-core/src/server/mod.rs | 59 ++- akasa-core/src/server/proxy.rs | 2 +- akasa-core/src/server/{rt_tokio.rs => rt.rs} | 5 +- akasa-core/src/server/tls-test-keys/cert.pem | 21 -- akasa-core/src/server/tls-test-keys/key.pem | 28 -- akasa-core/src/server/tls.rs | 337 ------------------ akasa-core/src/tests/utils.rs | 18 +- akasa/src/main.rs | 2 +- docs/chinese/getting-started.md | 4 +- docs/english/getting-started.md | 4 +- 19 files changed, 57 insertions(+), 494 deletions(-) delete mode 100644 akasa-core/src/server/io_compat.rs rename akasa-core/src/server/{rt_tokio.rs => rt.rs} (97%) delete mode 100644 akasa-core/src/server/tls-test-keys/cert.pem delete mode 100644 akasa-core/src/server/tls-test-keys/key.pem delete mode 100644 akasa-core/src/server/tls.rs diff --git a/akasa-core/Cargo.toml b/akasa-core/Cargo.toml index bbae8dd..dee872e 100644 --- a/akasa-core/Cargo.toml +++ b/akasa-core/Cargo.toml @@ -23,6 +23,8 @@ parking_lot = "0.12.1" serde = { version = "1.0.147", features = ["derive"] } thiserror = "1.0.38" tokio = { version = "1.23.0", features = ["full"] } +tokio-tungstenite = "0.20.1" +tokio-openssl = "0.6.3" uuid = { version = "1.2.2", features = ["v4"] } rand = { version = "0.8.5", features = ["getrandom"] } ahash = "0.8.3" @@ -35,7 +37,6 @@ base64 = "0.21.0" ring = "0.16" crc32c = "0.6.3" openssl = "0.10.51" -async-tungstenite = "0.21.0" [dev-dependencies] futures-sink = "0.3.26" diff --git a/akasa-core/src/protocols/mqtt/online_loop.rs b/akasa-core/src/protocols/mqtt/online_loop.rs index 5ec4174..caad7d6 100644 --- a/akasa-core/src/protocols/mqtt/online_loop.rs +++ b/akasa-core/src/protocols/mqtt/online_loop.rs @@ -11,13 +11,11 @@ use flume::{ r#async::{RecvStream, SendSink}, Sender, }; -use futures_lite::{ - io::{AsyncRead, AsyncWrite}, - Stream, -}; +use futures_lite::Stream; use futures_sink::Sink; use hashbrown::HashMap; use mqtt_proto::{v3, v5, GenericPollPacket, GenericPollPacketState, PollHeader, QoS, VarBytes}; +use tokio::io::{AsyncRead, AsyncWrite}; use crate::hook::{handle_request, Hook, HookAction, HookRequest, HookResponse}; use crate::state::{ClientId, ClientReceiver, ControlMessage, GlobalState, NormalMessage}; diff --git a/akasa-core/src/protocols/mqtt/v3/message.rs b/akasa-core/src/protocols/mqtt/v3/message.rs index d9f9f61..8aaa958 100644 --- a/akasa-core/src/protocols/mqtt/v3/message.rs +++ b/akasa-core/src/protocols/mqtt/v3/message.rs @@ -5,10 +5,7 @@ use std::net::SocketAddr; use std::sync::Arc; use flume::{Receiver, Sender}; -use futures_lite::{ - io::{AsyncRead, AsyncWrite}, - FutureExt, -}; +use futures_lite::FutureExt; use hashbrown::HashMap; use mqtt_proto::{ v3::{ @@ -17,6 +14,7 @@ use mqtt_proto::{ }, Error, Pid, Protocol, QoS, QosPid, }; +use tokio::io::{AsyncRead, AsyncWrite}; use crate::hook::{ handle_request, Hook, HookAction, HookRequest, HookResponse, LockedHookContext, PublishAction, diff --git a/akasa-core/src/protocols/mqtt/v3/packet/common.rs b/akasa-core/src/protocols/mqtt/v3/packet/common.rs index a638aaa..6f23a52 100644 --- a/akasa-core/src/protocols/mqtt/v3/packet/common.rs +++ b/akasa-core/src/protocols/mqtt/v3/packet/common.rs @@ -1,11 +1,11 @@ use std::io; use std::time::Instant; -use futures_lite::io::AsyncWrite; use mqtt_proto::{ v3::{Packet, Publish}, QoS, QosPid, }; +use tokio::io::AsyncWrite; use crate::protocols::mqtt::{get_unix_ts, PendingPacketStatus}; use crate::state::ClientId; diff --git a/akasa-core/src/protocols/mqtt/v3/packet/connect.rs b/akasa-core/src/protocols/mqtt/v3/packet/connect.rs index 6654fa4..8349d48 100644 --- a/akasa-core/src/protocols/mqtt/v3/packet/connect.rs +++ b/akasa-core/src/protocols/mqtt/v3/packet/connect.rs @@ -2,11 +2,11 @@ use std::io; use std::sync::Arc; use std::time::Instant; -use futures_lite::io::AsyncWrite; use mqtt_proto::{ v3::{Connack, Connect, ConnectReturnCode}, Protocol, }; +use tokio::io::AsyncWrite; use crate::protocols::mqtt::{check_password, start_keep_alive_timer}; use crate::state::{AddClientReceipt, ClientReceiver, Executor, GlobalState}; diff --git a/akasa-core/src/protocols/mqtt/v5/message.rs b/akasa-core/src/protocols/mqtt/v5/message.rs index eed6b37..0dc4380 100644 --- a/akasa-core/src/protocols/mqtt/v5/message.rs +++ b/akasa-core/src/protocols/mqtt/v5/message.rs @@ -7,10 +7,7 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use flume::{Receiver, Sender}; -use futures_lite::{ - io::{AsyncRead, AsyncWrite}, - FutureExt, -}; +use futures_lite::FutureExt; use hashbrown::HashMap; use mqtt_proto::{ v5::{ @@ -20,6 +17,7 @@ use mqtt_proto::{ }, Error, Pid, Protocol, QoS, QosPid, }; +use tokio::io::{AsyncRead, AsyncWrite}; use crate::hook::{ handle_request, Hook, HookAction, HookRequest, HookResponse, LockedHookContext, PublishAction, diff --git a/akasa-core/src/protocols/mqtt/v5/packet/common.rs b/akasa-core/src/protocols/mqtt/v5/packet/common.rs index a018c57..eb9b88c 100644 --- a/akasa-core/src/protocols/mqtt/v5/packet/common.rs +++ b/akasa-core/src/protocols/mqtt/v5/packet/common.rs @@ -3,7 +3,6 @@ use std::io; use std::sync::Arc; use std::time::Instant; -use futures_lite::io::AsyncWrite; use mqtt_proto::{ v5::{ Connack, ConnackProperties, ConnectReasonCode, Disconnect, DisconnectProperties, @@ -11,6 +10,7 @@ use mqtt_proto::{ }, QoS, QosPid, }; +use tokio::io::AsyncWrite; use crate::protocols::mqtt::{get_unix_ts, PendingPacketStatus}; use crate::state::ClientId; diff --git a/akasa-core/src/protocols/mqtt/v5/packet/connect.rs b/akasa-core/src/protocols/mqtt/v5/packet/connect.rs index e2a4336..200a5d6 100644 --- a/akasa-core/src/protocols/mqtt/v5/packet/connect.rs +++ b/akasa-core/src/protocols/mqtt/v5/packet/connect.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use std::time::Instant; use bytes::Bytes; -use futures_lite::io::AsyncWrite; use mqtt_proto::{ v5::{ Auth, AuthProperties, AuthReasonCode, Connack, ConnackProperties, Connect, @@ -12,6 +11,7 @@ use mqtt_proto::{ QoS, }; use scram::server::{AuthenticationStatus, ScramServer}; +use tokio::io::AsyncWrite; use crate::config::SaslMechanism; use crate::protocols::mqtt::{check_password, start_keep_alive_timer}; diff --git a/akasa-core/src/server/io_compat.rs b/akasa-core/src/server/io_compat.rs deleted file mode 100644 index dd06f0e..0000000 --- a/akasa-core/src/server/io_compat.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use futures_lite::io::{AsyncRead, AsyncWrite}; -use tokio::io::{self as tokio_io, ReadBuf}; - -pub struct IoWrapper(S); - -impl IoWrapper { - pub fn new(inner: S) -> IoWrapper { - IoWrapper(inner) - } -} - -impl AsyncRead for IoWrapper { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let mut read_buf = ReadBuf::new(buf); - tokio_io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, &mut read_buf) - .map_ok(|()| read_buf.capacity() - read_buf.remaining()) - } -} - -impl AsyncWrite for IoWrapper { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - tokio_io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) - } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio_io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx) - } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio_io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx) - } -} diff --git a/akasa-core/src/server/mod.rs b/akasa-core/src/server/mod.rs index 17040b8..5929ed0 100644 --- a/akasa-core/src/server/mod.rs +++ b/akasa-core/src/server/mod.rs @@ -1,8 +1,5 @@ -mod io_compat; mod proxy; -pub mod rt_tokio; -#[allow(dead_code)] -mod tls; +pub mod rt; use std::cmp; use std::io; @@ -12,28 +9,26 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use async_tungstenite::{ - tungstenite::{http, Message}, - WebSocketStream, -}; use flume::bounded; -use futures_lite::{ - io::{AsyncRead, AsyncWrite}, - FutureExt, Stream, -}; +use futures_lite::{FutureExt, Stream}; use futures_sink::Sink; use futures_util::TryFutureExt; use mqtt_proto::{decode_raw_header, v3, v5, Error, Protocol}; use openssl::ssl::{NameType, Ssl, SslAcceptor, SslFiletype, SslMethod, SslVerifyMode}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_openssl::SslStream; +use tokio_tungstenite::{ + accept_hdr_async, + tungstenite::{http, Message}, + WebSocketStream, +}; use crate::config::TlsListener; use crate::hook::Hook; use crate::protocols::mqtt; use crate::state::{Executor, GlobalState}; -use io_compat::IoWrapper; use proxy::{parse_header, Addresses}; -use tls::SslStream; const CONNECT_TIMEOUT_SECS: u64 = 5; @@ -128,7 +123,7 @@ pub async fn handle_accept< // Handle WebSocket let mut ws_wrapper = if conn_args.websocket { - let stream = match async_tungstenite::accept_hdr_async( + let stream = match accept_hdr_async( tls_wrapper, |req: &http::Request<_>, mut resp: http::Response<_>| { if let Some(protocol) = req.headers().get("Sec-WebSocket-Protocol") { @@ -269,8 +264,8 @@ impl AsyncRead for TlsWrapper { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { match self.get_mut() { TlsWrapper::Raw(conn) => Pin::new(conn).poll_read(cx, buf), TlsWrapper::Tls(tls_stream) => Pin::new(tls_stream).poll_read(cx, buf), @@ -295,10 +290,10 @@ impl AsyncWrite for TlsWrapper { TlsWrapper::Tls(tls_stream) => Pin::new(tls_stream).poll_flush(cx), } } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - TlsWrapper::Raw(conn) => Pin::new(conn).poll_close(cx), - TlsWrapper::Tls(tls_stream) => Pin::new(tls_stream).poll_close(cx), + TlsWrapper::Raw(conn) => Pin::new(conn).poll_shutdown(cx), + TlsWrapper::Tls(tls_stream) => Pin::new(tls_stream).poll_shutdown(cx), } } } @@ -350,8 +345,8 @@ impl AsyncRead for WebSocketWrapper { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { match self.get_mut() { WebSocketWrapper::Raw(conn) => Pin::new(conn).poll_read(cx, buf), WebSocketWrapper::WebSocket { @@ -361,9 +356,9 @@ impl AsyncRead for WebSocketWrapper { pending_pong, closed, } => { - fn copy_data(buf: &mut [u8], data: &[u8], data_idx: &mut usize) -> usize { - let amt = cmp::min(data.len() - *data_idx, buf.len()); - buf[0..amt].copy_from_slice(&data[*data_idx..*data_idx + amt]); + fn copy_data(buf: &mut ReadBuf, data: &[u8], data_idx: &mut usize) -> usize { + let amt = cmp::min(data.len() - *data_idx, buf.remaining()); + buf.put_slice(&data[*data_idx..*data_idx + amt]); *data_idx += amt; amt } @@ -371,7 +366,8 @@ impl AsyncRead for WebSocketWrapper { ws_send_pong(stream, pending_pong, cx)?; } if *read_data_idx < read_data.len() { - return Poll::Ready(Ok(copy_data(buf, read_data, read_data_idx))); + copy_data(buf, read_data, read_data_idx); + return Poll::Ready(Ok(())); } loop { match Pin::new(&mut *stream).poll_next(cx) { @@ -382,9 +378,10 @@ impl AsyncRead for WebSocketWrapper { } *read_data = bin; *read_data_idx = 0; - return Poll::Ready(Ok(copy_data(buf, read_data, read_data_idx))); + copy_data(buf, read_data, read_data_idx); + return Poll::Ready(Ok(())); } - Message::Close(_) => return Poll::Ready(Ok(0)), + Message::Close(_) => return Poll::Ready(Ok(())), Message::Ping(data) => { *pending_pong = Some(data); ws_send_pong(stream, pending_pong, cx)?; @@ -406,7 +403,7 @@ impl AsyncRead for WebSocketWrapper { log::debug!("WebSocket read error: {:?}", err); return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())); } - Poll::Ready(None) => return Poll::Ready(Ok(0)), + Poll::Ready(None) => return Poll::Ready(Ok(())), Poll::Pending => return Poll::Pending, } } @@ -462,9 +459,9 @@ impl AsyncWrite for WebSocketWrapper { } } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - WebSocketWrapper::Raw(conn) => Pin::new(conn).poll_close(cx), + WebSocketWrapper::Raw(conn) => Pin::new(conn).poll_shutdown(cx), WebSocketWrapper::WebSocket { stream, closed, .. } => { if !*closed { let mut sink = Pin::new(&mut *stream); diff --git a/akasa-core/src/server/proxy.rs b/akasa-core/src/server/proxy.rs index 9a5f894..f204723 100644 --- a/akasa-core/src/server/proxy.rs +++ b/akasa-core/src/server/proxy.rs @@ -11,7 +11,7 @@ use std::io; use std::net::{Ipv4Addr, Ipv6Addr}; -use futures_lite::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{AsyncRead, AsyncReadExt}; /// The prefix of the PROXY protocol header. const PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n"; diff --git a/akasa-core/src/server/rt_tokio.rs b/akasa-core/src/server/rt.rs similarity index 97% rename from akasa-core/src/server/rt_tokio.rs rename to akasa-core/src/server/rt.rs index 779bc83..c43ca37 100644 --- a/akasa-core/src/server/rt_tokio.rs +++ b/akasa-core/src/server/rt.rs @@ -5,7 +5,7 @@ use std::time::Duration; use tokio::{net::TcpListener, runtime::Runtime}; -use super::{build_tls_context, handle_accept, ConnectionArgs, IoWrapper}; +use super::{build_tls_context, handle_accept, ConnectionArgs}; use crate::config::{Listener, ProxyMode, TlsListener}; use crate::hook::Hook; use crate::state::{Executor, GlobalState}; @@ -138,14 +138,13 @@ async fn listen { - stream: S, - context: usize, -} - -impl fmt::Debug for StreamWrapper -where - S: fmt::Debug, -{ - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&self.stream, fmt) - } -} - -impl StreamWrapper { - /// # Safety - /// - /// Must be called with `context` set to a valid pointer to a live `Context` object, and the - /// wrapper must be pinned in memory. - unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) { - debug_assert_ne!(self.context, 0); - let stream = Pin::new_unchecked(&mut self.stream); - let context = &mut *(self.context as *mut _); - (stream, context) - } -} - -impl Read for StreamWrapper -where - S: AsyncRead, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let (stream, cx) = unsafe { self.parts() }; - match stream.poll_read(cx, buf)? { - Poll::Ready(nread) => Ok(nread), - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -impl Write for StreamWrapper -where - S: AsyncWrite, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let (stream, cx) = unsafe { self.parts() }; - match stream.poll_write(cx, buf) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } - - fn flush(&mut self) -> io::Result<()> { - let (stream, cx) = unsafe { self.parts() }; - match stream.poll_flush(cx) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -fn cvt(r: io::Result) -> Poll> { - match r { - Ok(v) => Poll::Ready(Ok(v)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(e) => Poll::Ready(Err(e)), - } -} - -fn cvt_ossl(r: Result) -> Poll> { - match r { - Ok(v) => Poll::Ready(Ok(v)), - Err(e) => match e.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending, - _ => Poll::Ready(Err(e)), - }, - } -} - -/// An asynchronous version of [`openssl::ssl::SslStream`]. -#[derive(Debug)] -pub struct SslStream(ssl::SslStream>); - -impl SslStream -where - S: AsyncRead + AsyncWrite, -{ - /// Like [`SslStream::new`](ssl::SslStream::new). - pub fn new(ssl: Ssl, stream: S) -> Result { - ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream) - } - - /// Like [`SslStream::connect`](ssl::SslStream::connect). - pub fn poll_connect( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.with_context(cx, |s| cvt_ossl(s.connect())) - } - - /// A convenience method wrapping [`poll_connect`](Self::poll_connect). - pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { - future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await - } - - /// Like [`SslStream::accept`](ssl::SslStream::accept). - pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.with_context(cx, |s| cvt_ossl(s.accept())) - } - - /// A convenience method wrapping [`poll_accept`](Self::poll_accept). - pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { - future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await - } - - /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake). - pub fn poll_do_handshake( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.with_context(cx, |s| cvt_ossl(s.do_handshake())) - } - - /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake). - pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { - future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await - } - - /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data). - #[cfg(ossl111)] - pub fn poll_read_early_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf))) - } - - /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data). - #[cfg(ossl111)] - pub async fn read_early_data( - mut self: Pin<&mut Self>, - buf: &mut [u8], - ) -> Result { - future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await - } - - /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data). - #[cfg(ossl111)] - pub fn poll_write_early_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf))) - } - - /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data). - #[cfg(ossl111)] - pub async fn write_early_data( - mut self: Pin<&mut Self>, - buf: &[u8], - ) -> Result { - future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await - } -} - -impl SslStream { - /// Returns a shared reference to the `Ssl` object associated with this stream. - pub fn ssl(&self) -> &SslRef { - self.0.ssl() - } - - /// Returns a shared reference to the underlying stream. - pub fn get_ref(&self) -> &S { - &self.0.get_ref().stream - } - - /// Returns a mutable reference to the underlying stream. - pub fn get_mut(&mut self) -> &mut S { - &mut self.0.get_mut().stream - } - - /// Returns a pinned mutable reference to the underlying stream. - pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { - unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) } - } - - fn with_context(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R - where - F: FnOnce(&mut ssl::SslStream>) -> R, - { - let this = unsafe { self.get_unchecked_mut() }; - this.0.get_mut().context = ctx as *mut _ as usize; - let r = f(&mut this.0); - this.0.get_mut().context = 0; - r - } -} - -impl AsyncRead for SslStream -where - S: AsyncRead + AsyncWrite, -{ - fn poll_read( - self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.with_context(ctx, |s| { - // This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though - // OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now. - match cvt(s.read(buf))? { - Poll::Ready(nread) => Poll::Ready(Ok(nread)), - Poll::Pending => Poll::Pending, - } - }) - } -} - -impl AsyncWrite for SslStream -where - S: AsyncRead + AsyncWrite, -{ - fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll> { - self.with_context(ctx, |s| cvt(s.write(buf))) - } - - fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - self.with_context(ctx, |s| cvt(s.flush())) - } - - fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - match self.as_mut().with_context(ctx, |s| s.shutdown()) { - Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {} - Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {} - Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => { - return Poll::Pending; - } - Err(e) => { - return Poll::Ready(Err(e - .into_io_error() - .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))); - } - } - - self.get_pin_mut().poll_close(ctx) - } -} - -#[cfg(test)] -mod test { - use super::super::IoWrapper; - use super::SslStream; - use futures_lite::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; - use futures_util::future; - use openssl::ssl::{Ssl, SslAcceptor, SslConnector, SslFiletype, SslMethod}; - use std::pin::Pin; - use tokio::net::{TcpListener, TcpStream}; - - #[tokio::test] - async fn server() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let server = async move { - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor - .set_private_key_file("src/server/tls-test-keys/key.pem", SslFiletype::PEM) - .unwrap(); - acceptor - .set_certificate_chain_file("src/server/tls-test-keys/cert.pem") - .unwrap(); - let acceptor = acceptor.build(); - - let ssl = Ssl::new(acceptor.context()).unwrap(); - let stream = listener.accept().await.unwrap().0; - let mut stream = SslStream::new(ssl, IoWrapper::new(stream)).unwrap(); - - Pin::new(&mut stream).accept().await.unwrap(); - - let mut buf = [0; 4]; - stream.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, b"asdf"); - - stream.write_all(b"jkl;").await.unwrap(); - - future::poll_fn(|ctx| Pin::new(&mut stream).poll_close(ctx)) - .await - .unwrap() - }; - - let client = async { - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); - connector - .set_ca_file("src/server/tls-test-keys/cert.pem") - .unwrap(); - let ssl = connector - .build() - .configure() - .unwrap() - .into_ssl("localhost") - .unwrap(); - - let stream = TcpStream::connect(&addr).await.unwrap(); - let mut stream = SslStream::new(ssl, IoWrapper::new(stream)).unwrap(); - - Pin::new(&mut stream).connect().await.unwrap(); - - stream.write_all(b"asdf").await.unwrap(); - - let mut buf = vec![]; - stream.read_to_end(&mut buf).await.unwrap(); - assert_eq!(buf, b"jkl;"); - }; - - future::join(server, client).await; - } -} diff --git a/akasa-core/src/tests/utils.rs b/akasa-core/src/tests/utils.rs index 29ede27..188dd3b 100644 --- a/akasa-core/src/tests/utils.rs +++ b/akasa-core/src/tests/utils.rs @@ -7,10 +7,10 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; -use futures_lite::io::{AsyncRead, AsyncWrite}; use futures_sink::Sink; use mqtt_proto::{v3, v5}; use rand::{rngs::OsRng, RngCore}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::{ sync::mpsc::{channel, error::TryRecvError, Receiver, Sender}, task::JoinHandle, @@ -22,7 +22,7 @@ use crate::hook::{ Hook, HookAction, HookConnectCode, HookPublishCode, HookResult, HookSubscribeCode, HookUnsubscribeCode, }; -use crate::server::{handle_accept, rt_tokio::TokioExecutor, ConnectionArgs}; +use crate::server::{handle_accept, rt::TokioExecutor, ConnectionArgs}; use crate::state::{AuthPassword, GlobalState, HashAlgorithm}; use crate::{hash_password, SessionV3, SessionV5, MIN_SALT_LEN}; @@ -135,8 +135,8 @@ impl AsyncRead for MockConn { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { // let peer = self.peer.clone(); if self.data_in.is_empty() { self.data_in = match self.chan_in.poll_recv(cx) { @@ -148,13 +148,13 @@ impl AsyncRead for MockConn { }; } if self.data_in.is_empty() { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } - let amt = cmp::min(buf.len(), self.data_in.len()); + let amt = cmp::min(buf.remaining(), self.data_in.len()); let mut other = self.data_in.split_off(amt); mem::swap(&mut other, &mut self.data_in); - buf[..amt].copy_from_slice(&other); - Poll::Ready(Ok(amt)) + buf.put_slice(&other); + Poll::Ready(Ok(())) } } @@ -218,7 +218,7 @@ impl AsyncWrite for MockConn { .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe)) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.chan_out) .as_mut() .poll_close(cx) diff --git a/akasa/src/main.rs b/akasa/src/main.rs index e1152e6..281f920 100644 --- a/akasa/src/main.rs +++ b/akasa/src/main.rs @@ -114,7 +114,7 @@ fn main() -> anyhow::Result<()> { let mut global_state = GlobalState::new(config); global_state.auth_passwords = auth_passwords; let global = Arc::new(global_state); - server::rt_tokio::start(hook_handler, global)?; + server::rt::start(hook_handler, global)?; } Commands::DefaultConfig { allow_anonymous } => { let config = if allow_anonymous { diff --git a/docs/chinese/getting-started.md b/docs/chinese/getting-started.md index 4b066bd..83aa625 100644 --- a/docs/chinese/getting-started.md +++ b/docs/chinese/getting-started.md @@ -69,5 +69,5 @@ auth: # ws: None, # wss: None, # } -#[2023-00-00T00:00:00Z INFO akasa_core::server::rt_tokio] Listen mqtt@127.0.0.1:1883 success! (tokio) -``` \ No newline at end of file +#[2023-00-00T00:00:00Z INFO akasa_core::server::rt] Listen mqtt@127.0.0.1:1883 success! (tokio) +``` diff --git a/docs/english/getting-started.md b/docs/english/getting-started.md index 03b425c..e81d0af 100644 --- a/docs/english/getting-started.md +++ b/docs/english/getting-started.md @@ -69,5 +69,5 @@ The final step is to start the server: # ws: None, # wss: None, # } -#[2023-00-00T00:00:00Z INFO akasa_core::server::rt_tokio] Listen mqtt@127.0.0.1:1883 success! (tokio) -``` \ No newline at end of file +#[2023-00-00T00:00:00Z INFO akasa_core::server::rt] Listen mqtt@127.0.0.1:1883 success! (tokio) +``` From fa35930c20181995eb837678f0851a38fe6188c3 Mon Sep 17 00:00:00 2001 From: Linfeng Qian Date: Sat, 18 Nov 2023 20:23:08 +0800 Subject: [PATCH 3/3] refactor: remove executor abstraction layer --- akasa-core/src/protocols/mqtt/common.rs | 17 ++-- akasa-core/src/protocols/mqtt/v3/message.rs | 34 ++----- .../src/protocols/mqtt/v3/packet/connect.rs | 6 +- akasa-core/src/protocols/mqtt/v5/message.rs | 97 +++++++------------ .../src/protocols/mqtt/v5/packet/connect.rs | 11 +-- akasa-core/src/server/mod.rs | 9 +- akasa-core/src/server/rt.rs | 60 +----------- akasa-core/src/state.rs | 83 +--------------- akasa-core/src/tests/utils.rs | 12 +-- 9 files changed, 72 insertions(+), 257 deletions(-) diff --git a/akasa-core/src/protocols/mqtt/common.rs b/akasa-core/src/protocols/mqtt/common.rs index a96b72d..b1dc7e7 100644 --- a/akasa-core/src/protocols/mqtt/common.rs +++ b/akasa-core/src/protocols/mqtt/common.rs @@ -4,13 +4,12 @@ use std::time::{Duration, Instant}; use parking_lot::RwLock; -use crate::state::{ClientId, ControlMessage, Executor, GlobalState}; +use crate::state::{ClientId, ControlMessage, GlobalState}; -pub(crate) fn start_keep_alive_timer( +pub(crate) fn start_keep_alive_timer( keep_alive: u16, client_id: ClientId, last_packet_time: &Arc>, - executor: &E, global: &Arc, ) -> io::Result<()> { // FIXME: if kee_alive is zero, set a default keep_alive value from config @@ -19,7 +18,7 @@ pub(crate) fn start_keep_alive_timer( log::debug!("{} keep alive: {:?}", client_id, half_interval * 2); let last_packet_time = Arc::clone(last_packet_time); let global = Arc::clone(global); - if let Err(err) = executor.spawn_interval(move || { + let action_gen = move || { // Need clone twice: https://stackoverflow.com/a/68462908/1274372 let last_packet_time = Arc::clone(&last_packet_time); let global = Arc::clone(&global); @@ -45,10 +44,12 @@ pub(crate) fn start_keep_alive_timer( } None } - }) { - log::error!("spawn executor keep alive timer failed: {:?}", err); - return Err(err); - } + }; + tokio::spawn(async move { + while let Some(duration) = action_gen().await { + tokio::time::sleep(duration).await; + } + }); } Ok(()) } diff --git a/akasa-core/src/protocols/mqtt/v3/message.rs b/akasa-core/src/protocols/mqtt/v3/message.rs index 8aaa958..5772719 100644 --- a/akasa-core/src/protocols/mqtt/v3/message.rs +++ b/akasa-core/src/protocols/mqtt/v3/message.rs @@ -23,9 +23,7 @@ use crate::hook::{ use crate::protocols::mqtt::{ BroadcastPackets, OnlineLoop, OnlineSession, PendingPackets, WritePacket, }; -use crate::state::{ - ClientId, ClientReceiver, ControlMessage, Executor, GlobalState, NormalMessage, -}; +use crate::state::{ClientId, ClientReceiver, ControlMessage, GlobalState, NormalMessage}; use super::{ packet::{ @@ -43,7 +41,6 @@ use super::{ #[allow(clippy::too_many_arguments)] pub async fn handle_connection< T: AsyncRead + AsyncWrite + Unpin, - E: Executor, H: Hook + Clone + Send + Sync + 'static, >( conn: T, @@ -52,7 +49,6 @@ pub async fn handle_connection< protocol: Protocol, timeout_receiver: Receiver<()>, hook_handler: H, - executor: E, global: Arc, ) -> io::Result<()> { match handle_online( @@ -62,26 +58,23 @@ pub async fn handle_connection< protocol, timeout_receiver, &hook_handler, - &executor, &global, ) .await { Ok(Some((session, receiver))) => { log::info!( - "executor {:03}, {}({}) go to offline, total {} clients ({} online)", - executor.id(), + "{}({}) go to offline, total {} clients ({} online)", session.client_id, peer, global.clients_count(), global.online_clients_count(), ); - executor.spawn_local(handle_offline(session, receiver, global)); + tokio::spawn(handle_offline(session, receiver, global)); } Ok(None) => { log::info!( - "executor {:03}, {} finished, total {} clients ({} online)", - executor.id(), + "{} finished, total {} clients ({} online)", peer, global.clients_count(), global.online_clients_count(), @@ -89,8 +82,7 @@ pub async fn handle_connection< } Err(err) => { log::info!( - "executor {:03}, {} error: {}, total {} clients ({} online)", - executor.id(), + "{} error: {}, total {} clients ({} online)", peer, err, global.clients_count(), @@ -105,7 +97,6 @@ pub async fn handle_connection< #[allow(clippy::too_many_arguments)] async fn handle_online< T: AsyncRead + AsyncWrite + Unpin, - E: Executor, H: Hook + Clone + Send + Sync + 'static, >( mut conn: T, @@ -114,7 +105,6 @@ async fn handle_online< protocol: Protocol, timeout_receiver: Receiver<()>, hook_handler: &H, - executor: &E, global: &Arc, ) -> io::Result> { let mut session = Session::new(&global.config, peer); @@ -141,15 +131,8 @@ async fn handle_online< before_connect_hook(peer, &packet, hook_handler, global).await?; } - let session_present = handle_connect( - &mut session, - &mut receiver, - packet, - &mut conn, - executor, - global, - ) - .await?; + let session_present = + handle_connect(&mut session, &mut receiver, packet, &mut conn, global).await?; if !session.connected { log::info!("{} not connected", session.peer); @@ -167,8 +150,7 @@ async fn handle_online< let receiver = receiver.expect("receiver"); log::info!( - "executor {:03}, {} connected, total {} clients ({} online) ", - executor.id(), + "{} connected, total {} clients ({} online) ", session.peer, global.clients_count(), global.online_clients_count(), diff --git a/akasa-core/src/protocols/mqtt/v3/packet/connect.rs b/akasa-core/src/protocols/mqtt/v3/packet/connect.rs index 8349d48..3d1658c 100644 --- a/akasa-core/src/protocols/mqtt/v3/packet/connect.rs +++ b/akasa-core/src/protocols/mqtt/v3/packet/connect.rs @@ -9,17 +9,16 @@ use mqtt_proto::{ use tokio::io::AsyncWrite; use crate::protocols::mqtt::{check_password, start_keep_alive_timer}; -use crate::state::{AddClientReceipt, ClientReceiver, Executor, GlobalState}; +use crate::state::{AddClientReceipt, ClientReceiver, GlobalState}; use super::super::Session; use super::common::write_packet; -pub(crate) async fn handle_connect( +pub(crate) async fn handle_connect( session: &mut Session, receiver: &mut Option, packet: Connect, conn: &mut T, - executor: &E, global: &Arc, ) -> io::Result { log::debug!( @@ -152,7 +151,6 @@ clean session : {} session.keep_alive, session.client_id, &session.last_packet_time, - executor, global, )?; diff --git a/akasa-core/src/protocols/mqtt/v5/message.rs b/akasa-core/src/protocols/mqtt/v5/message.rs index 0dc4380..6bf5af6 100644 --- a/akasa-core/src/protocols/mqtt/v5/message.rs +++ b/akasa-core/src/protocols/mqtt/v5/message.rs @@ -26,9 +26,7 @@ use crate::hook::{ use crate::protocols::mqtt::{ BroadcastPackets, OnlineLoop, OnlineSession, PendingPackets, WritePacket, }; -use crate::state::{ - ClientId, ClientReceiver, ControlMessage, Executor, GlobalState, NormalMessage, -}; +use crate::state::{ClientId, ClientReceiver, ControlMessage, GlobalState, NormalMessage}; use super::{ packet::{ @@ -49,7 +47,6 @@ use super::{ #[allow(clippy::too_many_arguments)] pub async fn handle_connection< T: AsyncRead + AsyncWrite + Unpin, - E: Executor, H: Hook + Clone + Send + Sync + 'static, >( conn: T, @@ -58,7 +55,6 @@ pub async fn handle_connection< protocol: Protocol, timeout_receiver: Receiver<()>, hook_handler: H, - executor: E, global: Arc, ) -> io::Result<()> { match handle_online( @@ -68,44 +64,40 @@ pub async fn handle_connection< protocol, timeout_receiver, &hook_handler, - &executor, &global, ) .await { Ok(Some((session, receiver))) => { log::info!( - "executor {:03}, {}({}) go to offline, total {} clients ({} online)", - executor.id(), + "{}({}) go to offline, total {} clients ({} online)", session.client_id, peer, global.clients_count(), global.online_clients_count(), ); let session_expiry = Duration::from_secs(session.session_expiry_interval as u64); - executor.spawn_sleep(session_expiry, { - let client_id = session.client_id; - let connected_time = session.connected_time.expect("connected time"); - let global = Arc::clone(&global); - async move { - if let Some(sender) = global.get_client_control_sender(&client_id) { - let msg = ControlMessage::SessionExpired { connected_time }; - if let Err(err) = sender.send_async(msg).await { - log::warn!( - "send session expired message to {} error: {:?}", - client_id, - err - ); - } + let client_id = session.client_id; + let connected_time = session.connected_time.expect("connected time"); + let global_clone = Arc::clone(&global); + tokio::spawn(async move { + tokio::time::sleep(session_expiry).await; + if let Some(sender) = global_clone.get_client_control_sender(&client_id) { + let msg = ControlMessage::SessionExpired { connected_time }; + if let Err(err) = sender.send_async(msg).await { + log::warn!( + "send session expired message to {} error: {:?}", + client_id, + err + ); } } }); - executor.spawn_local(handle_offline(session, receiver, global)); + tokio::spawn(handle_offline(session, receiver, global)); } Ok(None) => { log::info!( - "executor {:03}, {} finished, total {} clients ({} online)", - executor.id(), + "{} finished, total {} clients ({} online)", peer, global.clients_count(), global.online_clients_count(), @@ -113,8 +105,7 @@ pub async fn handle_connection< } Err(err) => { log::info!( - "executor {:03}, {} error: {}, total {} clients ({} online)", - executor.id(), + "{} error: {}, total {} clients ({} online)", peer, err, global.clients_count(), @@ -129,7 +120,6 @@ pub async fn handle_connection< #[allow(clippy::too_many_arguments)] async fn handle_online< T: AsyncRead + AsyncWrite + Unpin, - E: Executor, H: Hook + Clone + Send + Sync + 'static, >( mut conn: T, @@ -138,7 +128,6 @@ async fn handle_online< protocol: Protocol, timeout_receiver: Receiver<()>, hook_handler: &H, - executor: &E, global: &Arc, ) -> io::Result> { let mut session = Session::new(&global.config, peer); @@ -185,15 +174,8 @@ async fn handle_online< before_connect_hook(&mut session, &mut conn, peer, &packet, hook_handler, global).await?; } - let mut session_present = handle_connect( - &mut session, - &mut receiver, - packet, - &mut conn, - executor, - global, - ) - .await?; + let mut session_present = + handle_connect(&mut session, &mut receiver, packet, &mut conn, global).await?; // * Scram challenge only need 1 round. // * Kerberos challenge need 2 rounds. @@ -253,7 +235,6 @@ async fn handle_online< &mut receiver, Some(server_final), &mut conn, - executor, global, ) .await?; @@ -296,8 +277,7 @@ async fn handle_online< let receiver = receiver.expect("receiver"); log::info!( - "executor {:03}, {} connected, total {} clients ({} online) ", - executor.id(), + "{} connected, total {} clients ({} online) ", session.peer, global.clients_count(), global.online_clients_count(), @@ -328,7 +308,7 @@ async fn handle_online< ); // FIXME: check all place depend on session.disconnected if !session.client_disconnected { - handle_will(&mut session, executor, global).await?; + handle_will(&mut session, global).await?; } broadcast_packets(&mut session).await; if session.session_expiry_interval == 0 { @@ -702,11 +682,7 @@ async fn handle_offline(mut session: Session, receiver: ClientReceiver, global: } #[inline] -async fn handle_will( - session: &mut Session, - executor: &E, - global: &Arc, -) -> io::Result<()> { +async fn handle_will(session: &mut Session, global: &Arc) -> io::Result<()> { if let Some(last_will) = session.last_will.as_ref() { let delay_interval = last_will.properties.delay_interval.unwrap_or(0); if delay_interval == 0 { @@ -717,20 +693,19 @@ async fn handle_will( ); send_will(session, global)?; } else if delay_interval < session.session_expiry_interval { - executor.spawn_sleep(Duration::from_secs(delay_interval as u64), { - let client_id = session.client_id; - let connected_time = session.connected_time.expect("connected time (will)"); - let global = Arc::clone(global); - async move { - if let Some(sender) = global.get_client_control_sender(&client_id) { - let msg = ControlMessage::WillDelayReached { connected_time }; - if let Err(err) = sender.send_async(msg).await { - log::warn!( - "send will delay reached message to {} error: {:?}", - client_id, - err - ); - } + let client_id = session.client_id; + let connected_time = session.connected_time.expect("connected time (will)"); + let global = Arc::clone(global); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(delay_interval as u64)).await; + if let Some(sender) = global.get_client_control_sender(&client_id) { + let msg = ControlMessage::WillDelayReached { connected_time }; + if let Err(err) = sender.send_async(msg).await { + log::warn!( + "send will delay reached message to {} error: {:?}", + client_id, + err + ); } } }); diff --git a/akasa-core/src/protocols/mqtt/v5/packet/connect.rs b/akasa-core/src/protocols/mqtt/v5/packet/connect.rs index 200a5d6..c8b5ca9 100644 --- a/akasa-core/src/protocols/mqtt/v5/packet/connect.rs +++ b/akasa-core/src/protocols/mqtt/v5/packet/connect.rs @@ -15,17 +15,16 @@ use tokio::io::AsyncWrite; use crate::config::SaslMechanism; use crate::protocols::mqtt::{check_password, start_keep_alive_timer}; -use crate::state::{AddClientReceipt, ClientReceiver, Executor, GlobalState}; +use crate::state::{AddClientReceipt, ClientReceiver, GlobalState}; use super::super::{ScramStage, Session, TracedRng}; use super::common::{build_error_connack, build_error_disconnect, write_packet}; -pub(crate) async fn handle_connect( +pub(crate) async fn handle_connect( session: &mut Session, receiver: &mut Option, packet: Connect, conn: &mut T, - executor: &E, global: &Arc, ) -> io::Result { log::debug!( @@ -209,16 +208,15 @@ pub(crate) async fn handle_connect( write_packet(session.client_id, conn, &err_pkt).await?; Ok(false) } else { - session_connect(session, receiver, None, conn, executor, global).await + session_connect(session, receiver, None, conn, global).await } } -pub(crate) async fn session_connect( +pub(crate) async fn session_connect( session: &mut Session, receiver: &mut Option, auth_data: Option, conn: &mut T, - executor: &E, global: &Arc, ) -> io::Result { let mut session_present = false; @@ -267,7 +265,6 @@ pub(crate) async fn session_connect( session.keep_alive, session.client_id, &session.last_packet_time, - executor, global, )?; diff --git a/akasa-core/src/server/mod.rs b/akasa-core/src/server/mod.rs index 5929ed0..361d582 100644 --- a/akasa-core/src/server/mod.rs +++ b/akasa-core/src/server/mod.rs @@ -26,7 +26,7 @@ use tokio_tungstenite::{ use crate::config::TlsListener; use crate::hook::Hook; use crate::protocols::mqtt; -use crate::state::{Executor, GlobalState}; +use crate::state::GlobalState; use proxy::{parse_header, Addresses}; @@ -34,19 +34,18 @@ const CONNECT_TIMEOUT_SECS: u64 = 5; pub async fn handle_accept< T: AsyncRead + AsyncWrite + Unpin, - E: Executor, H: Hook + Clone + Send + Sync + 'static, >( mut conn: T, conn_args: ConnectionArgs, mut peer: SocketAddr, hook_handler: H, - executor: E, global: Arc, ) -> io::Result<()> { // If the client don't send enough data in 5 seconds, disconnect it. let (timeout_sender, timeout_receiver) = bounded(1); - executor.spawn_sleep(Duration::from_secs(CONNECT_TIMEOUT_SECS), async move { + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(CONNECT_TIMEOUT_SECS)).await; if timeout_sender.send_async(()).await.is_ok() { log::info!("connection timeout: {}", peer); } @@ -185,7 +184,6 @@ pub async fn handle_accept< protocol, timeout_receiver, hook_handler, - executor, global, ) .await?; @@ -199,7 +197,6 @@ pub async fn handle_accept< protocol, timeout_receiver, hook_handler, - executor, global, ) .await?; diff --git a/akasa-core/src/server/rt.rs b/akasa-core/src/server/rt.rs index c43ca37..f84ab50 100644 --- a/akasa-core/src/server/rt.rs +++ b/akasa-core/src/server/rt.rs @@ -1,4 +1,3 @@ -use std::future::Future; use std::io; use std::sync::Arc; use std::time::Duration; @@ -8,7 +7,7 @@ use tokio::{net::TcpListener, runtime::Runtime}; use super::{build_tls_context, handle_accept, ConnectionArgs}; use crate::config::{Listener, ProxyMode, TlsListener}; use crate::hook::Hook; -use crate::state::{Executor, GlobalState}; +use crate::state::GlobalState; pub fn start(hook_handler: H, global: Arc) -> io::Result<()> where @@ -38,8 +37,6 @@ where .transpose()?; rt.block_on(async move { - let executor = Arc::new(TokioExecutor {}); - let listeners = &global.config.listeners; let tasks: Vec<_> = [ listeners @@ -88,15 +85,11 @@ where .map(|conn_args| { let global = Arc::clone(&global); let hook_handler = hook_handler.clone(); - let executor = Arc::clone(&executor); tokio::spawn(async move { loop { let global = Arc::clone(&global); let hook_handler = hook_handler.clone(); - let executor = Arc::clone(&executor); - if let Err(err) = - listen(conn_args.clone(), hook_handler, executor, global).await - { + if let Err(err) = listen(conn_args.clone(), hook_handler, global).await { log::error!("Listen error: {:?}", err); tokio::time::sleep(Duration::from_secs(1)).await; } @@ -115,10 +108,9 @@ where Ok(()) } -async fn listen( +async fn listen( conn_args: ConnectionArgs, hook_handler: H, - executor: Arc, global: Arc, ) -> io::Result<()> { let addr = conn_args.addr; @@ -140,53 +132,9 @@ async fn listen(&self, future: F) - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - tokio::spawn(future); - } - - fn spawn_sleep(&self, duration: Duration, task: F) - where - F: Future + Send + 'static, - { - tokio::spawn(async move { - tokio::time::sleep(duration).await; - task.await; - }); - } - - fn spawn_interval(&self, action_gen: G) -> io::Result<()> - where - G: (Fn() -> F) + Send + Sync + 'static, - F: Future> + Send + 'static, - { - tokio::spawn(async move { - while let Some(duration) = action_gen().await { - tokio::time::sleep(duration).await; - } + let _ = handle_accept(conn, conn_args, peer, hook_handler, Arc::clone(&global)).await; }); - Ok(()) } } diff --git a/akasa-core/src/state.rs b/akasa-core/src/state.rs index e60d577..0eb38f6 100644 --- a/akasa-core/src/state.rs +++ b/akasa-core/src/state.rs @@ -1,11 +1,11 @@ use std::fmt; -use std::future::Future; + use std::io; use std::num::NonZeroU32; -use std::rc::Rc; + use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; -use std::time::{Duration, Instant}; + +use std::time::Instant; use bytes::Bytes; use dashmap::DashMap; @@ -216,81 +216,6 @@ impl GlobalState { } } -pub trait Executor { - fn id(&self) -> usize { - 0 - } - fn spawn_local(&self, future: F) - where - F: Future + Send + 'static, - F::Output: Send + 'static; - - fn spawn_sleep(&self, duration: Duration, task: F) - where - F: Future + Send + 'static; - - fn spawn_interval(&self, action_gen: G) -> io::Result<()> - where - G: (Fn() -> F) + Send + Sync + 'static, - F: Future> + Send + 'static; -} - -impl Executor for Rc { - fn id(&self) -> usize { - self.as_ref().id() - } - fn spawn_local(&self, future: F) - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - self.as_ref().spawn_local(future); - } - - fn spawn_sleep(&self, duration: Duration, task: F) - where - F: Future + Send + 'static, - { - self.as_ref().spawn_sleep(duration, task); - } - - fn spawn_interval(&self, action_gen: G) -> io::Result<()> - where - G: (Fn() -> F) + Send + Sync + 'static, - F: Future> + Send + 'static, - { - self.as_ref().spawn_interval(action_gen) - } -} -impl Executor for Arc { - fn id(&self) -> usize { - self.as_ref().id() - } - - fn spawn_local(&self, future: F) - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - self.as_ref().spawn_local(future); - } - - fn spawn_sleep(&self, duration: Duration, task: F) - where - F: Future + Send + 'static, - { - self.as_ref().spawn_sleep(duration, task); - } - - fn spawn_interval(&self, action_gen: G) -> io::Result<()> - where - G: (Fn() -> F) + Send + Sync + 'static, - F: Future> + Send + 'static, - { - self.as_ref().spawn_interval(action_gen) - } -} - #[derive(Debug, Clone)] pub enum ControlMessage { /// The v3.x client of the session connected, send the keept session to the connection loop diff --git a/akasa-core/src/tests/utils.rs b/akasa-core/src/tests/utils.rs index 188dd3b..5fca219 100644 --- a/akasa-core/src/tests/utils.rs +++ b/akasa-core/src/tests/utils.rs @@ -22,7 +22,7 @@ use crate::hook::{ Hook, HookAction, HookConnectCode, HookPublishCode, HookResult, HookSubscribeCode, HookUnsubscribeCode, }; -use crate::server::{handle_accept, rt::TokioExecutor, ConnectionArgs}; +use crate::server::{handle_accept, ConnectionArgs}; use crate::state::{AuthPassword, GlobalState, HashAlgorithm}; use crate::{hash_password, SessionV3, SessionV5, MIN_SALT_LEN}; @@ -101,7 +101,6 @@ impl MockConn { impl MockConnControl { pub fn start(&self, conn: MockConn) -> JoinHandle> { let peer = conn.peer; - let executor = Arc::new(TokioExecutor {}); let global = Arc::clone(&self.global); let hook_handler = TestHook; @@ -112,14 +111,7 @@ impl MockConnControl { websocket: false, tls_acceptor: None, }; - tokio::spawn(handle_accept( - conn, - conn_args, - peer, - hook_handler, - executor, - global, - )) + tokio::spawn(handle_accept(conn, conn_args, peer, hook_handler, global)) } pub fn try_read_packet_is_empty(&mut self) -> bool {