181 lines
3.9 KiB
Rust
181 lines
3.9 KiB
Rust
use std::{
|
|
collections::BTreeMap,
|
|
future::Future,
|
|
pin::Pin,
|
|
sync::{Arc, Mutex},
|
|
task::{Context, Poll, Waker},
|
|
thread::Thread,
|
|
};
|
|
|
|
pub struct Notify {
|
|
state: Arc<Mutex<NotifyState>>,
|
|
}
|
|
|
|
struct NotifyState {
|
|
available: usize,
|
|
wakers: BTreeMap<usize, WakerOrThread>,
|
|
next_id: usize,
|
|
}
|
|
|
|
pub struct Listener {
|
|
id: Option<usize>,
|
|
state: Arc<Mutex<NotifyState>>,
|
|
}
|
|
|
|
struct Listen<'a> {
|
|
listener: &'a mut Listener,
|
|
}
|
|
|
|
enum WakerOrThread {
|
|
Waker(Waker),
|
|
Thread(Thread),
|
|
}
|
|
|
|
impl Notify {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
state: Arc::new(Mutex::new(NotifyState {
|
|
available: 0,
|
|
wakers: BTreeMap::new(),
|
|
next_id: 0,
|
|
})),
|
|
}
|
|
}
|
|
|
|
pub fn listener(&self) -> Listener {
|
|
Listener {
|
|
id: None,
|
|
state: Arc::clone(&self.state),
|
|
}
|
|
}
|
|
|
|
pub fn notify_one(&self) {
|
|
self.notify(1);
|
|
}
|
|
|
|
pub fn notify(&self, count: usize) {
|
|
let mut guard = self.state.lock().unwrap();
|
|
|
|
let wakers = if let Some(key) = guard.wakers.keys().skip(count).copied().next() {
|
|
let mut wakers = guard.wakers.split_off(&key);
|
|
std::mem::swap(&mut guard.wakers, &mut wakers);
|
|
wakers
|
|
} else {
|
|
std::mem::take(&mut guard.wakers)
|
|
};
|
|
|
|
guard.available = guard.available.saturating_add(count);
|
|
|
|
for (_, waker) in wakers {
|
|
match waker {
|
|
WakerOrThread::Waker(waker) => waker.wake(),
|
|
WakerOrThread::Thread(thread) => thread.unpark(),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Listener {
|
|
pub fn consume_notifications(&self, count: usize) {
|
|
let mut guard = self.state.lock().unwrap();
|
|
guard.available = guard.available.saturating_sub(count);
|
|
}
|
|
|
|
pub fn consume_all_notifications(&self) {
|
|
self.consume_notifications(usize::MAX);
|
|
}
|
|
|
|
pub async fn listen(&mut self) {
|
|
Listen { listener: self }.await
|
|
}
|
|
|
|
pub fn listen_blocking(&self) {
|
|
loop {
|
|
{
|
|
let mut guard = self.state.lock().unwrap();
|
|
|
|
if guard.available > 0 {
|
|
guard.available -= 1;
|
|
|
|
return;
|
|
}
|
|
|
|
let id = guard.next_id();
|
|
guard
|
|
.wakers
|
|
.insert(id, WakerOrThread::Thread(std::thread::current()));
|
|
}
|
|
|
|
std::thread::park();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl NotifyState {
|
|
fn next_id(&mut self) -> usize {
|
|
let id = self.next_id;
|
|
|
|
self.next_id = self.next_id.wrapping_add(1);
|
|
|
|
while self.wakers.contains_key(&self.next_id) {
|
|
self.next_id = self.next_id.wrapping_add(1);
|
|
|
|
if self.next_id == id {
|
|
panic!("exausted usize space for ID generation");
|
|
}
|
|
}
|
|
|
|
id
|
|
}
|
|
}
|
|
|
|
impl Drop for Listener {
|
|
fn drop(&mut self) {
|
|
if let Some(id) = self.id.take() {
|
|
self.state.lock().unwrap().wakers.remove(&id);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for Notify {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl<'a> Future for Listen<'a> {
|
|
type Output = ();
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let mut this = self.as_mut();
|
|
let listener = &mut this.listener;
|
|
let state = &listener.state;
|
|
let listener_id = &mut listener.id;
|
|
|
|
let mut guard = state.lock().unwrap();
|
|
|
|
if guard.available > 0 {
|
|
guard.available -= 1;
|
|
|
|
return Poll::Ready(());
|
|
}
|
|
|
|
let id = guard.next_id();
|
|
*listener_id = Some(id);
|
|
guard
|
|
.wakers
|
|
.insert(id, WakerOrThread::Waker(cx.waker().clone()));
|
|
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
impl Clone for Listener {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
id: None,
|
|
state: Arc::clone(&self.state),
|
|
}
|
|
}
|
|
}
|