Build async queue, spsc to replace use of flume channels

This commit is contained in:
asonix 2024-04-13 12:13:26 -05:00
parent 6d5e171e25
commit 9c29ab9fc2
8 changed files with 534 additions and 124 deletions

View file

@ -13,7 +13,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
flume = "0.11.0"
metrics = "0.22.0"
tracing = "0.1.40"

34
src/drop_notifier.rs Normal file
View file

@ -0,0 +1,34 @@
use std::sync::Arc;
use crate::notify::Notify;
pub(super) fn notifier() -> (DropNotifier, DropListener) {
let notify = Arc::new(Notify::new());
(
DropNotifier {
notify: Arc::clone(&notify),
},
DropListener { notify },
)
}
pub(super) struct DropNotifier {
notify: Arc<Notify>,
}
pub(super) struct DropListener {
notify: Arc<Notify>,
}
impl DropListener {
pub(super) async fn listen(self) {
self.notify.listen().await.await
}
}
impl Drop for DropNotifier {
fn drop(&mut self) {
self.notify.notify_one();
}
}

42
src/executor.rs Normal file
View file

@ -0,0 +1,42 @@
use std::{
future::Future,
sync::Arc,
task::{Context, Poll, Wake},
};
struct ThreadWaker {
thread: std::thread::Thread,
}
impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.thread.unpark();
}
fn wake_by_ref(self: &Arc<Self>) {
self.thread.unpark();
}
}
pub(super) fn block_on<F>(fut: F) -> F::Output
where
F: Future,
{
let thread_waker = Arc::new(ThreadWaker {
thread: std::thread::current(),
})
.into();
let mut ctx = Context::from_waker(&thread_waker);
let mut fut = std::pin::pin!(fut);
loop {
if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) {
return out;
}
// doesn't race - unpark followed by park will result in park returning immediately
std::thread::park();
}
}

View file

