diff --git a/Cargo.toml b/Cargo.toml index 7dbc558..4cc76a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,9 @@ edition = "2021" metrics = "0.22.0" tracing = "0.1.40" -[dev-dependencies] +[target.'cfg(loom)'.dependencies] +loom = { version = "0.7", features = ["futures"] } + +[target.'cfg(not(loom))'.dev-dependencies] smol = "2.0.0" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/examples/demo.rs b/examples/demo.rs index f495413..e77f0be 100644 --- a/examples/demo.rs +++ b/examples/demo.rs @@ -1,6 +1,9 @@ +#[cfg(not(loom))] use std::time::Duration; +#[cfg(not(loom))] use tracing_subscriber::{fmt, prelude::*, EnvFilter}; +#[cfg(not(loom))] fn main() -> Result<(), Box> { tracing_subscriber::registry() .with(fmt::layer()) @@ -52,3 +55,6 @@ fn main() -> Result<(), Box> { Ok(()) }) } + +#[cfg(loom)] +fn main() {} diff --git a/src/drop_notifier.rs b/src/drop_notifier.rs index 8326d13..b680165 100644 --- a/src/drop_notifier.rs +++ b/src/drop_notifier.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use crate::sync::Arc; use crate::notify::Notify; diff --git a/src/lib.rs b/src/lib.rs index e6c3e7e..efeac16 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ mod notify; mod queue; mod selector; mod spsc; +mod sync; use std::{ future::Future, diff --git a/src/notify.rs b/src/notify.rs index b5b25e3..1017ac9 100644 --- a/src/notify.rs +++ b/src/notify.rs @@ -1,13 +1,11 @@ use std::{ collections::VecDeque, future::{poll_fn, Future}, - sync::{ - atomic::{AtomicU8, Ordering}, - Arc, Mutex, - }, task::{Poll, Waker}, }; +use crate::sync::{Arc, AtomicU8, Mutex, Ordering}; + const UNNOTIFIED: u8 = 0b0000; const NOTIFIED_ONE: u8 = 0b0001; const RESOLVED: u8 = 0b0010; @@ -210,3 +208,213 @@ impl Future for Listener<'_> { Poll::Pending } } + +#[cfg(all(test, loom))] +mod tests { + use super::Notify; + use std::future::Future; + + struct NoopWaker; + + impl std::task::Wake for NoopWaker { + fn wake(self: std::sync::Arc) {} + fn wake_by_ref(self: &std::sync::Arc) {} + } + + fn noop_waker() -> std::task::Waker { + std::sync::Arc::new(NoopWaker).into() + } + + #[test] + fn dropped_notified_listener() { + loom::model(|| { + loom::future::block_on(async { + let notify = Notify::new(); + + let listener = notify.listen().await; + + notify.notify_one(); + drop(listener); + + assert_eq!( + notify.state.lock().unwrap().token, + 1, + "Dropped notify should not have consumed token" + ); + }); + }); + } + + #[test] + fn threaded_dropped_notified_listener() { + loom::model(|| { + let notify = loom::sync::Arc::new(Notify::new()); + + let notify2 = notify.clone(); + let handle = loom::thread::spawn(move || { + loom::future::block_on(async move { + drop(notify2.listen().await); + }); + }); + + notify.notify_one(); + + handle.join().unwrap(); + + assert_eq!( + notify.state.lock().unwrap().token, + 1, + "Dropped notify should not have consumed token" + ); + }); + } + + #[test] + fn notified_listener() { + loom::model(|| { + loom::future::block_on(async { + let notify = Notify::new(); + + let mut listener = notify.listen().await; + + notify.notify_one(); + + let waker = noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + + assert_eq!( + std::pin::Pin::new(&mut listener).poll(&mut cx), + std::task::Poll::Ready(()), + "Polled listen should be notified" + ); + assert_eq!( + notify.state.lock().unwrap().token, + 0, + "Dropped notify should have consumed token" + ); + }) + }); + } + + #[test] + fn threaded_notified_listener() { + loom::model(|| { + let notify = loom::sync::Arc::new(Notify::new()); + + let notify2 = notify.clone(); + let handle = loom::thread::spawn(move || { + loom::future::block_on(async move { + notify2.listen().await.await; + }); + }); + + notify.notify_one(); + + handle.join().unwrap(); + + assert_eq!( + notify.state.lock().unwrap().token, + 0, + "Dropped notify should have consumed token" + ); + }); + } + + #[test] + fn multiple_listeners() { + loom::model(|| { + loom::future::block_on(async { + let notify = Notify::new(); + + let mut listener_1 = notify.listen().await; + let mut listener_2 = notify.listen().await; + + notify.notify_one(); + + let waker = noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + + assert_eq!( + std::pin::Pin::new(&mut listener_1).poll(&mut cx), + std::task::Poll::Ready(()), + "Polled listen_1 should be notified" + ); + + assert_eq!( + std::pin::Pin::new(&mut listener_2).poll(&mut cx), + std::task::Poll::Pending, + "Polled listen_2 should not be notified" + ); + }); + }); + } + + #[test] + fn multiple_notifies() { + loom::model(|| { + loom::future::block_on(async { + let notify = Notify::new(); + + let mut listener_1 = notify.listen().await; + let mut listener_2 = notify.listen().await; + + notify.notify_one(); + notify.notify_one(); + + let waker = noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + + assert_eq!( + std::pin::Pin::new(&mut listener_1).poll(&mut cx), + std::task::Poll::Ready(()), + "Polled listen_1 should be notified" + ); + + assert_eq!( + std::pin::Pin::new(&mut listener_2).poll(&mut cx), + std::task::Poll::Ready(()), + "Polled listen_2 should be notified" + ); + + assert_eq!( + notify.state.lock().unwrap().token, + 0, + "notifies should have consumed tokens" + ); + }); + }); + } + + #[test] + fn threaded_multiple_notifies() { + loom::model(|| { + let notify = loom::sync::Arc::new(Notify::new()); + + let notify2 = notify.clone(); + let handle1 = loom::thread::spawn(move || { + loom::future::block_on(async move { + notify2.listen().await.await; + }) + }); + + let notify2 = notify.clone(); + let handle2 = loom::thread::spawn(move || { + loom::future::block_on(async move { + notify2.listen().await.await; + }) + }); + + notify.notify_one(); + notify.notify_one(); + + handle1.join().unwrap(); + handle2.join().unwrap(); + + assert_eq!( + notify.state.lock().unwrap().token, + 0, + "threaded notifies should have consumed tokens" + ); + }); + } +} diff --git a/src/queue.rs b/src/queue.rs index 7db49b3..708b499 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,9 +1,7 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; +use std::collections::VecDeque; use crate::notify::Notify; +use crate::sync::{Arc, Mutex}; pub(super) fn bounded(capacity: usize) -> Queue { Queue::bounded(capacity) diff --git a/src/selector.rs b/src/selector.rs index f181877..72a8332 100644 --- a/src/selector.rs +++ b/src/selector.rs @@ -1,13 +1,11 @@ use std::{ future::Future, pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, task::{Context, Poll, Wake, Waker}, }; +use crate::sync::{Arc, AtomicBool, Ordering}; + pub(super) enum Either { Left(L), Right(R), @@ -27,16 +25,16 @@ struct SelectWaker { } impl Wake for SelectWaker { - fn wake_by_ref(self: &Arc) { + fn wake_by_ref(self: &std::sync::Arc) { self.flag.store(true, Ordering::Release); self.inner.wake_by_ref(); } - fn wake(self: Arc) { + fn wake(self: std::sync::Arc) { self.flag.store(true, Ordering::Release); - match Arc::try_unwrap(self) { + match std::sync::Arc::try_unwrap(self) { Ok(this) => this.inner.wake(), Err(this) => this.inner.wake_by_ref(), } @@ -51,7 +49,7 @@ where type Output = Either; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let left_waker = Arc::new(SelectWaker { + let left_waker = std::sync::Arc::new(SelectWaker { inner: cx.waker().clone(), flag: self.left_woken.clone(), }) @@ -63,7 +61,7 @@ where return Poll::Ready(Either::Left(left_out)); } - let right_waker = Arc::new(SelectWaker { + let right_waker = std::sync::Arc::new(SelectWaker { inner: cx.waker().clone(), flag: self.right_woken.clone(), }) diff --git a/src/sync.rs b/src/sync.rs new file mode 100644 index 0000000..e348296 --- /dev/null +++ b/src/sync.rs @@ -0,0 +1,11 @@ +#[cfg(not(loom))] +pub(crate) use std::sync::{ + atomic::{AtomicBool, AtomicU8, Ordering}, + Arc, Mutex, +}; + +#[cfg(loom)] +pub(crate) use loom::sync::{ + atomic::{AtomicBool, AtomicU8, Ordering}, + Arc, Mutex, +};