Safe notifications
This commit is contained in:
commit
03e28bff03
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
/target
|
||||
Cargo.lock
|
11
Cargo.toml
Normal file
11
Cargo.toml
Normal file
|
@ -0,0 +1,11 @@
|
|||
[package]
|
||||
name = "notify"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
|
||||
[dev-dependencies]
|
||||
jive = { git = "https://git.asonix.dog/safe-async/jive" }
|
40
examples/jive.rs
Normal file
40
examples/jive.rs
Normal file
|
@ -0,0 +1,40 @@
|
|||
use std::time::Duration;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
jive::block_on(async move {
|
||||
let notify = notify::Notify::new();
|
||||
|
||||
println!("Queuing one notify");
|
||||
notify.notify_one();
|
||||
|
||||
let mut listener = notify.listener();
|
||||
|
||||
listener.listen().await;
|
||||
println!("immediate");
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..4 {
|
||||
let mut listener = notify.listener();
|
||||
|
||||
handles.push(jive::spawn(async move {
|
||||
for _ in 0..5 {
|
||||
listener.listen().await;
|
||||
println!("woken {}!", i);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
jive::spawn(async move {
|
||||
loop {
|
||||
jive::time::sleep(Duration::from_millis(200)).await;
|
||||
notify.notify_one();
|
||||
}
|
||||
});
|
||||
|
||||
for handle in handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
34
examples/thread.rs
Normal file
34
examples/thread.rs
Normal file
|
@ -0,0 +1,34 @@
|
|||
use std::time::Duration;
|
||||
|
||||
fn main() {
|
||||
let notify = notify::Notify::new();
|
||||
|
||||
println!("Queuing one notify");
|
||||
notify.notify_one();
|
||||
|
||||
let listener = notify.listener();
|
||||
|
||||
listener.listen_blocking();
|
||||
println!("immediate");
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..4 {
|
||||
let listener = notify.listener();
|
||||
|
||||
handles.push(std::thread::spawn(move || {
|
||||
for _ in 0..5 {
|
||||
listener.listen_blocking();
|
||||
println!("woken {}!", i);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
std::thread::spawn(move || loop {
|
||||
std::thread::sleep(Duration::from_millis(200));
|
||||
notify.notify_one();
|
||||
});
|
||||
|
||||
for handle in handles {
|
||||
let _ = handle.join();
|
||||
}
|
||||
}
|
162
src/lib.rs
Normal file
162
src/lib.rs
Normal file
|
@ -0,0 +1,162 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
task::{Context, Poll, Waker},
|
||||
thread::Thread,
|
||||
};
|
||||
|
||||
pub struct Notify {
|
||||
state: Arc<Mutex<NotifyState>>,
|
||||
}
|
||||
|
||||
struct NotifyState {
|
||||
available: usize,
|
||||
wakers: BTreeMap<usize, WakerOrThread>,
|
||||
next_id: usize,
|
||||
}
|
||||
|
||||
pub struct Listener {
|
||||
id: Option<usize>,
|
||||
state: Arc<Mutex<NotifyState>>,
|
||||
}
|
||||
|
||||
struct Listen<'a> {
|
||||
listener: &'a mut Listener,
|
||||
}
|
||||
|
||||
enum WakerOrThread {
|
||||
Waker(Waker),
|
||||
Thread(Thread),
|
||||
}
|
||||
|
||||
impl Notify {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(NotifyState {
|
||||
available: 0,
|
||||
wakers: BTreeMap::new(),
|
||||
next_id: 0,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn listener(&self) -> Listener {
|
||||
Listener {
|
||||
id: None,
|
||||
state: Arc::clone(&self.state),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn notify_one(&self) {
|
||||
self.notify(1);
|
||||
}
|
||||
|
||||
pub fn notify(&self, count: usize) {
|
||||
let mut guard = self.state.lock().unwrap();
|
||||
|
||||
let wakers = if let Some(key) = guard.wakers.keys().skip(count).copied().next() {
|
||||
let mut wakers = guard.wakers.split_off(&key);
|
||||
std::mem::swap(&mut guard.wakers, &mut wakers);
|
||||
wakers
|
||||
} else {
|
||||
std::mem::take(&mut guard.wakers)
|
||||
};
|
||||
|
||||
guard.available = guard.available.saturating_add(count);
|
||||
|
||||
for (_, waker) in wakers {
|
||||
match waker {
|
||||
WakerOrThread::Waker(waker) => waker.wake(),
|
||||
WakerOrThread::Thread(thread) => thread.unpark(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
pub async fn listen(&mut self) {
|
||||
Listen { listener: self }.await
|
||||
}
|
||||
|
||||
pub fn listen_blocking(&self) {
|
||||
loop {
|
||||
{
|
||||
let mut guard = self.state.lock().unwrap();
|
||||
|
||||
if guard.available > 0 {
|
||||
guard.available -= 1;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
let id = guard.next_id();
|
||||
guard
|
||||
.wakers
|
||||
.insert(id, WakerOrThread::Thread(std::thread::current()));
|
||||
}
|
||||
|
||||
std::thread::park();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NotifyState {
|
||||
fn next_id(&mut self) -> usize {
|
||||
let id = self.next_id;
|
||||
|
||||
self.next_id = self.next_id.wrapping_add(1);
|
||||
|
||||
while self.wakers.contains_key(&self.next_id) {
|
||||
self.next_id = self.next_id.wrapping_add(1);
|
||||
|
||||
if self.next_id == id {
|
||||
panic!("exausted usize space for ID generation");
|
||||
}
|
||||
}
|
||||
|
||||
id
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Listener {
|
||||
fn drop(&mut self) {
|
||||
if let Some(id) = self.id.take() {
|
||||
self.state.lock().unwrap().wakers.remove(&id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Notify {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Future for Listen<'a> {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut();
|
||||
let listener = &mut this.listener;
|
||||
let state = &listener.state;
|
||||
let listener_id = &mut listener.id;
|
||||
|
||||
let mut guard = state.lock().unwrap();
|
||||
|
||||
if guard.available > 0 {
|
||||
guard.available -= 1;
|
||||
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
let id = guard.next_id();
|
||||
*listener_id = Some(id);
|
||||
guard
|
||||
.wakers
|
||||
.insert(id, WakerOrThread::Waker(cx.waker().clone()));
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue