async-cpupool/src/notify.rs

247 lines
6.4 KiB
Rust

use std::{
collections::VecDeque,
future::{poll_fn, Future},
task::{Poll, Waker},
};
use crate::sync::{Arc, AtomicU8, Mutex, Ordering};
thread_local! {
#[cfg(any(loom, test))]
static NOTIFY_COUNT: std::cell::RefCell<std::num::Wrapping<u64>> = std::cell::RefCell::new(std::num::Wrapping(0));
}
#[inline(always)]
fn increment_notify_count() {
#[cfg(any(loom, test))]
NOTIFY_COUNT.with_borrow_mut(|v| *v += 1);
}
#[inline(always)]
fn decrement_notify_count() {
#[cfg(any(loom, test))]
NOTIFY_COUNT.with_borrow_mut(|v| *v -= 1);
}
#[cfg(any(test, loom))]
#[doc(hidden)]
pub fn notify_count() -> u64 {
NOTIFY_COUNT.with_borrow(|v| v.0)
}
const UNNOTIFIED: u8 = 0b0000;
const NOTIFIED_ONE: u8 = 0b0001;
const RESOLVED: u8 = 0b0010;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct NotifyId(u64);
#[doc(hidden)]
pub struct Notify {
state: Mutex<NotifyState>,
}
struct NotifyState {
listeners: VecDeque<(NotifyId, Arc<AtomicU8>, Waker)>,
token: u64,
next_id: u64,
}
#[doc(hidden)]
pub struct Listener<'a> {
state: &'a Mutex<NotifyState>,
waker: Waker,
woken: Arc<AtomicU8>,
id: NotifyId,
}
impl Notify {
#[doc(hidden)]
pub fn new() -> Self {
increment_notify_count();
metrics::counter!("async-cpupool.notify.created").increment(1);
Notify {
state: Mutex::new(NotifyState {
listeners: VecDeque::new(),
token: 0,
next_id: 0,
}),
}
}
// although this is an async fn, it is not capable of yielding to the executor
#[doc(hidden)]
pub async fn listen(&self) -> Listener<'_> {
poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await
}
#[doc(hidden)]
#[cfg(any(loom, test))]
pub fn token(&self) -> u64 {
self.state.lock().unwrap().token
}
pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> {
let (id, woken) = self
.state
.lock()
.expect("not poisoned")
.insert(waker.clone());
Listener {
state: &self.state,
waker,
woken,
id,
}
}
#[doc(hidden)]
pub fn notify_one(&self) {
self.state.lock().expect("not poisoned").notify_one();
}
}
impl NotifyState {
fn notify_one(&mut self) {
loop {
if let Some((_, woken, waker)) = self.listeners.pop_front() {
metrics::counter!("async-cpupool.notify.removed").increment(1);
// can't use "weak" because we need to know failure means we failed fr fr to avoid
// popping unwoken listeners
match woken.compare_exchange(
UNNOTIFIED,
NOTIFIED_ONE,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => waker.wake(),
// if this listener isn't unnotified (races Listener Drop) we should wake the
// next one
Err(_) => continue,
}
} else {
self.token += 1;
}
break;
}
}
fn insert(&mut self, waker: Waker) -> (NotifyId, Arc<AtomicU8>) {
let id = NotifyId(self.next_id);
self.next_id += 1;
let token = if self.token > 0 {
self.token -= 1;
true
} else {
false
};
// don't insert waker if token is true - next poll will be ready
let woken = if token {
Arc::new(AtomicU8::new(NOTIFIED_ONE))
} else {
let woken = Arc::new(AtomicU8::new(UNNOTIFIED));
self.listeners.push_back((id, Arc::clone(&woken), waker));
metrics::counter!("async-cpupool.notify.inserted").increment(1);
woken
};
(id, woken)
}
fn remove(&mut self, id: NotifyId) {
if let Some(index) = self.find(&id) {
self.listeners.remove(index);
metrics::counter!("async-cpupool.notify.removed").increment(1);
}
}
fn find(&self, needle_id: &NotifyId) -> Option<usize> {
self.listeners
.binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id)
.ok()
}
fn update(&mut self, id: NotifyId, needle_waker: &Waker) {
if let Some(index) = self.find(&id) {
if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) {
if !needle_waker.will_wake(haystack_waker) {
*haystack_waker = needle_waker.clone();
}
}
}
}
}
impl Drop for Notify {
fn drop(&mut self) {
decrement_notify_count();
metrics::counter!("async-cpupool.notify.dropped").increment(1);
}
}
impl Drop for Listener<'_> {
fn drop(&mut self) {
// races compare_exchange in notify_one
let flags = self.woken.swap(RESOLVED, Ordering::AcqRel);
if flags == RESOLVED {
// do nothing
} else if flags == NOTIFIED_ONE {
let mut guard = self.state.lock().expect("not poisoned");
guard.notify_one();
} else if flags == UNNOTIFIED {
let mut guard = self.state.lock().expect("not poisoned");
guard.remove(self.id);
} else {
unreachable!("No other states exist")
}
}
}
impl Future for Listener<'_> {
type Output = ();
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
let mut flags = self.woken.load(Ordering::Acquire);
loop {
if flags == UNNOTIFIED {
break;
} else if flags == RESOLVED {
return Poll::Ready(());
} else {
match self.woken.compare_exchange_weak(
flags,
RESOLVED,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => return Poll::Ready(()),
Err(updated) => flags = updated,
};
}
}
if !self.waker.will_wake(cx.waker()) {
self.waker = cx.waker().clone();
self.state
.lock()
.expect("not poisoned")
.update(self.id, cx.waker());
}
Poll::Pending
}
}