diff --git a/src/lib.rs b/src/lib.rs index 314a6d2..c1b63eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,10 @@ +mod selector; + use std::{ num::{NonZeroU16, NonZeroUsize}, - sync::{ - atomic::{AtomicBool, AtomicU64}, - Arc, Mutex, - }, + sync::{atomic::AtomicU64, Arc, Mutex}, thread::JoinHandle, - time::{Duration, Instant}, + time::Instant, }; #[derive(Debug)] @@ -178,9 +177,7 @@ impl CpuPool { let mut threads = state.take_threads(); for thread in &mut threads { - thread - .signal - .store(true, std::sync::atomic::Ordering::Release); + thread.signal.take(); } for thread in &mut threads { @@ -373,9 +370,7 @@ impl ThreadVec { impl Drop for ThreadVec { fn drop(&mut self) { for thread in &mut self.threads { - thread - .signal - .store(true, std::sync::atomic::Ordering::Release); + thread.signal.take(); } for thread in &mut self.threads { @@ -388,14 +383,13 @@ impl Drop for ThreadVec { struct Thread { handle: Option>, - signal: Arc, + signal: Option, closed: flume::Receiver<()>, } impl Thread { async fn reap(mut self) { - self.signal - .store(true, std::sync::atomic::Ordering::Release); + self.signal.take(); let _ = self.closed.recv_async().await; @@ -417,19 +411,19 @@ impl Drop for SendOnDrop { fn spawn(name: &'static str, id: u64, receiver: flume::Receiver) -> Thread { let (closed_tx, closed) = flume::bounded(1); + let (signal, signal_rx) = flume::bounded(1); - let signal = Arc::new(AtomicBool::new(false)); + let signal = SendOnDrop { sender: signal }; let closed_tx = SendOnDrop { sender: closed_tx }; - let signal2 = signal.clone(); let handle = std::thread::Builder::new() .name(format!("{name}-{id}")) - .spawn(move || run(name, id, receiver, signal2, closed_tx)) + .spawn(move || run(name, id, receiver, signal_rx, closed_tx)) .expect("Failed to spawn new thread"); Thread { handle: Some(handle), - signal, + signal: Some(signal), closed, } } @@ -471,22 +465,15 @@ fn run( name: &'static str, id: u64, receiver: flume::Receiver, - signal: Arc, + signal: flume::Receiver<()>, closed_tx: SendOnDrop, ) { let guard = MetricsGuard::guard(name, id); loop { - match receiver.recv_timeout(Duration::from_millis(50)) { - Ok(send_fn) => { - invoke_send_fn(name, id, send_fn); - } - Err(flume::RecvTimeoutError::Disconnected) => break, - Err(flume::RecvTimeoutError::Timeout) => { - if signal.load(std::sync::atomic::Ordering::Acquire) { - break; - } - } + match selector::blocking_select(signal.recv_async(), receiver.recv_async()) { + selector::Either::Left(_) | selector::Either::Right(Err(_)) => break, + selector::Either::Right(Ok(send_fn)) => invoke_send_fn(name, id, send_fn), } } diff --git a/src/selector.rs b/src/selector.rs new file mode 100644 index 0000000..5599c9d --- /dev/null +++ b/src/selector.rs @@ -0,0 +1,146 @@ +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll, Wake, Waker}, +}; + +struct ThreadWaker { + thread: std::thread::Thread, +} + +impl Wake for ThreadWaker { + fn wake(self: Arc) { + self.thread.unpark(); + } + + fn wake_by_ref(self: &Arc) { + self.thread.unpark(); + } +} + +pub(super) enum Either { + Left(L), + Right(R), +} + +struct Select { + left: F1, + left_woken: Arc, + + right: F2, + right_woken: Arc, +} + +struct SelectWaker { + inner: Waker, + flag: Arc, +} + +impl Wake for SelectWaker { + fn wake_by_ref(self: &Arc) { + self.flag.store(true, Ordering::Release); + + self.inner.wake_by_ref(); + } + + fn wake(self: Arc) { + self.flag.store(true, Ordering::Release); + + match Arc::try_unwrap(self) { + Ok(this) => this.inner.wake(), + Err(this) => this.inner.wake_by_ref(), + } + } +} + +impl Future for Select +where + F1: Future + Unpin, + F2: Future + Unpin, +{ + type Output = Either; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let left_waker = Arc::new(SelectWaker { + inner: cx.waker().clone(), + flag: self.left_woken.clone(), + }) + .into(); + + let mut left_ctx = Context::from_waker(&left_waker); + + if let Poll::Ready(left_out) = Pin::new(&mut self.left).poll(&mut left_ctx) { + return Poll::Ready(Either::Left(left_out)); + } + + let right_waker = Arc::new(SelectWaker { + inner: cx.waker().clone(), + flag: self.right_woken.clone(), + }) + .into(); + + let mut right_ctx = Context::from_waker(&right_waker); + + if let Poll::Ready(right_out) = Pin::new(&mut self.right).poll(&mut right_ctx) { + return Poll::Ready(Either::Right(right_out)); + } + + Poll::Pending + } +} + +pub(super) fn blocking_select( + left: Left, + right: Right, +) -> Either +where + Left: Future, + Right: Future, +{ + block_on(select(left, right)) +} + +fn block_on(fut: F) -> F::Output +where + F: Future, +{ + let thread_waker = Arc::new(ThreadWaker { + thread: std::thread::current(), + }) + .into(); + + let mut ctx = Context::from_waker(&thread_waker); + + let mut fut = std::pin::pin!(fut); + + loop { + if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) { + return out; + } + + // doesn't race - unpark followed by park will result in park returning immediately + std::thread::park(); + } +} + +async fn select(left: Left, right: Right) -> Either +where + Left: Future, + Right: Future, +{ + let left = std::pin::pin!(left); + let right = std::pin::pin!(right); + + Select { + left, + left_woken: Arc::new(AtomicBool::new(true)), + + right, + right_woken: Arc::new(AtomicBool::new(true)), + } + .await +}