421 lines
12 KiB
Rust
421 lines
12 KiB
Rust
use std::{
|
|
collections::VecDeque,
|
|
future::{poll_fn, Future},
|
|
task::{Poll, Waker},
|
|
};
|
|
|
|
use crate::sync::{Arc, AtomicU8, Mutex, Ordering};
|
|
|
|
const UNNOTIFIED: u8 = 0b0000;
|
|
const NOTIFIED_ONE: u8 = 0b0001;
|
|
const RESOLVED: u8 = 0b0010;
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
|
struct NotifyId(u64);
|
|
|
|
pub(super) struct Notify {
|
|
state: Mutex<NotifyState>,
|
|
}
|
|
|
|
struct NotifyState {
|
|
listeners: VecDeque<(NotifyId, Arc<AtomicU8>, Waker)>,
|
|
token: u64,
|
|
next_id: u64,
|
|
}
|
|
|
|
pub(super) struct Listener<'a> {
|
|
state: &'a Mutex<NotifyState>,
|
|
waker: Waker,
|
|
woken: Arc<AtomicU8>,
|
|
id: NotifyId,
|
|
}
|
|
|
|
impl Notify {
|
|
pub(super) fn new() -> Self {
|
|
metrics::counter!("async-cpupool.notify.created").increment(1);
|
|
|
|
Notify {
|
|
state: Mutex::new(NotifyState {
|
|
listeners: VecDeque::new(),
|
|
token: 0,
|
|
next_id: 0,
|
|
}),
|
|
}
|
|
}
|
|
|
|
// although this is an async fn, it is not capable of yielding to the executor
|
|
pub(super) async fn listen(&self) -> Listener<'_> {
|
|
poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await
|
|
}
|
|
|
|
pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> {
|
|
let (id, woken) = self
|
|
.state
|
|
.lock()
|
|
.expect("not poisoned")
|
|
.insert(waker.clone());
|
|
|
|
Listener {
|
|
state: &self.state,
|
|
waker,
|
|
woken,
|
|
id,
|
|
}
|
|
}
|
|
|
|
pub(super) fn notify_one(&self) {
|
|
self.state.lock().expect("not poisoned").notify_one();
|
|
}
|
|
}
|
|
|
|
impl NotifyState {
|
|
fn notify_one(&mut self) {
|
|
loop {
|
|
if let Some((_, woken, waker)) = self.listeners.pop_front() {
|
|
metrics::counter!("async-cpupool.notify.removed").increment(1);
|
|
// can't use "weak" because we need to know failure means we failed fr fr to avoid
|
|
// popping unwoken listeners
|
|
match woken.compare_exchange(
|
|
UNNOTIFIED,
|
|
NOTIFIED_ONE,
|
|
Ordering::Release,
|
|
Ordering::Relaxed,
|
|
) {
|
|
Ok(_) => waker.wake(),
|
|
|
|
// if this listener isn't unnotified (races Listener Drop) we should wake the
|
|
// next one
|
|
Err(_) => continue,
|
|
}
|
|
} else {
|
|
self.token += 1;
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
fn insert(&mut self, waker: Waker) -> (NotifyId, Arc<AtomicU8>) {
|
|
let id = NotifyId(self.next_id);
|
|
self.next_id += 1;
|
|
|
|
let token = if self.token > 0 {
|
|
self.token -= 1;
|
|
true
|
|
} else {
|
|
false
|
|
};
|
|
|
|
// don't insert waker if token is true - next poll will be ready
|
|
let woken = if token {
|
|
Arc::new(AtomicU8::new(NOTIFIED_ONE))
|
|
} else {
|
|
let woken = Arc::new(AtomicU8::new(UNNOTIFIED));
|
|
self.listeners.push_back((id, Arc::clone(&woken), waker));
|
|
metrics::counter!("async-cpupool.notify.inserted").increment(1);
|
|
|
|
woken
|
|
};
|
|
|
|
(id, woken)
|
|
}
|
|
|
|
fn remove(&mut self, id: NotifyId) {
|
|
if let Some(index) = self.find(&id) {
|
|
self.listeners.remove(index);
|
|
metrics::counter!("async-cpupool.notify.removed").increment(1);
|
|
}
|
|
}
|
|
|
|
fn find(&self, needle_id: &NotifyId) -> Option<usize> {
|
|
self.listeners
|
|
.binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id)
|
|
.ok()
|
|
}
|
|
|
|
fn update(&mut self, id: NotifyId, needle_waker: &Waker) {
|
|
if let Some(index) = self.find(&id) {
|
|
if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) {
|
|
if !needle_waker.will_wake(haystack_waker) {
|
|
*haystack_waker = needle_waker.clone();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for Notify {
|
|
fn drop(&mut self) {
|
|
metrics::counter!("async-cpupool.notify.dropped").increment(1);
|
|
}
|
|
}
|
|
|
|
impl Drop for Listener<'_> {
|
|
fn drop(&mut self) {
|
|
// races compare_exchange in notify_one
|
|
let flags = self.woken.swap(RESOLVED, Ordering::AcqRel);
|
|
|
|
if flags == RESOLVED {
|
|
// do nothing
|
|
} else if flags == NOTIFIED_ONE {
|
|
let mut guard = self.state.lock().expect("not poisoned");
|
|
guard.notify_one();
|
|
} else if flags == UNNOTIFIED {
|
|
let mut guard = self.state.lock().expect("not poisoned");
|
|
guard.remove(self.id);
|
|
} else {
|
|
unreachable!("No other states exist")
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Future for Listener<'_> {
|
|
type Output = ();
|
|
|
|
fn poll(
|
|
mut self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Self::Output> {
|
|
let mut flags = self.woken.load(Ordering::Acquire);
|
|
|
|
loop {
|
|
if flags == UNNOTIFIED {
|
|
break;
|
|
} else if flags == RESOLVED {
|
|
return Poll::Ready(());
|
|
} else {
|
|
match self.woken.compare_exchange_weak(
|
|
flags,
|
|
RESOLVED,
|
|
Ordering::Release,
|
|
Ordering::Acquire,
|
|
) {
|
|
Ok(_) => return Poll::Ready(()),
|
|
Err(updated) => flags = updated,
|
|
};
|
|
}
|
|
}
|
|
|
|
if !self.waker.will_wake(cx.waker()) {
|
|
self.waker = cx.waker().clone();
|
|
|
|
self.state
|
|
.lock()
|
|
.expect("not poisoned")
|
|
.update(self.id, cx.waker());
|
|
}
|
|
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
#[cfg(all(test, loom))]
|
|
mod tests {
|
|
use super::Notify;
|
|
use std::future::Future;
|
|
|
|
struct NoopWaker;
|
|
|
|
impl std::task::Wake for NoopWaker {
|
|
fn wake(self: std::sync::Arc<Self>) {}
|
|
fn wake_by_ref(self: &std::sync::Arc<Self>) {}
|
|
}
|
|
|
|
fn noop_waker() -> std::task::Waker {
|
|
std::sync::Arc::new(NoopWaker).into()
|
|
}
|
|
|
|
#[test]
|
|
fn dropped_notified_listener() {
|
|
loom::model(|| {
|
|
loom::future::block_on(async {
|
|
let notify = Notify::new();
|
|
|
|
let listener = notify.listen().await;
|
|
|
|
notify.notify_one();
|
|
drop(listener);
|
|
|
|
assert_eq!(
|
|
notify.state.lock().unwrap().token,
|
|
1,
|
|
"Dropped notify should not have consumed token"
|
|
);
|
|
});
|
|
});
|
|
}
|
|
|
|
#[test]
|
|
fn threaded_dropped_notified_listener() {
|
|
loom::model(|| {
|
|
let notify = loom::sync::Arc::new(Notify::new());
|
|
|
|
let notify2 = notify.clone();
|
|
let handle = loom::thread::spawn(move || {
|
|
loom::future::block_on(async move {
|
|
drop(notify2.listen().await);
|
|
});
|
|
});
|
|
|
|
notify.notify_one();
|
|
|
|
handle.join().unwrap();
|
|
|
|
assert_eq!(
|
|
notify.state.lock().unwrap().token,
|
|
1,
|
|
"Dropped notify should not have consumed token"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[test]
|
|
fn notified_listener() {
|
|
loom::model(|| {
|
|
loom::future::block_on(async {
|
|
let notify = Notify::new();
|
|
|
|
let mut listener = notify.listen().await;
|
|
|
|
notify.notify_one();
|
|
|
|
let waker = noop_waker();
|
|
let mut cx = std::task::Context::from_waker(&waker);
|
|
|
|
assert_eq!(
|
|
std::pin::Pin::new(&mut listener).poll(&mut cx),
|
|
std::task::Poll::Ready(()),
|
|
"Polled listen should be notified"
|
|
);
|
|
assert_eq!(
|
|
notify.state.lock().unwrap().token,
|
|
0,
|
|
"Dropped notify should have consumed token"
|
|
);
|
|
})
|
|
});
|
|
}
|
|
|
|
#[test]
|
|
fn threaded_notified_listener() {
|
|
loom::model(|| {
|
|
let notify = loom::sync::Arc::new(Notify::new());
|
|
|
|
let notify2 = notify.clone();
|
|
let handle = loom::thread::spawn(move || {
|
|
loom::future::block_on(async move {
|
|
notify2.listen().await.await;
|
|
});
|
|
});
|
|
|
|
notify.notify_one();
|
|
|
|
handle.join().unwrap();
|
|
|
|
assert_eq!(
|
|
notify.state.lock().unwrap().token,
|
|
0,
|
|
"Dropped notify should have consumed token"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[test]
|
|
fn multiple_listeners() {
|
|
loom::model(|| {
|
|
loom::future::block_on(async {
|
|
let notify = Notify::new();
|
|
|
|
let mut listener_1 = notify.listen().await;
|
|
let mut listener_2 = notify.listen().await;
|
|
|
|
notify.notify_one();
|
|
|
|
let waker = noop_waker();
|
|
let mut cx = std::task::Context::from_waker(&waker);
|
|
|
|
assert_eq!(
|
|
std::pin::Pin::new(&mut listener_1).poll(&mut cx),
|
|
std::task::Poll::Ready(()),
|
|
"Polled listen_1 should be notified"
|
|
);
|
|
|
|
assert_eq!(
|
|
std::pin::Pin::new(&mut listener_2).poll(&mut cx),
|
|
std::task::Poll::Pending,
|
|
"Polled listen_2 should not be notified"
|
|
);
|
|
});
|
|
});
|
|
}
|
|
|
|
#[test]
|
|
fn multiple_notifies() {
|
|
loom::model(|| {
|
|
loom::future::block_on(async {
|
|
let notify = Notify::new();
|
|
|
|
let mut listener_1 = notify.listen().await;
|
|
let mut listener_2 = notify.listen().await;
|
|
|
|
notify.notify_one();
|
|
notify.notify_one();
|
|
|
|
let waker = noop_waker();
|
|
let mut cx = std::task::Context::from_waker(&waker);
|
|
|
|
assert_eq!(
|
|
std::pin::Pin::new(&mut listener_1).poll(&mut cx),
|
|
std::task::Poll::Ready(()),
|
|
"Polled listen_1 should be notified"
|
|
);
|
|
|
|
assert_eq!(
|
|
std::pin::Pin::new(&mut listener_2).poll(&mut cx),
|
|
std::task::Poll::Ready(()),
|
|
"Polled listen_2 should be notified"
|
|
);
|
|
|
|
assert_eq!(
|
|
notify.state.lock().unwrap().token,
|
|
0,
|
|
"notifies should have consumed tokens"
|
|
);
|
|
});
|
|
});
|
|
}
|
|
|
|
#[test]
|
|
fn threaded_multiple_notifies() {
|
|
loom::model(|| {
|
|
let notify = loom::sync::Arc::new(Notify::new());
|
|
|
|
let notify2 = notify.clone();
|
|
let handle1 = loom::thread::spawn(move || {
|
|
loom::future::block_on(async move {
|
|
notify2.listen().await.await;
|
|
})
|
|
});
|
|
|
|
let notify2 = notify.clone();
|
|
let handle2 = loom::thread::spawn(move || {
|
|
loom::future::block_on(async move {
|
|
notify2.listen().await.await;
|
|
})
|
|
});
|
|
|
|
notify.notify_one();
|
|
notify.notify_one();
|
|
|
|
handle1.join().unwrap();
|
|
handle2.join().unwrap();
|
|
|
|
assert_eq!(
|
|
notify.state.lock().unwrap().token,
|
|
0,
|
|
"threaded notifies should have consumed tokens"
|
|
);
|
|
});
|
|
}
|
|
}
|