diff --git a/Cargo.toml b/Cargo.toml index 9eff9a3..6748de7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ members = [ "examples/no_macro/basic", "examples/no_macro/blinky", + "examples/no_macro/blinky_async", "examples/no_macro/bench", "examples/no_macro/history", "examples/no_macro/calculator", diff --git a/README.md b/README.md index b82329e..2a2ff67 100644 --- a/README.md +++ b/README.md @@ -494,6 +494,16 @@ Short answer: nothing. `#[state_machine]` simply parses the underlying `impl` bl I would say they serve a different purpose. The [typestate pattern](http://cliffle.com/blog/rust-typestate/) is very useful for designing an API as it is able to enforce the validity of operations at compile time by making each state a unique type. But `statig` is designed to model a dynamic system where events originate externally and the order of operations is determined at run time. More concretely, this means that the state machine is going to sit in a loop where events are read from a queue and submitted to the state machine using the `handle()` method. If we want to do the same with a state machine that uses the typestate pattern we'd have to use an enum to wrap all our different states and match events to operations on these states. This means extra boilerplate code for little advantage as the order of operations is unknown so it can't be checked at compile time. On the other hand `statig` gives you the ability to create a hierarchy of states which I find to be invaluable as state machines grow in complexity. +## Testing + +Install the following dependencies: + +```sh +sudo apt install cmake libfontconfig1-dev +cargo test --workspace +``` + + --- ## Credits diff --git a/examples/macro/async_blinky/src/main.rs b/examples/macro/async_blinky/src/main.rs index 4d0eb9a..9f35e5e 100644 --- a/examples/macro/async_blinky/src/main.rs +++ b/examples/macro/async_blinky/src/main.rs @@ -1,9 +1,12 @@ #![allow(unused)] use futures::executor; +use futures::future::poll_fn; use statig::prelude::*; use std::fmt::Debug; +use std::future::Future; use std::io::Write; +use std::pin::Pin; use std::thread::spawn; #[derive(Debug, Default)] @@ -29,6 +32,8 @@ pub enum Event { superstate(derive(Debug)), // Set the `on_transition` callback. on_transition = "Self::on_transition", + // Set the `on_transition_async` callback. + on_transition_async = "Self::on_transition_async", // Set the `on_dispatch` callback. on_dispatch = "Self::on_dispatch" )] @@ -83,6 +88,19 @@ impl Blinky { println!("transitioned from `{source:?}` to `{target:?}`"); } + async fn transitioning(&mut self, from: &State, to: &State) { + println!("transitioning from {:?} to {:?}", from, to); + } + + fn on_transition_async<'a>( + &'a mut self, + source: &'a State, + target: &'a State, + ) -> Pin + Send + 'a>> { + println!("transitioned async from `{source:?}` to `{target:?}`"); + Box::pin(self.transitioning(source, target)) + } + fn on_dispatch(&mut self, state: StateOrSuperstate, event: &Event) { println!("dispatching `{event:?}` to `{state:?}`"); } diff --git a/examples/no_macro/blinky_async/Cargo.toml b/examples/no_macro/blinky_async/Cargo.toml new file mode 100644 index 0000000..032376f --- /dev/null +++ b/examples/no_macro/blinky_async/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "no_macro_blink_async" +version = "0.1.0" +edition = "2021" + +[dependencies] +statig = { path = "../../../statig", features = ["async"] } +tokio = { version = "*", features = ["full"] } diff --git a/examples/no_macro/blinky_async/src/main.rs b/examples/no_macro/blinky_async/src/main.rs new file mode 100644 index 0000000..e9c7deb --- /dev/null +++ b/examples/no_macro/blinky_async/src/main.rs @@ -0,0 +1,150 @@ +#![allow(unused)] + +use statig::awaitable::{self, *}; +use std::{ + future::{poll_fn, Future}, + io::Write, + pin::Pin, + task::Poll, +}; + +#[derive(Default)] +pub struct Blinky { + field: String, +} + +// The event that will be handled by the state machine. +pub enum Event { + TimerElapsed, + ButtonPressed, +} + +// The enum representing all states of the state machine. These are +// the states you can actually transition to. +#[derive(Debug)] +pub enum State { + LedOn, + LedOff, + NotBlinking, +} + +// The enum representing the superstates of the system. You can not transition +// to a superstate, but instead they define shared behavior of underlying states or +// superstates. +pub enum Superstate { + Blinking, +} + +// The `statig` trait needs to be implemented on the type that will +// imlement the state machine. +impl IntoStateMachine for Blinky { + /// The enum that represents the state. + type State = State; + + type Superstate<'sub> = Superstate; + + /// The event type that will be submitted to the state machine. + type Event<'evt> = Event; + + type Context<'ctx> = (); + + /// The initial state of the state machine. + const INITIAL: State = State::LedOn; + + const ON_TRANSITION_ASYNC: for<'fut> fn( + &'fut mut Self, + &'fut Self::State, + &'fut Self::State, + ) + -> Pin + Send + 'fut>> = |blinky, from, to| { + println!("transitioned from {:?} to {:?}", from, to); + Box::pin(blinky.transition_and_print_internal_state(from, to)) + }; +} + +// Implement the `statig::State` trait for the state enum. +impl awaitable::State for State { + fn call_handler<'fut>( + &'fut mut self, + blinky: &'fut mut Blinky, + event: &'fut Event, + _: &'fut mut (), + ) -> Pin> + Send + 'fut)>> { + match self { + State::LedOn => Box::pin(Blinky::timer_elapsed_turn_off(event)), + State::LedOff => Box::pin(Blinky::timer_elapsed_turn_on(event)), + State::NotBlinking => Box::pin(Blinky::not_blinking_button_pressed(event)), + } + } + + fn superstate<'fut>(&mut self) -> Option { + match self { + State::LedOn => Some(Superstate::Blinking), + State::LedOff => Some(Superstate::Blinking), + State::NotBlinking => None, + } + } +} + +// Implement the `statig::Superstate` trait for the superstate enum. +impl awaitable::Superstate for Superstate { + fn call_handler<'fut>( + &'fut mut self, + blinky: &'fut mut Blinky, + event: &'fut Event, + _: &'fut mut (), + ) -> Pin> + Send + 'fut)>> { + Box::pin(match self { + Superstate::Blinking => Blinky::blinking_button_pressed(event), + }) + } +} + +impl Blinky { + async fn transition_and_print_internal_state(&mut self, from: &State, to: &State) { + println!( + "transitioned (current test value is: {}) from {:?} to {:?}", + self.field, from, to + ); + } + async fn timer_elapsed_turn_off(event: &Event) -> Response { + match event { + Event::TimerElapsed => Transition(State::LedOff), + _ => Super, + } + } + + async fn timer_elapsed_turn_on(event: &Event) -> Response { + match event { + Event::TimerElapsed => Transition(State::LedOn), + _ => Super, + } + } + + async fn blinking_button_pressed(event: &Event) -> Response { + match event { + Event::ButtonPressed => Transition(State::NotBlinking), + _ => Super, + } + } + + async fn not_blinking_button_pressed(event: &Event) -> Response { + match event { + Event::ButtonPressed => Transition(State::LedOn), + _ => Super, + } + } +} + +#[tokio::main] +async fn main() { + let mut state_machine = Blinky { + field: "test field value".to_string(), + } + .state_machine(); + + state_machine.handle(&Event::TimerElapsed).await; + state_machine.handle(&Event::ButtonPressed).await; + state_machine.handle(&Event::TimerElapsed).await; + state_machine.handle(&Event::ButtonPressed).await; +} diff --git a/macro/src/analyze.rs b/macro/src/analyze.rs index 1e1f90d..50c955b 100644 --- a/macro/src/analyze.rs +++ b/macro/src/analyze.rs @@ -50,6 +50,8 @@ pub struct StateMachine { pub visibility: Visibility, /// Optional `on_transition` callback. pub on_transition: Option, + /// Optional `on_transition_async` callback. + pub on_transition_async: Option, /// Optional `on_dispatch` callback. pub on_dispatch: Option, } @@ -180,6 +182,7 @@ pub fn analyze_state_machine(attribute_args: &AttributeArgs, item_impl: &ItemImp let mut superstate_derives = Vec::new(); let mut on_transition = None; + let mut on_transition_async = None; let mut on_dispatch = None; let mut visibility = parse_quote!(pub); @@ -224,6 +227,14 @@ pub fn analyze_state_machine(attribute_args: &AttributeArgs, item_impl: &ItemImp _ => abort!(name_value, "must be a string literal"), } } + NestedMeta::Meta(Meta::NameValue(name_value)) + if name_value.path.is_ident("on_transition_async") => + { + on_transition_async = match &name_value.lit { + Lit::Str(input_pat) => Some(input_pat.parse().unwrap()), + _ => abort!(name_value, "must be a string literal"), + } + } NestedMeta::Meta(Meta::NameValue(name_value)) if name_value.path.is_ident("on_dispatch") => { @@ -341,6 +352,7 @@ pub fn analyze_state_machine(attribute_args: &AttributeArgs, item_impl: &ItemImp superstate_derives, on_dispatch, on_transition, + on_transition_async, event_ident, context_ident, visibility, @@ -660,6 +672,7 @@ fn valid_state_analyze() { superstate_ident, superstate_derives, on_transition, + on_transition_async: None, on_dispatch, event_ident, context_ident, diff --git a/macro/src/codegen.rs b/macro/src/codegen.rs index 42d1a1b..118bd75 100644 --- a/macro/src/codegen.rs +++ b/macro/src/codegen.rs @@ -37,7 +37,6 @@ pub fn codegen(ir: Ir) -> TokenStream { #superstate_impl ) } - fn codegen_state_machine_impl(ir: &Ir) -> ItemImpl { let shared_storage_type = &ir.state_machine.shared_storage_type; let (impl_generics, _, where_clause) = @@ -66,6 +65,13 @@ fn codegen_state_machine_impl(ir: &Ir) -> ItemImpl { ), }; + let on_transition_async = match &ir.state_machine.on_transition_async { + None => quote!(), + Some(on_transition_async) => quote!( + const ON_TRANSITION_ASYNC: for <'a> fn(&'a mut Self, &'a Self::State, &'a Self::State) -> core::pin::Pin + Send + 'a>> = #on_transition_async; + ), + }; + let on_dispatch = match &ir.state_machine.on_dispatch { None => quote!(), Some(on_dispatch) => quote!( @@ -84,6 +90,8 @@ fn codegen_state_machine_impl(ir: &Ir) -> ItemImpl { #on_transition + #on_transition_async + #on_dispatch } ) diff --git a/macro/src/lower.rs b/macro/src/lower.rs index d32a1da..84314d6 100644 --- a/macro/src/lower.rs +++ b/macro/src/lower.rs @@ -58,6 +58,8 @@ pub struct StateMachine { pub superstate_generics: Generics, /// The path of the `on_transition` callback. pub on_transition: Option, + /// The path of the `on_transition_async` callback. + pub on_transition_async: Option, /// The path of the `on_dispatch` callback. pub on_dispatch: Option, /// The visibility for the derived types, @@ -137,6 +139,8 @@ pub fn lower(model: &Model) -> Ir { let state_ident = model.state_machine.state_ident.clone(); let superstate_ident = model.state_machine.superstate_ident.clone(); let on_transition = model.state_machine.on_transition.clone(); + + let on_transition_async = model.state_machine.on_transition_async.clone(); let on_dispatch = model.state_machine.on_dispatch.clone(); let event_ident = model.state_machine.event_ident.clone(); let context_ident = model.state_machine.context_ident.clone(); @@ -421,6 +425,7 @@ pub fn lower(model: &Model) -> Ir { superstate_derives, superstate_generics, on_transition, + on_transition_async, on_dispatch, visibility, event_ident, @@ -708,6 +713,7 @@ fn create_analyze_state_machine() -> analyze::StateMachine { superstate_ident: parse_quote!(Superstate), superstate_derives: vec![parse_quote!(Copy), parse_quote!(Clone)], on_transition: None, + on_transition_async: None, on_dispatch: None, visibility: parse_quote!(pub), event_ident: parse_quote!(input), @@ -733,6 +739,7 @@ fn create_lower_state_machine() -> StateMachine { superstate_derives: vec![parse_quote!(Copy), parse_quote!(Clone)], superstate_generics, on_transition: None, + on_transition_async: None, on_dispatch: None, visibility: parse_quote!(pub), event_ident: parse_quote!(input), diff --git a/statig/src/awaitable/state_machine.rs b/statig/src/awaitable/state_machine.rs index 015918d..9fad851 100644 --- a/statig/src/awaitable/state_machine.rs +++ b/statig/src/awaitable/state_machine.rs @@ -1,14 +1,11 @@ use core::fmt::Debug; -use super::awaitable; +use super::{awaitable, State, Superstate}; use crate::{Inner, IntoStateMachine}; /// A state machine where the shared storage is of type `Self`. -pub trait IntoStateMachineExt: IntoStateMachine -where - Self: Send, - for<'sub> Self::Superstate<'sub>: awaitable::Superstate + Send, - Self::State: awaitable::State + Send, +pub trait IntoStateMachineExt: + for<'sub> IntoStateMachine: Superstate + Send, State: State> + Send { /// Create a state machine that will be lazily initialized. fn state_machine(self) -> StateMachine @@ -36,11 +33,8 @@ where } } -impl IntoStateMachineExt for T -where - Self: IntoStateMachine + Send, - for<'sub> Self::Superstate<'sub>: awaitable::Superstate + Send, - Self::State: awaitable::State + Send, +impl IntoStateMachineExt for T where + T: for<'sub> IntoStateMachine, Superstate<'sub>: Superstate + Send> + Send { } @@ -55,7 +49,7 @@ where impl StateMachine where - M: IntoStateMachine + Send, + M: IntoStateMachineExt, M::State: awaitable::State + 'static + Send, for<'sub> M::Superstate<'sub>: awaitable::Superstate + Send, { @@ -233,21 +227,21 @@ where /// A state machine that has been initialized. pub struct InitializedStateMachine where - M: IntoStateMachine, + M: IntoStateMachineExt, { inner: Inner, } impl InitializedStateMachine where - M: IntoStateMachine + Send, + M: IntoStateMachineExt, M::State: awaitable::State + 'static + Send, for<'sub> M::Superstate<'sub>: awaitable::Superstate + Send, { /// Handle the given event. pub async fn handle(&mut self, event: &M::Event<'_>) where - for<'ctx> M: IntoStateMachine = ()>, + for<'ctx> M: IntoStateMachineExt = ()>, for<'evt> M::Event<'evt>: Send + Sync, for<'ctx> M::Context<'ctx>: Send + Sync, { @@ -257,7 +251,7 @@ where /// Handle the given event. pub async fn handle_with_context(&mut self, event: &M::Event<'_>, context: &mut M::Context<'_>) where - M: IntoStateMachine, + M: IntoStateMachineExt, for<'evt> M::Event<'evt>: Send + Sync, for<'ctx> M::Context<'ctx>: Send + Sync, { @@ -290,7 +284,7 @@ where impl Clone for InitializedStateMachine where - M: IntoStateMachine + Clone, + M: IntoStateMachineExt + Clone, M::State: Clone, { fn clone(&self) -> Self { @@ -302,7 +296,7 @@ where impl Debug for InitializedStateMachine where - M: IntoStateMachine + Debug, + M: IntoStateMachineExt + Debug, M::State: Debug, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -315,7 +309,7 @@ where impl PartialEq for InitializedStateMachine where - M: IntoStateMachine + PartialEq, + M: IntoStateMachineExt + PartialEq, M::State: PartialEq, { fn eq(&self, other: &Self) -> bool { @@ -325,14 +319,14 @@ where impl Eq for InitializedStateMachine where - M: IntoStateMachine + PartialEq + Eq, - M::State: PartialEq + Eq, + M: IntoStateMachineExt + Eq, + M::State: Eq, { } impl core::ops::Deref for InitializedStateMachine where - M: IntoStateMachine, + M: IntoStateMachineExt, { type Target = M; @@ -379,16 +373,14 @@ where /// execute all the entry actions into the initial state. pub struct UninitializedStateMachine where - M: IntoStateMachine, + M: IntoStateMachineExt, { inner: Inner, } impl UninitializedStateMachine where - M: IntoStateMachine + Send, - M::State: awaitable::State + 'static + Send, - for<'sub> M::Superstate<'sub>: awaitable::Superstate + Send, + M: IntoStateMachineExt, { /// Initialize the state machine by executing all entry actions towards /// the initial state. @@ -416,9 +408,10 @@ where /// ``` pub async fn init(self) -> InitializedStateMachine where - for<'ctx> M: IntoStateMachine = ()>, + for<'ctx> M: IntoStateMachineExt = ()>, for<'evt> M::Event<'evt>: Send + Sync, for<'ctx> M::Context<'ctx>: Send + Sync, + M::State: Send + 'static, { let mut state_machine = InitializedStateMachine { inner: self.inner }; state_machine.inner.async_init_with_context(&mut ()).await; @@ -453,6 +446,7 @@ where where for<'evt> M::Event<'evt>: Send + Sync, for<'ctx> M::Context<'ctx>: Send + Sync, + M::State: Send + 'static, { let mut state_machine = InitializedStateMachine { inner: self.inner }; state_machine.inner.async_init_with_context(context).await; @@ -462,7 +456,7 @@ where impl Clone for UninitializedStateMachine where - M: IntoStateMachine + Clone, + M: IntoStateMachineExt + Clone, M::State: Clone, { fn clone(&self) -> Self { @@ -474,7 +468,7 @@ where impl Debug for UninitializedStateMachine where - M: IntoStateMachine + Debug, + M: IntoStateMachineExt + Debug, M::State: Debug, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -487,7 +481,7 @@ where impl PartialEq for UninitializedStateMachine where - M: IntoStateMachine + PartialEq, + M: IntoStateMachineExt + PartialEq, M::State: PartialEq, { fn eq(&self, other: &Self) -> bool { @@ -497,14 +491,14 @@ where impl Eq for UninitializedStateMachine where - M: IntoStateMachine + PartialEq + Eq, - M::State: PartialEq + Eq, + M: IntoStateMachineExt + PartialEq + Eq, + M::State: State + PartialEq + Eq, { } impl core::ops::Deref for UninitializedStateMachine where - M: IntoStateMachine, + M: IntoStateMachineExt, { type Target = M; diff --git a/statig/src/inner.rs b/statig/src/inner.rs index 46abe02..4337b89 100644 --- a/statig/src/inner.rs +++ b/statig/src/inner.rs @@ -1,3 +1,4 @@ +use crate::awaitable::IntoStateMachineExt; #[cfg(feature = "async")] use crate::awaitable::{self, StateExt as _}; use crate::blocking::{self, StateExt as _}; @@ -58,7 +59,7 @@ where #[cfg(feature = "async")] impl Inner where - M: IntoStateMachine + Send, + M: IntoStateMachineExt, for<'evt> M::Event<'evt>: Send + Sync, for<'ctx> M::Context<'ctx>: Send + Sync, M::State: awaitable::State + Send + 'static, @@ -106,6 +107,7 @@ where .await; M::ON_TRANSITION(&mut self.shared_storage, &target, &self.state); + M::ON_TRANSITION_ASYNC(&mut self.shared_storage, &target, &self.state).await; } } diff --git a/statig/src/into_state_machine.rs b/statig/src/into_state_machine.rs index bffea31..b37e1bf 100644 --- a/statig/src/into_state_machine.rs +++ b/statig/src/into_state_machine.rs @@ -1,3 +1,5 @@ +use core::{future::Future, pin::Pin}; + use crate::StateOrSuperstate; /// Trait for transorming a type into a state machine. @@ -29,4 +31,14 @@ where /// Method that is called *after* every transition. const ON_TRANSITION: fn(&mut Self, &Self::State, &Self::State) = |_, _, _| {}; + + const ON_TRANSITION_ASYNC: for<'fut> fn( + &'fut mut Self, + from: &'fut Self::State, + to: &'fut Self::State, + ) + -> Pin + Send + 'fut>> = |_, _, _| { + use std::task::Poll; + Box::pin(std::future::poll_fn(|_| Poll::Ready(()))) + }; }