use std::{ collections::VecDeque, future::{poll_fn, Future}, task::{Poll, Waker}, }; use crate::sync::{Arc, AtomicU8, Mutex, Ordering}; const UNNOTIFIED: u8 = 0b0000; const NOTIFIED_ONE: u8 = 0b0001; const RESOLVED: u8 = 0b0010; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct NotifyId(u64); pub(super) struct Notify { state: Mutex, } struct NotifyState { listeners: VecDeque<(NotifyId, Arc, Waker)>, token: u64, next_id: u64, } pub(super) struct Listener<'a> { state: &'a Mutex, waker: Waker, woken: Arc, id: NotifyId, } impl Notify { pub(super) fn new() -> Self { metrics::counter!("async-cpupool.notify.created").increment(1); Notify { state: Mutex::new(NotifyState { listeners: VecDeque::new(), token: 0, next_id: 0, }), } } // although this is an async fn, it is not capable of yielding to the executor pub(super) async fn listen(&self) -> Listener<'_> { poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await } pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> { let (id, woken) = self .state .lock() .expect("not poisoned") .insert(waker.clone()); Listener { state: &self.state, waker, woken, id, } } pub(super) fn notify_one(&self) { self.state.lock().expect("not poisoned").notify_one(); } } impl NotifyState { fn notify_one(&mut self) { loop { if let Some((_, woken, waker)) = self.listeners.pop_front() { metrics::counter!("async-cpupool.notify.removed").increment(1); // can't use "weak" because we need to know failure means we failed fr fr to avoid // popping unwoken listeners match woken.compare_exchange( UNNOTIFIED, NOTIFIED_ONE, Ordering::Release, Ordering::Relaxed, ) { Ok(_) => waker.wake(), // if this listener isn't unnotified (races Listener Drop) we should wake the // next one Err(_) => continue, } } else { self.token += 1; } break; } } fn insert(&mut self, waker: Waker) -> (NotifyId, Arc) { let id = NotifyId(self.next_id); self.next_id += 1; let token = if self.token > 0 { self.token -= 1; true } else { false }; // don't insert waker if token is true - next poll will be ready let woken = if token { Arc::new(AtomicU8::new(NOTIFIED_ONE)) } else { let woken = Arc::new(AtomicU8::new(UNNOTIFIED)); self.listeners.push_back((id, Arc::clone(&woken), waker)); metrics::counter!("async-cpupool.notify.inserted").increment(1); woken }; (id, woken) } fn remove(&mut self, id: NotifyId) { if let Some(index) = self.find(&id) { self.listeners.remove(index); metrics::counter!("async-cpupool.notify.removed").increment(1); } } fn find(&self, needle_id: &NotifyId) -> Option { self.listeners .binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id) .ok() } fn update(&mut self, id: NotifyId, needle_waker: &Waker) { if let Some(index) = self.find(&id) { if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) { if !needle_waker.will_wake(haystack_waker) { *haystack_waker = needle_waker.clone(); } } } } } impl Drop for Notify { fn drop(&mut self) { metrics::counter!("async-cpupool.notify.dropped").increment(1); } } impl Drop for Listener<'_> { fn drop(&mut self) { // races compare_exchange in notify_one let flags = self.woken.swap(RESOLVED, Ordering::AcqRel); if flags == RESOLVED { // do nothing } else if flags == NOTIFIED_ONE { let mut guard = self.state.lock().expect("not poisoned"); guard.notify_one(); } else if flags == UNNOTIFIED { let mut guard = self.state.lock().expect("not poisoned"); guard.remove(self.id); } else { unreachable!("No other states exist") } } } impl Future for Listener<'_> { type Output = (); fn poll( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll { let mut flags = self.woken.load(Ordering::Acquire); loop { if flags == UNNOTIFIED { break; } else if flags == RESOLVED { return Poll::Ready(()); } else { match self.woken.compare_exchange_weak( flags, RESOLVED, Ordering::Release, Ordering::Acquire, ) { Ok(_) => return Poll::Ready(()), Err(updated) => flags = updated, }; } } if !self.waker.will_wake(cx.waker()) { self.waker = cx.waker().clone(); self.state .lock() .expect("not poisoned") .update(self.id, cx.waker()); } Poll::Pending } } #[cfg(all(test, loom))] mod tests { use super::Notify; use std::future::Future; struct NoopWaker; impl std::task::Wake for NoopWaker { fn wake(self: std::sync::Arc) {} fn wake_by_ref(self: &std::sync::Arc) {} } fn noop_waker() -> std::task::Waker { std::sync::Arc::new(NoopWaker).into() } #[test] fn dropped_notified_listener() { loom::model(|| { loom::future::block_on(async { let notify = Notify::new(); let listener = notify.listen().await; notify.notify_one(); drop(listener); assert_eq!( notify.state.lock().unwrap().token, 1, "Dropped notify should not have consumed token" ); }); }); } #[test] fn threaded_dropped_notified_listener() { loom::model(|| { let notify = loom::sync::Arc::new(Notify::new()); let notify2 = notify.clone(); let handle = loom::thread::spawn(move || { loom::future::block_on(async move { drop(notify2.listen().await); }); }); notify.notify_one(); handle.join().unwrap(); assert_eq!( notify.state.lock().unwrap().token, 1, "Dropped notify should not have consumed token" ); }); } #[test] fn notified_listener() { loom::model(|| { loom::future::block_on(async { let notify = Notify::new(); let mut listener = notify.listen().await; notify.notify_one(); let waker = noop_waker(); let mut cx = std::task::Context::from_waker(&waker); assert_eq!( std::pin::Pin::new(&mut listener).poll(&mut cx), std::task::Poll::Ready(()), "Polled listen should be notified" ); assert_eq!( notify.state.lock().unwrap().token, 0, "Dropped notify should have consumed token" ); }) }); } #[test] fn threaded_notified_listener() { loom::model(|| { let notify = loom::sync::Arc::new(Notify::new()); let notify2 = notify.clone(); let handle = loom::thread::spawn(move || { loom::future::block_on(async move { notify2.listen().await.await; }); }); notify.notify_one(); handle.join().unwrap(); assert_eq!( notify.state.lock().unwrap().token, 0, "Dropped notify should have consumed token" ); }); } #[test] fn multiple_listeners() { loom::model(|| { loom::future::block_on(async { let notify = Notify::new(); let mut listener_1 = notify.listen().await; let mut listener_2 = notify.listen().await; notify.notify_one(); let waker = noop_waker(); let mut cx = std::task::Context::from_waker(&waker); assert_eq!( std::pin::Pin::new(&mut listener_1).poll(&mut cx), std::task::Poll::Ready(()), "Polled listen_1 should be notified" ); assert_eq!( std::pin::Pin::new(&mut listener_2).poll(&mut cx), std::task::Poll::Pending, "Polled listen_2 should not be notified" ); }); }); } #[test] fn multiple_notifies() { loom::model(|| { loom::future::block_on(async { let notify = Notify::new(); let mut listener_1 = notify.listen().await; let mut listener_2 = notify.listen().await; notify.notify_one(); notify.notify_one(); let waker = noop_waker(); let mut cx = std::task::Context::from_waker(&waker); assert_eq!( std::pin::Pin::new(&mut listener_1).poll(&mut cx), std::task::Poll::Ready(()), "Polled listen_1 should be notified" ); assert_eq!( std::pin::Pin::new(&mut listener_2).poll(&mut cx), std::task::Poll::Ready(()), "Polled listen_2 should be notified" ); assert_eq!( notify.state.lock().unwrap().token, 0, "notifies should have consumed tokens" ); }); }); } #[test] fn threaded_multiple_notifies() { loom::model(|| { let notify = loom::sync::Arc::new(Notify::new()); let notify2 = notify.clone(); let handle1 = loom::thread::spawn(move || { loom::future::block_on(async move { notify2.listen().await.await; }) }); let notify2 = notify.clone(); let handle2 = loom::thread::spawn(move || { loom::future::block_on(async move { notify2.listen().await.await; }) }); notify.notify_one(); notify.notify_one(); handle1.join().unwrap(); handle2.join().unwrap(); assert_eq!( notify.state.lock().unwrap().token, 0, "threaded notifies should have consumed tokens" ); }); } }