From 9c29ab9fc2be1041cd614a9305a76f02aa790005 Mon Sep 17 00:00:00 2001 From: asonix Date: Sat, 13 Apr 2024 12:13:26 -0500 Subject: [PATCH] Build async queue, spsc to replace use of flume channels --- Cargo.toml | 1 - src/drop_notifier.rs | 34 +++++++ src/executor.rs | 42 +++++++++ src/lib.rs | 152 ++++++++++++++++--------------- src/notify.rs | 208 +++++++++++++++++++++++++++++++++++++++++++ src/queue.rs | 110 +++++++++++++++++++++++ src/selector.rs | 47 +--------- src/spsc.rs | 64 +++++++++++++ 8 files changed, 534 insertions(+), 124 deletions(-) create mode 100644 src/drop_notifier.rs create mode 100644 src/executor.rs create mode 100644 src/notify.rs create mode 100644 src/queue.rs create mode 100644 src/spsc.rs diff --git a/Cargo.toml b/Cargo.toml index 0e218ae..b32fe72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -flume = "0.11.0" metrics = "0.22.0" tracing = "0.1.40" diff --git a/src/drop_notifier.rs b/src/drop_notifier.rs new file mode 100644 index 0000000..8326d13 --- /dev/null +++ b/src/drop_notifier.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use crate::notify::Notify; + +pub(super) fn notifier() -> (DropNotifier, DropListener) { + let notify = Arc::new(Notify::new()); + + ( + DropNotifier { + notify: Arc::clone(¬ify), + }, + DropListener { notify }, + ) +} + +pub(super) struct DropNotifier { + notify: Arc, +} + +pub(super) struct DropListener { + notify: Arc, +} + +impl DropListener { + pub(super) async fn listen(self) { + self.notify.listen().await.await + } +} + +impl Drop for DropNotifier { + fn drop(&mut self) { + self.notify.notify_one(); + } +} diff --git a/src/executor.rs b/src/executor.rs new file mode 100644 index 0000000..ba591a4 --- /dev/null +++ b/src/executor.rs @@ -0,0 +1,42 @@ +use std::{ + future::Future, + sync::Arc, + task::{Context, Poll, Wake}, +}; + +struct ThreadWaker { + thread: std::thread::Thread, +} + +impl Wake for ThreadWaker { + fn wake(self: Arc) { + self.thread.unpark(); + } + + fn wake_by_ref(self: &Arc) { + self.thread.unpark(); + } +} + +pub(super) fn block_on(fut: F) -> F::Output +where + F: Future, +{ + let thread_waker = Arc::new(ThreadWaker { + thread: std::thread::current(), + }) + .into(); + + let mut ctx = Context::from_waker(&thread_waker); + + let mut fut = std::pin::pin!(fut); + + loop { + if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) { + return out; + } + + // doesn't race - unpark followed by park will result in park returning immediately + std::thread::park(); + } +} diff --git a/src/lib.rs b/src/lib.rs index df7e594..96fac7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,26 @@ #![doc = include_str!("../README.md")] #![deny(missing_docs)] +mod drop_notifier; +mod executor; +mod notify; +mod queue; mod selector; +mod spsc; use std::{ + future::Future, num::{NonZeroU16, NonZeroUsize}, sync::{atomic::AtomicU64, Arc, Mutex}, thread::JoinHandle, time::Instant, }; +use drop_notifier::{DropListener, DropNotifier}; +use executor::block_on; +use queue::Queue; +use selector::select; + /// Configuration builder for the CpuPool #[derive(Debug)] pub struct Config { @@ -240,62 +251,67 @@ impl CpuPool { /// # }) /// # } /// ``` - pub async fn spawn(&self, send_fn: F) -> Result + pub fn spawn(&self, send_fn: F) -> impl Future> + '_ where F: FnOnce() -> T + Send + 'static, T: Send + 'static, { - let (response_tx, response) = flume::bounded(1); + let (response_tx, response_rx) = spsc::channel(); let send_fn = Box::new(move || { let output = (send_fn)(); - if response_tx.send(output).is_err() { - tracing::trace!("Receiver hung up"); + match response_tx.blocking_send(output) { + Ok(()) => (), // sent + Err(Canceled) => tracing::warn!("receiver hung up"), } }); - let res = self.state.sender.try_send(send_fn); + let opt = self.state.queue.try_push(send_fn); let current_threads = self .state .current_threads .load(std::sync::atomic::Ordering::Acquire); - if (self.state.sender_len() > current_threads || self.state.sender.is_full()) - && self.push_thread() - { + let pushed = match self.state.queue.is_full_or() { + Ok(()) => self.push_thread(), + Err(len) if len > current_threads as usize => self.push_thread(), + Err(_) => false, + }; + + if pushed { tracing::trace!("Pushed thread"); } - match res { - Err(flume::TrySendError::Full(item)) => { - if self.state.sender.send_async(item).await.is_err() { - return Err(Canceled); + async { + match opt { + Some(item) => self.state.queue.push(item).await, + None => {} + } + + let current_threads = self + .state + .current_threads + .load(std::sync::atomic::Ordering::Acquire); + + match self.state.queue.is_full_or() { + Ok(()) => { + self.push_thread(); } + Err(len) if len > current_threads as usize => { + self.push_thread(); + } + Err(len) if len < current_threads.ilog2() as usize => { + if let Some(thread) = self.pop_thread() { + thread.reap().await; + } + } + Err(_) => {} } - Err(flume::TrySendError::Disconnected(_)) => { - return Err(Canceled); - } - Ok(()) => {} + + response_rx.recv().await } - - let current_threads = self - .state - .current_threads - .load(std::sync::atomic::Ordering::Acquire); - - let sender_len = self.state.sender_len(); - - if sender_len > current_threads || self.state.sender.is_full() { - self.push_thread(); - } else if sender_len < u64::from(current_threads.ilog2()) { - if let Some(thread) = self.pop_thread() { - thread.reap().await; - } - } - - response.into_recv_async().await.map_err(|_| Canceled) } /// Attempt to close the CpuPool @@ -330,8 +346,8 @@ impl CpuPool { thread.signal.take(); } - for thread in &mut threads { - let _ = thread.closed.recv_async().await; + for mut thread in threads { + thread.closed.listen().await; if let Some(handle) = thread.handle.take() { handle.join().expect("Thread panicked"); @@ -376,7 +392,7 @@ impl CpuPool { .thread_id .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let thread = spawn(self.state.name, thread_id, self.state.receiver.clone()); + let thread = spawn(self.state.name, thread_id, self.state.queue.clone()); self.state .threads @@ -439,8 +455,7 @@ struct CpuPoolState { max_threads: NonZeroU16, current_threads: AtomicU64, thread_id: AtomicU64, - sender: flume::Sender, - receiver: flume::Receiver, + queue: Queue, threads: Mutex, } @@ -453,13 +468,12 @@ impl CpuPoolState { ) -> Self { let thread_capacity = usize::from(u16::from(max_threads)); - let (sender, receiver) = - flume::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity)); + let queue = queue::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity)); let start_threads = u64::from(u16::from(min_threads)); let threads = ThreadVec::new(start_threads, thread_capacity, |i| { - spawn(name, i, receiver.clone()) + spawn(name, i, queue.clone()) }); let current_threads = AtomicU64::new(start_threads); @@ -471,8 +485,7 @@ impl CpuPoolState { max_threads, current_threads, thread_id, - sender, - receiver, + queue, threads: Mutex::new(threads), } } @@ -480,10 +493,6 @@ impl CpuPoolState { fn take_threads(&mut self) -> Vec { self.threads.lock().expect("threads lock poison").take() } - - fn sender_len(&self) -> u64 { - u64::try_from(self.sender.len()).unwrap_or(u64::MAX) - } } impl std::fmt::Debug for CpuPoolState { @@ -543,15 +552,15 @@ impl Drop for ThreadVec { struct Thread { handle: Option>, - signal: Option, - closed: flume::Receiver<()>, + signal: Option, + closed: DropListener, } impl Thread { async fn reap(mut self) { self.signal.take(); - let _ = self.closed.recv_async().await; + self.closed.listen().await; if let Some(handle) = self.handle.take() { handle.join().expect("Thread panicked"); @@ -559,32 +568,19 @@ impl Thread { } } -struct SendOnDrop { - sender: flume::Sender<()>, -} - -impl Drop for SendOnDrop { - fn drop(&mut self) { - let _ = self.sender.try_send(()); - } -} - -fn spawn(name: &'static str, id: u64, receiver: flume::Receiver) -> Thread { - let (closed_tx, closed) = flume::bounded(1); - let (signal, signal_rx) = flume::bounded(1); - - let signal = SendOnDrop { sender: signal }; - let closed_tx = SendOnDrop { sender: closed_tx }; +fn spawn(name: &'static str, id: u64, receiver: Queue) -> Thread { + let (closed_notifier, closed_listener) = drop_notifier::notifier(); + let (signal_notifier, signal_listener) = drop_notifier::notifier(); let handle = std::thread::Builder::new() .name(format!("{name}-{id}")) - .spawn(move || run(name, id, receiver, signal_rx, closed_tx)) + .spawn(move || run(name, id, receiver, signal_listener, closed_notifier)) .expect("Failed to spawn new thread"); Thread { handle: Some(handle), - signal: Some(signal), - closed, + signal: Some(signal_notifier), + closed: closed_listener, } } @@ -616,7 +612,7 @@ impl MetricsGuard { impl Drop for MetricsGuard { fn drop(&mut self) { metrics::counter!(format!("async-cpupool.{}.closed", self.name), "clean" => (!self.armed).to_string()).increment(1); - metrics::histogram!(format!("async-cpupool.{}.duration", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64()); + metrics::histogram!(format!("async-cpupool.{}.seconds", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64()); tracing::trace!("Stopping {}-{}", self.name, self.id); } } @@ -624,16 +620,18 @@ impl Drop for MetricsGuard { fn run( name: &'static str, id: u64, - receiver: flume::Receiver, - signal: flume::Receiver<()>, - closed_tx: SendOnDrop, + receiver: Queue, + signal: DropListener, + closed_tx: DropNotifier, ) { let guard = MetricsGuard::guard(name, id); + let mut signal = std::pin::pin!(signal.listen()); + loop { - match selector::blocking_select(signal.recv_async(), receiver.recv_async()) { - selector::Either::Left(_) | selector::Either::Right(Err(_)) => break, - selector::Either::Right(Ok(send_fn)) => invoke_send_fn(name, send_fn), + match block_on(select(&mut signal, receiver.pop())) { + selector::Either::Left(_) => break, + selector::Either::Right(send_fn) => invoke_send_fn(name, send_fn), } } @@ -651,7 +649,7 @@ fn invoke_send_fn(name: &'static str, send_fn: SendFn) { })); metrics::counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string()).increment(1); - metrics::histogram!(format!("async-cpupool.{name}.operation.duration"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64()); + metrics::histogram!(format!("async-cpupool.{name}.operation.seconds"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64()); if let Err(e) = res { tracing::trace!("panic in spawned task: {e:?}"); diff --git a/src/notify.rs b/src/notify.rs new file mode 100644 index 0000000..36ac45a --- /dev/null +++ b/src/notify.rs @@ -0,0 +1,208 @@ +use std::{ + collections::VecDeque, + future::{poll_fn, Future}, + sync::{ + atomic::{AtomicU8, Ordering}, + Arc, Mutex, + }, + task::{Poll, Waker}, +}; + +const UNNOTIFIED: u8 = 0b0000; +const NOTIFIED_ONE: u8 = 0b0001; +const RESOLVED: u8 = 0b0010; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct NotifyId(u64); + +pub(super) struct Notify { + state: Mutex, +} + +struct NotifyState { + listeners: VecDeque<(NotifyId, Arc, Waker)>, + token: u64, + next_id: u64, +} + +pub(super) struct Listener<'a> { + state: &'a Mutex, + waker: Waker, + woken: Arc, + id: NotifyId, +} + +impl Notify { + pub(super) fn new() -> Self { + metrics::counter!("async-cpupool.notify.created").increment(1); + + Notify { + state: Mutex::new(NotifyState { + listeners: VecDeque::new(), + token: 0, + next_id: 0, + }), + } + } + + // although this is an async fn, it is not capable of yielding to the executor + pub(super) async fn listen(&self) -> Listener<'_> { + poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await + } + + pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> { + let (id, woken) = self + .state + .lock() + .expect("not poisoned") + .insert(waker.clone()); + + Listener { + state: &self.state, + waker, + woken, + id, + } + } + + pub(super) fn notify_one(&self) { + self.state.lock().expect("not poisoned").notify_one(); + } +} + +impl NotifyState { + fn notify_one(&mut self) { + loop { + if let Some((_, woken, waker)) = self.listeners.pop_front() { + // can't use "weak" because we need to know failure means we failed fr fr to avoid + // popping unwoken listeners + match woken.compare_exchange( + UNNOTIFIED, + NOTIFIED_ONE, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => waker.wake(), + + // if this listener isn't unnotified (races Listener Drop) we should wake the + // next one + Err(_) => continue, + } + } else { + self.token += 1; + } + + break; + } + } + + fn insert(&mut self, waker: Waker) -> (NotifyId, Arc) { + let id = NotifyId(self.next_id); + self.next_id += 1; + + let token = if self.token > 0 { + self.token -= 1; + true + } else { + false + }; + + // don't insert waker if token is true - next poll will be ready + let woken = if token { + Arc::new(AtomicU8::new(NOTIFIED_ONE)) + } else { + let woken = Arc::new(AtomicU8::new(UNNOTIFIED)); + self.listeners.push_back((id, Arc::clone(&woken), waker)); + woken + }; + + (id, woken) + } + + fn remove(&mut self, id: NotifyId) { + if let Some(index) = self.find(&id) { + self.listeners.remove(index); + } + } + + fn find(&self, needle_id: &NotifyId) -> Option { + self.listeners + .binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id) + .ok() + } + + fn update(&mut self, id: NotifyId, needle_waker: &Waker) { + if let Some(index) = self.find(&id) { + if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) { + if !needle_waker.will_wake(haystack_waker) { + *haystack_waker = needle_waker.clone(); + } + } + } + } +} + +impl Drop for Notify { + fn drop(&mut self) { + metrics::counter!("async-cpupool.notify.dropped").increment(1); + } +} + +impl Drop for Listener<'_> { + fn drop(&mut self) { + // races compare_exchange in notify_one + let flags = self.woken.swap(RESOLVED, Ordering::AcqRel); + + if flags == RESOLVED { + return; + } else if flags == NOTIFIED_ONE { + let mut guard = self.state.lock().expect("not poisoned"); + guard.notify_one(); + } else if flags == UNNOTIFIED { + let mut guard = self.state.lock().expect("not poisoned"); + guard.remove(self.id); + } else { + unreachable!("No other states exist") + } + } +} + +impl Future for Listener<'_> { + type Output = (); + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll { + let mut flags = self.woken.load(Ordering::Acquire); + + loop { + if flags == UNNOTIFIED { + break; + } else if flags == RESOLVED { + return Poll::Ready(()); + } else { + match self.woken.compare_exchange_weak( + flags, + RESOLVED, + Ordering::Release, + Ordering::Acquire, + ) { + Ok(_) => return Poll::Ready(()), + Err(updated) => flags = updated, + }; + } + } + + if !self.waker.will_wake(cx.waker()) { + self.waker = cx.waker().clone(); + + self.state + .lock() + .expect("not poisoned") + .update(self.id, cx.waker()); + } + + Poll::Pending + } +} diff --git a/src/queue.rs b/src/queue.rs new file mode 100644 index 0000000..e1363d4 --- /dev/null +++ b/src/queue.rs @@ -0,0 +1,110 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use crate::notify::Notify; + +pub(super) fn bounded(capacity: usize) -> Queue { + Queue::bounded(capacity) +} + +pub(crate) struct Queue { + inner: Arc>, + capacity: usize, +} + +struct QueueState { + queue: Mutex>, + push_notify: Notify, + pop_notify: Notify, +} + +impl Queue { + pub(super) fn bounded(capacity: usize) -> Self { + Self { + inner: Arc::new(QueueState { + queue: Mutex::new(VecDeque::new()), + push_notify: Notify::new(), + pop_notify: Notify::new(), + }), + capacity, + } + } + + pub(super) fn len(&self) -> usize { + self.inner.queue.lock().expect("not poisoned").len() + } + + pub(super) fn is_full_or(&self) -> Result<(), usize> { + let len = self.len(); + + if len >= self.capacity { + Ok(()) + } else { + Err(len) + } + } + + pub(super) async fn push(&self, mut item: T) { + loop { + let listener = self.inner.push_notify.listen().await; + + if let Some(returned_item) = self.try_push_impl(item) { + item = returned_item; + + listener.await; + } else { + self.inner.pop_notify.notify_one(); + return; + } + } + } + + pub(super) fn try_push(&self, item: T) -> Option { + match self.try_push_impl(item) { + Some(item) => Some(item), + None => { + self.inner.pop_notify.notify_one(); + None + } + } + } + + fn try_push_impl(&self, item: T) -> Option { + let mut guard = self.inner.queue.lock().expect("not poisoned"); + + if self.capacity <= guard.len() { + Some(item) + } else { + guard.push_back(item); + None + } + } + + pub(super) async fn pop(&self) -> T { + loop { + let listener = self.inner.pop_notify.listen().await; + + if let Some(item) = self.try_pop() { + self.inner.push_notify.notify_one(); + return item; + } + + listener.await; + } + } + + fn try_pop(&self) -> Option { + self.inner.queue.lock().expect("not poisoned").pop_front() + } +} + +impl Clone for Queue { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + capacity: self.capacity, + } + } +} diff --git a/src/selector.rs b/src/selector.rs index 5599c9d..f181877 100644 --- a/src/selector.rs +++ b/src/selector.rs @@ -8,20 +8,6 @@ use std::{ task::{Context, Poll, Wake, Waker}, }; -struct ThreadWaker { - thread: std::thread::Thread, -} - -impl Wake for ThreadWaker { - fn wake(self: Arc) { - self.thread.unpark(); - } - - fn wake_by_ref(self: &Arc) { - self.thread.unpark(); - } -} - pub(super) enum Either { Left(L), Right(R), @@ -93,41 +79,10 @@ where } } -pub(super) fn blocking_select( +pub(super) async fn select( left: Left, right: Right, ) -> Either -where - Left: Future, - Right: Future, -{ - block_on(select(left, right)) -} - -fn block_on(fut: F) -> F::Output -where - F: Future, -{ - let thread_waker = Arc::new(ThreadWaker { - thread: std::thread::current(), - }) - .into(); - - let mut ctx = Context::from_waker(&thread_waker); - - let mut fut = std::pin::pin!(fut); - - loop { - if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) { - return out; - } - - // doesn't race - unpark followed by park will result in park returning immediately - std::thread::park(); - } -} - -async fn select(left: Left, right: Right) -> Either where Left: Future, Right: Future, diff --git a/src/spsc.rs b/src/spsc.rs new file mode 100644 index 0000000..40b81c4 --- /dev/null +++ b/src/spsc.rs @@ -0,0 +1,64 @@ +use crate::{ + drop_notifier::{DropListener, DropNotifier}, + executor::block_on, + queue::Queue, + selector::{select, Either}, + Canceled, +}; + +pub(super) fn channel() -> (Sender, Receiver) { + let queue = crate::queue::bounded(1); + + let (send_notifier, send_listener) = crate::drop_notifier::notifier(); + + let (recv_notifier, recv_listener) = crate::drop_notifier::notifier(); + + ( + Sender { + queue: queue.clone(), + send_notifier, + recv_listener, + }, + Receiver { + queue, + recv_notifier, + send_listener, + }, + ) +} + +pub(super) struct Sender { + queue: Queue, + #[allow(unused)] + send_notifier: DropNotifier, + recv_listener: DropListener, +} + +pub(super) struct Receiver { + queue: Queue, + #[allow(unused)] + recv_notifier: DropNotifier, + send_listener: DropListener, +} + +impl Sender { + pub(super) async fn send(self, item: T) -> Result<(), Canceled> { + match select(self.queue.push(item), self.recv_listener.listen()).await { + Either::Left(()) => Ok(()), + Either::Right(()) => Err(Canceled), + } + } + + pub(super) fn blocking_send(self, item: T) -> Result<(), Canceled> { + block_on(self.send(item)) + } +} + +impl Receiver { + pub(super) async fn recv(self) -> Result { + match select(self.queue.pop(), self.send_listener.listen()).await { + Either::Left(item) => Ok(item), + Either::Right(()) => Err(Canceled), + } + } +}