Safe notifications

This commit is contained in:
Aode (Lion) 2022-02-24 19:52:50 -06:00
commit 03e28bff03
5 changed files with 249 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/target
Cargo.lock

11
Cargo.toml Normal file
View file

@ -0,0 +1,11 @@
[package]
name = "notify"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
[dev-dependencies]
jive = { git = "https://git.asonix.dog/safe-async/jive" }

40
examples/jive.rs Normal file
View file

@ -0,0 +1,40 @@
use std::time::Duration;
fn main() -> Result<(), Box<dyn std::error::Error>> {
jive::block_on(async move {
let notify = notify::Notify::new();
println!("Queuing one notify");
notify.notify_one();
let mut listener = notify.listener();
listener.listen().await;
println!("immediate");
let mut handles = Vec::new();
for i in 0..4 {
let mut listener = notify.listener();
handles.push(jive::spawn(async move {
for _ in 0..5 {
listener.listen().await;
println!("woken {}!", i);
}
}));
}
jive::spawn(async move {
loop {
jive::time::sleep(Duration::from_millis(200)).await;
notify.notify_one();
}
});
for handle in handles {
let _ = handle.await;
}
})?;
Ok(())
}

34
examples/thread.rs Normal file
View file

@ -0,0 +1,34 @@
use std::time::Duration;
fn main() {
let notify = notify::Notify::new();
println!("Queuing one notify");
notify.notify_one();
let listener = notify.listener();
listener.listen_blocking();
println!("immediate");
let mut handles = Vec::new();
for i in 0..4 {
let listener = notify.listener();
handles.push(std::thread::spawn(move || {
for _ in 0..5 {
listener.listen_blocking();
println!("woken {}!", i);
}
}));
}
std::thread::spawn(move || loop {
std::thread::sleep(Duration::from_millis(200));
notify.notify_one();
});
for handle in handles {
let _ = handle.join();
}
}

162
src/lib.rs Normal file
View file

@ -0,0 +1,162 @@
use std::{
collections::BTreeMap,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
thread::Thread,
};
pub struct Notify {
state: Arc<Mutex<NotifyState>>,
}
struct NotifyState {
available: usize,
wakers: BTreeMap<usize, WakerOrThread>,
next_id: usize,
}
pub struct Listener {
id: Option<usize>,
state: Arc<Mutex<NotifyState>>,
}
struct Listen<'a> {
listener: &'a mut Listener,
}
enum WakerOrThread {
Waker(Waker),
Thread(Thread),
}
impl Notify {
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(NotifyState {
available: 0,
wakers: BTreeMap::new(),
next_id: 0,
})),
}
}
pub fn listener(&self) -> Listener {
Listener {
id: None,
state: Arc::clone(&self.state),
}
}
pub fn notify_one(&self) {
self.notify(1);
}
pub fn notify(&self, count: usize) {
let mut guard = self.state.lock().unwrap();
let wakers = if let Some(key) = guard.wakers.keys().skip(count).copied().next() {
let mut wakers = guard.wakers.split_off(&key);
std::mem::swap(&mut guard.wakers, &mut wakers);
wakers
} else {
std::mem::take(&mut guard.wakers)
};
guard.available = guard.available.saturating_add(count);
for (_, waker) in wakers {
match waker {
WakerOrThread::Waker(waker) => waker.wake(),
WakerOrThread::Thread(thread) => thread.unpark(),
}
}
}
}
impl Listener {
pub async fn listen(&mut self) {
Listen { listener: self }.await
}
pub fn listen_blocking(&self) {
loop {
{
let mut guard = self.state.lock().unwrap();
if guard.available > 0 {
guard.available -= 1;
return;
}
let id = guard.next_id();
guard
.wakers
.insert(id, WakerOrThread::Thread(std::thread::current()));
}
std::thread::park();
}
}
}
impl NotifyState {
fn next_id(&mut self) -> usize {
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
while self.wakers.contains_key(&self.next_id) {
self.next_id = self.next_id.wrapping_add(1);
if self.next_id == id {
panic!("exausted usize space for ID generation");
}
}
id
}
}
impl Drop for Listener {
fn drop(&mut self) {
if let Some(id) = self.id.take() {
self.state.lock().unwrap().wakers.remove(&id);
}
}
}
impl Default for Notify {
fn default() -> Self {
Self::new()
}
}
impl<'a> Future for Listen<'a> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
let listener = &mut this.listener;
let state = &listener.state;
let listener_id = &mut listener.id;
let mut guard = state.lock().unwrap();
if guard.available > 0 {
guard.available -= 1;
return Poll::Ready(());
}
let id = guard.next_id();
*listener_id = Some(id);
guard
.wakers
.insert(id, WakerOrThread::Waker(cx.waker().clone()));
Poll::Pending
}
}