async-cpupool/src/notify.rs
2024-04-15 17:59:13 -05:00

421 lines
12 KiB
Rust

use std::{
collections::VecDeque,
future::{poll_fn, Future},
task::{Poll, Waker},
};
use crate::sync::{Arc, AtomicU8, Mutex, Ordering};
const UNNOTIFIED: u8 = 0b0000;
const NOTIFIED_ONE: u8 = 0b0001;
const RESOLVED: u8 = 0b0010;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct NotifyId(u64);
pub(super) struct Notify {
state: Mutex<NotifyState>,
}
struct NotifyState {
listeners: VecDeque<(NotifyId, Arc<AtomicU8>, Waker)>,
token: u64,
next_id: u64,
}
pub(super) struct Listener<'a> {
state: &'a Mutex<NotifyState>,
waker: Waker,
woken: Arc<AtomicU8>,
id: NotifyId,
}
impl Notify {
pub(super) fn new() -> Self {
metrics::counter!("async-cpupool.notify.created").increment(1);
Notify {
state: Mutex::new(NotifyState {
listeners: VecDeque::new(),
token: 0,
next_id: 0,
}),
}
}
// although this is an async fn, it is not capable of yielding to the executor
pub(super) async fn listen(&self) -> Listener<'_> {
poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await
}
pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> {
let (id, woken) = self
.state
.lock()
.expect("not poisoned")
.insert(waker.clone());
Listener {
state: &self.state,
waker,
woken,
id,
}
}
pub(super) fn notify_one(&self) {
self.state.lock().expect("not poisoned").notify_one();
}
}
impl NotifyState {
fn notify_one(&mut self) {
loop {
if let Some((_, woken, waker)) = self.listeners.pop_front() {
metrics::counter!("async-cpupool.notify.removed").increment(1);
// can't use "weak" because we need to know failure means we failed fr fr to avoid
// popping unwoken listeners
match woken.compare_exchange(
UNNOTIFIED,
NOTIFIED_ONE,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => waker.wake(),
// if this listener isn't unnotified (races Listener Drop) we should wake the
// next one
Err(_) => continue,
}
} else {
self.token += 1;
}
break;
}
}
fn insert(&mut self, waker: Waker) -> (NotifyId, Arc<AtomicU8>) {
let id = NotifyId(self.next_id);
self.next_id += 1;
let token = if self.token > 0 {
self.token -= 1;
true
} else {
false
};
// don't insert waker if token is true - next poll will be ready
let woken = if token {
Arc::new(AtomicU8::new(NOTIFIED_ONE))
} else {
let woken = Arc::new(AtomicU8::new(UNNOTIFIED));
self.listeners.push_back((id, Arc::clone(&woken), waker));
metrics::counter!("async-cpupool.notify.inserted").increment(1);
woken
};
(id, woken)
}
fn remove(&mut self, id: NotifyId) {
if let Some(index) = self.find(&id) {
self.listeners.remove(index);
metrics::counter!("async-cpupool.notify.removed").increment(1);
}
}
fn find(&self, needle_id: &NotifyId) -> Option<usize> {
self.listeners
.binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id)
.ok()
}
fn update(&mut self, id: NotifyId, needle_waker: &Waker) {
if let Some(index) = self.find(&id) {
if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) {
if !needle_waker.will_wake(haystack_waker) {
*haystack_waker = needle_waker.clone();
}
}
}
}
}
impl Drop for Notify {
fn drop(&mut self) {
metrics::counter!("async-cpupool.notify.dropped").increment(1);
}
}
impl Drop for Listener<'_> {
fn drop(&mut self) {
// races compare_exchange in notify_one
let flags = self.woken.swap(RESOLVED, Ordering::AcqRel);
if flags == RESOLVED {
// do nothing
} else if flags == NOTIFIED_ONE {
let mut guard = self.state.lock().expect("not poisoned");
guard.notify_one();
} else if flags == UNNOTIFIED {
let mut guard = self.state.lock().expect("not poisoned");
guard.remove(self.id);
} else {
unreachable!("No other states exist")
}
}
}
impl Future for Listener<'_> {
type Output = ();
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
let mut flags = self.woken.load(Ordering::Acquire);
loop {
if flags == UNNOTIFIED {
break;
} else if flags == RESOLVED {
return Poll::Ready(());
} else {
match self.woken.compare_exchange_weak(
flags,
RESOLVED,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => return Poll::Ready(()),
Err(updated) => flags = updated,
};
}
}
if !self.waker.will_wake(cx.waker()) {
self.waker = cx.waker().clone();
self.state
.lock()
.expect("not poisoned")
.update(self.id, cx.waker());
}
Poll::Pending
}
}
#[cfg(all(test, loom))]
mod tests {
use super::Notify;
use std::future::Future;
struct NoopWaker;
impl std::task::Wake for NoopWaker {
fn wake(self: std::sync::Arc<Self>) {}
fn wake_by_ref(self: &std::sync::Arc<Self>) {}
}
fn noop_waker() -> std::task::Waker {
std::sync::Arc::new(NoopWaker).into()
}
#[test]
fn dropped_notified_listener() {
loom::model(|| {
loom::future::block_on(async {
let notify = Notify::new();
let listener = notify.listen().await;
notify.notify_one();
drop(listener);
assert_eq!(
notify.state.lock().unwrap().token,
1,
"Dropped notify should not have consumed token"
);
});
});
}
#[test]
fn threaded_dropped_notified_listener() {
loom::model(|| {
let notify = loom::sync::Arc::new(Notify::new());
let notify2 = notify.clone();
let handle = loom::thread::spawn(move || {
loom::future::block_on(async move {
drop(notify2.listen().await);
});
});
notify.notify_one();
handle.join().unwrap();
assert_eq!(
notify.state.lock().unwrap().token,
1,
"Dropped notify should not have consumed token"
);
});
}
#[test]
fn notified_listener() {
loom::model(|| {
loom::future::block_on(async {
let notify = Notify::new();
let mut listener = notify.listen().await;
notify.notify_one();
let waker = noop_waker();
let mut cx = std::task::Context::from_waker(&waker);
assert_eq!(
std::pin::Pin::new(&mut listener).poll(&mut cx),
std::task::Poll::Ready(()),
"Polled listen should be notified"
);
assert_eq!(
notify.state.lock().unwrap().token,
0,
"Dropped notify should have consumed token"
);
})
});
}
#[test]
fn threaded_notified_listener() {
loom::model(|| {
let notify = loom::sync::Arc::new(Notify::new());
let notify2 = notify.clone();
let handle = loom::thread::spawn(move || {
loom::future::block_on(async move {
notify2.listen().await.await;
});
});
notify.notify_one();
handle.join().unwrap();
assert_eq!(
notify.state.lock().unwrap().token,
0,
"Dropped notify should have consumed token"
);
});
}
#[test]
fn multiple_listeners() {
loom::model(|| {
loom::future::block_on(async {
let notify = Notify::new();
let mut listener_1 = notify.listen().await;
let mut listener_2 = notify.listen().await;
notify.notify_one();
let waker = noop_waker();
let mut cx = std::task::Context::from_waker(&waker);
assert_eq!(
std::pin::Pin::new(&mut listener_1).poll(&mut cx),
std::task::Poll::Ready(()),
"Polled listen_1 should be notified"
);
assert_eq!(
std::pin::Pin::new(&mut listener_2).poll(&mut cx),
std::task::Poll::Pending,
"Polled listen_2 should not be notified"
);
});
});
}
#[test]
fn multiple_notifies() {
loom::model(|| {
loom::future::block_on(async {
let notify = Notify::new();
let mut listener_1 = notify.listen().await;
let mut listener_2 = notify.listen().await;
notify.notify_one();
notify.notify_one();
let waker = noop_waker();
let mut cx = std::task::Context::from_waker(&waker);
assert_eq!(
std::pin::Pin::new(&mut listener_1).poll(&mut cx),
std::task::Poll::Ready(()),
"Polled listen_1 should be notified"
);
assert_eq!(
std::pin::Pin::new(&mut listener_2).poll(&mut cx),
std::task::Poll::Ready(()),
"Polled listen_2 should be notified"
);
assert_eq!(
notify.state.lock().unwrap().token,
0,
"notifies should have consumed tokens"
);
});
});
}
#[test]
fn threaded_multiple_notifies() {
loom::model(|| {
let notify = loom::sync::Arc::new(Notify::new());
let notify2 = notify.clone();
let handle1 = loom::thread::spawn(move || {
loom::future::block_on(async move {
notify2.listen().await.await;
})
});
let notify2 = notify.clone();
let handle2 = loom::thread::spawn(move || {
loom::future::block_on(async move {
notify2.listen().await.await;
})
});
notify.notify_one();
notify.notify_one();
handle1.join().unwrap();
handle2.join().unwrap();
assert_eq!(
notify.state.lock().unwrap().token,
0,
"threaded notifies should have consumed tokens"
);
});
}
}