@ -1,15 +1,26 @@
#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
mod drop_notifier;
mod executor;
mod notify;
mod queue;
mod selector;
mod spsc;
use std::{
future::Future,
num::{NonZeroU16, NonZeroUsize},
sync::{atomic::AtomicU64, Arc, Mutex},
thread::JoinHandle,
time::Instant,
};
use drop_notifier::{DropListener, DropNotifier};
use executor::block_on;
use queue::Queue;
use selector::select;
/// Configuration builder for the CpuPool
#[derive(Debug)]
pub struct Config {
@ -240,62 +251,67 @@ impl CpuPool {
/// # })
/// # }
/// ```
pub async fn spawn<F, T>(&self, send_fn: F) -> Result<T, Canceled>
pub fn spawn<F, T>(&self, send_fn: F) -> impl Future<Output = Result<T, Canceled>> + '_
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (response_tx, response) = flume::bounded(1);
let (response_tx, response_rx) = spsc::channel();
let send_fn = Box::new(move || {
let output = (send_fn)();
if response_tx.send(output).is_err() {
tracing::trace!("Receiver hung up");
match response_tx.blocking_send(output) {
Ok(()) => (), // sent
Err(Canceled) => tracing::warn!("receiver hung up"),
}
});
let res = self.state.sender.try_send(send_fn);
let opt = self.state.queue.try_push(send_fn);
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if (self.state.sender_len() > current_threads || self.state.sender.is_full())
&& self.push_thread()
{
let pushed = match self.state.queue.is_full_or() {
Ok(()) => self.push_thread(),
Err(len) if len > current_threads as usize => self.push_thread(),
Err(_) => false,
};
if pushed {
tracing::trace!("Pushed thread");
}
match res {
Err(flume::TrySendError::Full(item)) => {
if self.state.sender.send_async(item).await.is_err() {
return Err(Canceled);
async {
match opt {
Some(item) => self.state.queue.push(item).await,
None => {}
}
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
match self.state.queue.is_full_or() {
Ok(()) => {
self.push_thread();
}
Err(len) if len > current_threads as usize => {
self.push_thread();
}
Err(len) if len < current_threads.ilog2() as usize => {
if let Some(thread) = self.pop_thread() {
thread.reap().await;
}
}
Err(_) => {}
}
Err(flume::TrySendError::Disconnected(_)) => {
return Err(Canceled);
}
Ok(()) => {}
response_rx.recv().await
}
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
let sender_len = self.state.sender_len();
if sender_len > current_threads || self.state.sender.is_full() {
self.push_thread();
} else if sender_len < u64::from(current_threads.ilog2()) {
if let Some(thread) = self.pop_thread() {
thread.reap().await;
}
}
response.into_recv_async().await.map_err(|_| Canceled)
}
/// Attempt to close the CpuPool
@ -330,8 +346,8 @@ impl CpuPool {
thread.signal.take();
}
for thread in &mut threads {
let _ = thread.closed.recv_async().await;
for mut thread in threads {
thread.closed.listen().await;
if let Some(handle) = thread.handle.take() {
handle.join().expect("Thread panicked");
@ -376,7 +392,7 @@ impl CpuPool {
.thread_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let thread = spawn(self.state.name, thread_id, self.state.receiver.clone());
let thread = spawn(self.state.name, thread_id, self.state.queue.clone());
self.state
.threads
@ -439,8 +455,7 @@ struct CpuPoolState {
max_threads: NonZeroU16,
current_threads: AtomicU64,
thread_id: AtomicU64,
sender: flume::Sender<SendFn>,
receiver: flume::Receiver<SendFn>,
queue: Queue<SendFn>,
threads: Mutex<ThreadVec>,
}
@ -453,13 +468,12 @@ impl CpuPoolState {
) -> Self {
let thread_capacity = usize::from(u16::from(max_threads));
let (sender, receiver) =
flume::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
let queue = queue::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
let start_threads = u64::from(u16::from(min_threads));
let threads = ThreadVec::new(start_threads, thread_capacity, |i| {
spawn(name, i, receiver.clone())
spawn(name, i, queue.clone())
});
let current_threads = AtomicU64::new(start_threads);
@ -471,8 +485,7 @@ impl CpuPoolState {
max_threads,
current_threads,
thread_id,
sender,
receiver,
queue,
threads: Mutex::new(threads),
}
}
@ -480,10 +493,6 @@ impl CpuPoolState {
fn take_threads(&mut self) -> Vec<Thread> {
self.threads.lock().expect("threads lock poison").take()
}
fn sender_len(&self) -> u64 {
u64::try_from(self.sender.len()).unwrap_or(u64::MAX)
}
}
impl std::fmt::Debug for CpuPoolState {
@ -543,15 +552,15 @@ impl Drop for ThreadVec {
struct Thread {
handle: Option<JoinHandle<()>>,
signal: Option<SendOnDrop>,
closed: flume::Receiver<()>,
signal: Option<DropNotifier>,
closed: DropListener,
}
impl Thread {
async fn reap(mut self) {
self.signal.take();
let _ = self.closed.recv_async().await;
self.closed.listen().await;
if let Some(handle) = self.handle.take() {
handle.join().expect("Thread panicked");
@ -559,32 +568,19 @@ impl Thread {
}
}
struct SendOnDrop {
sender: flume::Sender<()>,
}
impl Drop for SendOnDrop {
fn drop(&mut self) {
let _ = self.sender.try_send(());
}
}
fn spawn(name: &'static str, id: u64, receiver: flume::Receiver<SendFn>) -> Thread {
let (closed_tx, closed) = flume::bounded(1);
let (signal, signal_rx) = flume::bounded(1);
let signal = SendOnDrop { sender: signal };
let closed_tx = SendOnDrop { sender: closed_tx };
fn spawn(name: &'static str, id: u64, receiver: Queue<SendFn>) -> Thread {
let (closed_notifier, closed_listener) = drop_notifier::notifier();
let (signal_notifier, signal_listener) = drop_notifier::notifier();
let handle = std::thread::Builder::new()
.name(format!("{name}-{id}"))
.spawn(move || run(name, id, receiver, signal_rx, closed_tx))
.spawn(move || run(name, id, receiver, signal_listener, closed_notifier))
.expect("Failed to spawn new thread");
Thread {
handle: Some(handle),
signal: Some(signal),
closed,
signal: Some(signal_notifier),
closed: closed_listener,
}
}
@ -616,7 +612,7 @@ impl MetricsGuard {
impl Drop for MetricsGuard {
fn drop(&mut self) {
metrics::counter!(format!("async-cpupool.{}.closed", self.name), "clean" => (!self.armed).to_string()).increment(1);
metrics::histogram!(format!("async-cpupool.{}.duration", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64());
metrics::histogram!(format!("async-cpupool.{}.seconds", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64());
tracing::trace!("Stopping {}-{}", self.name, self.id);
}
}
@ -624,16 +620,18 @@ impl Drop for MetricsGuard {
fn run(
name: &'static str,
id: u64,
receiver: flume::Receiver<SendFn>,
signal: flume::Receiver<()>,
closed_tx: SendOnDrop,
receiver: Queue<SendFn>,
signal: DropListener,
closed_tx: DropNotifier,
) {
let guard = MetricsGuard::guard(name, id);
let mut signal = std::pin::pin!(signal.listen());
loop {
match selector::blocking_select(signal.recv_async(), receiver.recv_async()) {
selector::Either::Left(_) | selector::Either::Right(Err(_)) => break,
selector::Either::Right(Ok(send_fn)) => invoke_send_fn(name, send_fn),
match block_on(select(&mut signal, receiver.pop())) {
selector::Either::Left(_) => break,
selector::Either::Right(send_fn) => invoke_send_fn(name, send_fn),
}
}
@ -651,7 +649,7 @@ fn invoke_send_fn(name: &'static str, send_fn: SendFn) {
}));
metrics::counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string()).increment(1);
metrics::histogram!(format!("async-cpupool.{name}.operation.duration"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64());
metrics::histogram!(format!("async-cpupool.{name}.operation.seconds"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64());
if let Err(e) = res {
tracing::trace!("panic in spawned task: {e:?}");

208
src/notify.rs Normal file
View file

@ -0,0 +1,208 @@
use std::{
collections::VecDeque,
future::{poll_fn, Future},
sync::{
atomic::{AtomicU8, Ordering},
Arc, Mutex,
},
task::{Poll, Waker},
};
const UNNOTIFIED: u8 = 0b0000;
const NOTIFIED_ONE: u8 = 0b0001;
const RESOLVED: u8 = 0b0010;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct NotifyId(u64);
pub(super) struct Notify {
state: Mutex<NotifyState>,
}
struct NotifyState {
listeners: VecDeque<(NotifyId, Arc<AtomicU8>, Waker)>,
token: u64,
next_id: u64,
}
pub(super) struct Listener<'a> {
state: &'a Mutex<NotifyState>,
waker: Waker,
woken: Arc<AtomicU8>,
id: NotifyId,
}
impl Notify {
pub(super) fn new() -> Self {
metrics::counter!("async-cpupool.notify.created").increment(1);
Notify {
state: Mutex::new(NotifyState {
listeners: VecDeque::new(),
token: 0,
next_id: 0,
}),
}
}
// although this is an async fn, it is not capable of yielding to the executor
pub(super) async fn listen(&self) -> Listener<'_> {
poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await
}
pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> {
let (id, woken) = self
.state
.lock()
.expect("not poisoned")
.insert(waker.clone());
Listener {
state: &self.state,
waker,
woken,
id,
}
}
pub(super) fn notify_one(&self) {
self.state.lock().expect("not poisoned").notify_one();
}
}
impl NotifyState {
fn notify_one(&mut self) {
loop {
if let Some((_, woken, waker)) = self.listeners.pop_front() {
// can't use "weak" because we need to know failure means we failed fr fr to avoid
// popping unwoken listeners
match woken.compare_exchange(
UNNOTIFIED,
NOTIFIED_ONE,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => waker.wake(),
// if this listener isn't unnotified (races Listener Drop) we should wake the
// next one
Err(_) => continue,
}
} else {
self.token += 1;
}
break;
}
}
fn insert(&mut self, waker: Waker) -> (NotifyId, Arc<AtomicU8>) {
let id = NotifyId(self.next_id);
self.next_id += 1;
let token = if self.token > 0 {
self.token -= 1;
true
} else {
false
};
// don't insert waker if token is true - next poll will be ready
let woken = if token {
Arc::new(AtomicU8::new(NOTIFIED_ONE))
} else {
let woken = Arc::new(AtomicU8::new(UNNOTIFIED));
self.listeners.push_back((id, Arc::clone(&woken), waker));
woken
};
(id, woken)
}
fn remove(&mut self, id: NotifyId) {
if let Some(index) = self.find(&id) {
self.listeners.remove(index);
}
}
fn find(&self, needle_id: &NotifyId) -> Option<usize> {
self.listeners
.binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id)
.ok()
}
fn update(&mut self, id: NotifyId, needle_waker: &Waker) {
if let Some(index) = self.find(&id) {
if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) {
if !needle_waker.will_wake(haystack_waker) {
*haystack_waker = needle_waker.clone();
}
}
}
}
}
impl Drop for Notify {
fn drop(&mut self) {
metrics::counter!("async-cpupool.notify.dropped").increment(1);
}
}
impl Drop for Listener<'_> {
fn drop(&mut self) {
// races compare_exchange in notify_one
let flags = self.woken.swap(RESOLVED, Ordering::AcqRel);
if flags == RESOLVED {
return;
} else if flags == NOTIFIED_ONE {
let mut guard = self.state.lock().expect("not poisoned");
guard.notify_one();
} else if flags == UNNOTIFIED {
let mut guard = self.state.lock().expect("not poisoned");
guard.remove(self.id);
} else {
unreachable!("No other states exist")
}
}
}
impl Future for Listener<'_> {
type Output = ();
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
let mut flags = self.woken.load(Ordering::Acquire);
loop {
if flags == UNNOTIFIED {
break;
} else if flags == RESOLVED {
return Poll::Ready(());
} else {
match self.woken.compare_exchange_weak(
flags,
RESOLVED,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => return Poll::Ready(()),
Err(updated) => flags = updated,
};
}
}
if !self.waker.will_wake(cx.waker()) {
self.waker = cx.waker().clone();
self.state
.lock()
.expect("not poisoned")
.update(self.id, cx.waker());
}
Poll::Pending
}
}

110
src/queue.rs Normal file
View file

@ -0,0 +1,110 @@
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
use crate::notify::Notify;
pub(super) fn bounded<T>(capacity: usize) -> Queue<T> {
Queue::bounded(capacity)
}
pub(crate) struct Queue<T> {
inner: Arc<QueueState<T>>,
capacity: usize,
}
struct QueueState<T> {
queue: Mutex<VecDeque<T>>,
push_notify: Notify,
pop_notify: Notify,
}
impl<T> Queue<T> {
pub(super) fn bounded(capacity: usize) -> Self {
Self {
inner: Arc::new(QueueState {
queue: Mutex::new(VecDeque::new()),
push_notify: Notify::new(),
pop_notify: Notify::new(),
}),
capacity,
}
}
pub(super) fn len(&self) -> usize {
self.inner.queue.lock().expect("not poisoned").len()
}
pub(super) fn is_full_or(&self) -> Result<(), usize> {
let len = self.len();
if len >= self.capacity {
Ok(())
} else {
Err(len)
}
}
pub(super) async fn push(&self, mut item: T) {
loop {
let listener = self.inner.push_notify.listen().await;
if let Some(returned_item) = self.try_push_impl(item) {
item = returned_item;
listener.await;
} else {
self.inner.pop_notify.notify_one();
return;
}
}
}
pub(super) fn try_push(&self, item: T) -> Option<T> {
match self.try_push_impl(item) {
Some(item) => Some(item),
None => {
self.inner.pop_notify.notify_one();
None
}
}
}
fn try_push_impl(&self, item: T) -> Option<T> {
let mut guard = self.inner.queue.lock().expect("not poisoned");
if self.capacity <= guard.len() {
Some(item)
} else {
guard.push_back(item);
None
}
}
pub(super) async fn pop(&self) -> T {
loop {
let listener = self.inner.pop_notify.listen().await;
if let Some(item) = self.try_pop() {
self.inner.push_notify.notify_one();
return item;
}
listener.await;
}
}
fn try_pop(&self) -> Option<T> {
self.inner.queue.lock().expect("not poisoned").pop_front()
}
}
impl<T> Clone for Queue<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
capacity: self.capacity,
}
}
}

View file

@ -8,20 +8,6 @@ use std::{
task::{Context, Poll, Wake, Waker},
};
struct ThreadWaker {
thread: std::thread::Thread,
}
impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.thread.unpark();
}
fn wake_by_ref(self: &Arc<Self>) {
self.thread.unpark();
}
}
pub(super) enum Either<L, R> {
Left(L),
Right(R),
@ -93,41 +79,10 @@ where
}
}
pub(super) fn blocking_select<Left, Right>(
pub(super) async fn select<Left, Right>(
left: Left,
right: Right,
) -> Either<Left::Output, Right::Output>
where
Left: Future,
Right: Future,
{
block_on(select(left, right))
}
fn block_on<F>(fut: F) -> F::Output
where
F: Future,
{
let thread_waker = Arc::new(ThreadWaker {
thread: std::thread::current(),
})
.into();
let mut ctx = Context::from_waker(&thread_waker);
let mut fut = std::pin::pin!(fut);
loop {
if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) {
return out;
}
// doesn't race - unpark followed by park will result in park returning immediately
std::thread::park();
}
}
async fn select<Left, Right>(left: Left, right: Right) -> Either<Left::Output, Right::Output>
where
Left: Future,
Right: Future,

64
src/spsc.rs Normal file
View file

@ -0,0 +1,64 @@
use crate::{
drop_notifier::{DropListener, DropNotifier},
executor::block_on,
queue::Queue,
selector::{select, Either},
Canceled,
};
pub(super) fn channel<T>() -> (Sender<T>, Receiver<T>) {
let queue = crate::queue::bounded(1);
let (send_notifier, send_listener) = crate::drop_notifier::notifier();
let (recv_notifier, recv_listener) = crate::drop_notifier::notifier();
(
Sender {
queue: queue.clone(),
send_notifier,
recv_listener,
},
Receiver {
queue,
recv_notifier,
send_listener,
},
)
}
pub(super) struct Sender<T> {
queue: Queue<T>,
#[allow(unused)]
send_notifier: DropNotifier,
recv_listener: DropListener,
}
pub(super) struct Receiver<T> {
queue: Queue<T>,
#[allow(unused)]
recv_notifier: DropNotifier,
send_listener: DropListener,
}
impl<T> Sender<T> {
pub(super) async fn send(self, item: T) -> Result<(), Canceled> {
match select(self.queue.push(item), self.recv_listener.listen()).await {
Either::Left(()) => Ok(()),
Either::Right(()) => Err(Canceled),
}
}
pub(super) fn blocking_send(self, item: T) -> Result<(), Canceled> {
block_on(self.send(item))
}
}
impl<T> Receiver<T> {
pub(super) async fn recv(self) -> Result<T, Canceled> {
match select(self.queue.pop(), self.send_listener.listen()).await {
Either::Left(item) => Ok(item),
Either::Right(()) => Err(Canceled),
}
}
}