use std::{ collections::VecDeque, future::{poll_fn, Future}, sync::{ atomic::{AtomicU8, Ordering}, Arc, Mutex, }, task::{Poll, Waker}, }; const UNNOTIFIED: u8 = 0b0000; const NOTIFIED_ONE: u8 = 0b0001; const RESOLVED: u8 = 0b0010; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct NotifyId(u64); pub(super) struct Notify { state: Mutex, } struct NotifyState { listeners: VecDeque<(NotifyId, Arc, Waker)>, token: u64, next_id: u64, } pub(super) struct Listener<'a> { state: &'a Mutex, waker: Waker, woken: Arc, id: NotifyId, } impl Notify { pub(super) fn new() -> Self { metrics::counter!("async-cpupool.notify.created").increment(1); Notify { state: Mutex::new(NotifyState { listeners: VecDeque::new(), token: 0, next_id: 0, }), } } // although this is an async fn, it is not capable of yielding to the executor pub(super) async fn listen(&self) -> Listener<'_> { poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await } pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> { let (id, woken) = self .state .lock() .expect("not poisoned") .insert(waker.clone()); Listener { state: &self.state, waker, woken, id, } } pub(super) fn notify_one(&self) { self.state.lock().expect("not poisoned").notify_one(); } } impl NotifyState { fn notify_one(&mut self) { loop { if let Some((_, woken, waker)) = self.listeners.pop_front() { metrics::counter!("async-cpupool.notify.removed").increment(1); // can't use "weak" because we need to know failure means we failed fr fr to avoid // popping unwoken listeners match woken.compare_exchange( UNNOTIFIED, NOTIFIED_ONE, Ordering::Release, Ordering::Relaxed, ) { Ok(_) => waker.wake(), // if this listener isn't unnotified (races Listener Drop) we should wake the // next one Err(_) => continue, } } else { self.token += 1; } break; } } fn insert(&mut self, waker: Waker) -> (NotifyId, Arc) { let id = NotifyId(self.next_id); self.next_id += 1; let token = if self.token > 0 { self.token -= 1; true } else { false }; // don't insert waker if token is true - next poll will be ready let woken = if token { Arc::new(AtomicU8::new(NOTIFIED_ONE)) } else { let woken = Arc::new(AtomicU8::new(UNNOTIFIED)); self.listeners.push_back((id, Arc::clone(&woken), waker)); metrics::counter!("async-cpupool.notify.inserted").increment(1); woken }; (id, woken) } fn remove(&mut self, id: NotifyId) { if let Some(index) = self.find(&id) { self.listeners.remove(index); metrics::counter!("async-cpupool.notify.removed").increment(1); } } fn find(&self, needle_id: &NotifyId) -> Option { self.listeners .binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id) .ok() } fn update(&mut self, id: NotifyId, needle_waker: &Waker) { if let Some(index) = self.find(&id) { if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) { if !needle_waker.will_wake(haystack_waker) { *haystack_waker = needle_waker.clone(); } } } } } impl Drop for Notify { fn drop(&mut self) { metrics::counter!("async-cpupool.notify.dropped").increment(1); } } impl Drop for Listener<'_> { fn drop(&mut self) { // races compare_exchange in notify_one let flags = self.woken.swap(RESOLVED, Ordering::AcqRel); if flags == RESOLVED { return; } else if flags == NOTIFIED_ONE { let mut guard = self.state.lock().expect("not poisoned"); guard.notify_one(); } else if flags == UNNOTIFIED { let mut guard = self.state.lock().expect("not poisoned"); guard.remove(self.id); } else { unreachable!("No other states exist") } } } impl Future for Listener<'_> { type Output = (); fn poll( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll { let mut flags = self.woken.load(Ordering::Acquire); loop { if flags == UNNOTIFIED { break; } else if flags == RESOLVED { return Poll::Ready(()); } else { match self.woken.compare_exchange_weak( flags, RESOLVED, Ordering::Release, Ordering::Acquire, ) { Ok(_) => return Poll::Ready(()), Err(updated) => flags = updated, }; } } if !self.waker.will_wake(cx.waker()) { self.waker = cx.waker().clone(); self.state .lock() .expect("not poisoned") .update(self.id, cx.waker()); } Poll::Pending } }