Build async queue, spsc to replace use of flume channels
This commit is contained in:
parent
6d5e171e25
commit
9c29ab9fc2
|
@ -13,7 +13,6 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
flume = "0.11.0"
|
||||
metrics = "0.22.0"
|
||||
tracing = "0.1.40"
|
||||
|
||||
|
|
34
src/drop_notifier.rs
Normal file
34
src/drop_notifier.rs
Normal file
|
@ -0,0 +1,34 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::notify::Notify;
|
||||
|
||||
pub(super) fn notifier() -> (DropNotifier, DropListener) {
|
||||
let notify = Arc::new(Notify::new());
|
||||
|
||||
(
|
||||
DropNotifier {
|
||||
notify: Arc::clone(¬ify),
|
||||
},
|
||||
DropListener { notify },
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) struct DropNotifier {
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
pub(super) struct DropListener {
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl DropListener {
|
||||
pub(super) async fn listen(self) {
|
||||
self.notify.listen().await.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DropNotifier {
|
||||
fn drop(&mut self) {
|
||||
self.notify.notify_one();
|
||||
}
|
||||
}
|
42
src/executor.rs
Normal file
42
src/executor.rs
Normal file
|
@ -0,0 +1,42 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
sync::Arc,
|
||||
task::{Context, Poll, Wake},
|
||||
};
|
||||
|
||||
struct ThreadWaker {
|
||||
thread: std::thread::Thread,
|
||||
}
|
||||
|
||||
impl Wake for ThreadWaker {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.thread.unpark();
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.thread.unpark();
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn block_on<F>(fut: F) -> F::Output
|
||||
where
|
||||
F: Future,
|
||||
{
|
||||
let thread_waker = Arc::new(ThreadWaker {
|
||||
thread: std::thread::current(),
|
||||
})
|
||||
.into();
|
||||
|
||||
let mut ctx = Context::from_waker(&thread_waker);
|
||||
|
||||
let mut fut = std::pin::pin!(fut);
|
||||
|
||||
loop {
|
||||
if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// doesn't race - unpark followed by park will result in park returning immediately
|
||||
std::thread::park();
|
||||
}
|
||||
}
|
152
src/lib.rs
152
src/lib.rs
|
@ -1,15 +1,26 @@
|
|||
#![doc = include_str!("../README.md")]
|
||||
#![deny(missing_docs)]
|
||||
|
||||
mod drop_notifier;
|
||||
mod executor;
|
||||
mod notify;
|
||||
mod queue;
|
||||
mod selector;
|
||||
mod spsc;
|
||||
|
||||
use std::{
|
||||
future::Future,
|
||||
num::{NonZeroU16, NonZeroUsize},
|
||||
sync::{atomic::AtomicU64, Arc, Mutex},
|
||||
thread::JoinHandle,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use drop_notifier::{DropListener, DropNotifier};
|
||||
use executor::block_on;
|
||||
use queue::Queue;
|
||||
use selector::select;
|
||||
|
||||
/// Configuration builder for the CpuPool
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
|
@ -240,62 +251,67 @@ impl CpuPool {
|
|||
/// # })
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn spawn<F, T>(&self, send_fn: F) -> Result<T, Canceled>
|
||||
pub fn spawn<F, T>(&self, send_fn: F) -> impl Future<Output = Result<T, Canceled>> + '_
|
||||
where
|
||||
F: FnOnce() -> T + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let (response_tx, response) = flume::bounded(1);
|
||||
let (response_tx, response_rx) = spsc::channel();
|
||||
|
||||
let send_fn = Box::new(move || {
|
||||
let output = (send_fn)();
|
||||
|
||||
if response_tx.send(output).is_err() {
|
||||
tracing::trace!("Receiver hung up");
|
||||
match response_tx.blocking_send(output) {
|
||||
Ok(()) => (), // sent
|
||||
Err(Canceled) => tracing::warn!("receiver hung up"),
|
||||
}
|
||||
});
|
||||
|
||||
let res = self.state.sender.try_send(send_fn);
|
||||
let opt = self.state.queue.try_push(send_fn);
|
||||
|
||||
let current_threads = self
|
||||
.state
|
||||
.current_threads
|
||||
.load(std::sync::atomic::Ordering::Acquire);
|
||||
|
||||
if (self.state.sender_len() > current_threads || self.state.sender.is_full())
|
||||
&& self.push_thread()
|
||||
{
|
||||
let pushed = match self.state.queue.is_full_or() {
|
||||
Ok(()) => self.push_thread(),
|
||||
Err(len) if len > current_threads as usize => self.push_thread(),
|
||||
Err(_) => false,
|
||||
};
|
||||
|
||||
if pushed {
|
||||
tracing::trace!("Pushed thread");
|
||||
}
|
||||
|
||||
match res {
|
||||
Err(flume::TrySendError::Full(item)) => {
|
||||
if self.state.sender.send_async(item).await.is_err() {
|
||||
return Err(Canceled);
|
||||
async {
|
||||
match opt {
|
||||
Some(item) => self.state.queue.push(item).await,
|
||||
None => {}
|
||||
}
|
||||
|
||||
let current_threads = self
|
||||
.state
|
||||
.current_threads
|
||||
.load(std::sync::atomic::Ordering::Acquire);
|
||||
|
||||
match self.state.queue.is_full_or() {
|
||||
Ok(()) => {
|
||||
self.push_thread();
|
||||
}
|
||||
Err(len) if len > current_threads as usize => {
|
||||
self.push_thread();
|
||||
}
|
||||
Err(len) if len < current_threads.ilog2() as usize => {
|
||||
if let Some(thread) = self.pop_thread() {
|
||||
thread.reap().await;
|
||||
}
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
Err(flume::TrySendError::Disconnected(_)) => {
|
||||
return Err(Canceled);
|
||||
}
|
||||
Ok(()) => {}
|
||||
|
||||
response_rx.recv().await
|
||||
}
|
||||
|
||||
let current_threads = self
|
||||
.state
|
||||
.current_threads
|
||||
.load(std::sync::atomic::Ordering::Acquire);
|
||||
|
||||
let sender_len = self.state.sender_len();
|
||||
|
||||
if sender_len > current_threads || self.state.sender.is_full() {
|
||||
self.push_thread();
|
||||
} else if sender_len < u64::from(current_threads.ilog2()) {
|
||||
if let Some(thread) = self.pop_thread() {
|
||||
thread.reap().await;
|
||||
}
|
||||
}
|
||||
|
||||
response.into_recv_async().await.map_err(|_| Canceled)
|
||||
}
|
||||
|
||||
/// Attempt to close the CpuPool
|
||||
|
@ -330,8 +346,8 @@ impl CpuPool {
|
|||
thread.signal.take();
|
||||
}
|
||||
|
||||
for thread in &mut threads {
|
||||
let _ = thread.closed.recv_async().await;
|
||||
for mut thread in threads {
|
||||
thread.closed.listen().await;
|
||||
|
||||
if let Some(handle) = thread.handle.take() {
|
||||
handle.join().expect("Thread panicked");
|
||||
|
@ -376,7 +392,7 @@ impl CpuPool {
|
|||
.thread_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let thread = spawn(self.state.name, thread_id, self.state.receiver.clone());
|
||||
let thread = spawn(self.state.name, thread_id, self.state.queue.clone());
|
||||
|
||||
self.state
|
||||
.threads
|
||||
|
@ -439,8 +455,7 @@ struct CpuPoolState {
|
|||
max_threads: NonZeroU16,
|
||||
current_threads: AtomicU64,
|
||||
thread_id: AtomicU64,
|
||||
sender: flume::Sender<SendFn>,
|
||||
receiver: flume::Receiver<SendFn>,
|
||||
queue: Queue<SendFn>,
|
||||
threads: Mutex<ThreadVec>,
|
||||
}
|
||||
|
||||
|
@ -453,13 +468,12 @@ impl CpuPoolState {
|
|||
) -> Self {
|
||||
let thread_capacity = usize::from(u16::from(max_threads));
|
||||
|
||||
let (sender, receiver) =
|
||||
flume::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
|
||||
let queue = queue::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
|
||||
|
||||
let start_threads = u64::from(u16::from(min_threads));
|
||||
|
||||
let threads = ThreadVec::new(start_threads, thread_capacity, |i| {
|
||||
spawn(name, i, receiver.clone())
|
||||
spawn(name, i, queue.clone())
|
||||
});
|
||||
|
||||
let current_threads = AtomicU64::new(start_threads);
|
||||
|
@ -471,8 +485,7 @@ impl CpuPoolState {
|
|||
max_threads,
|
||||
current_threads,
|
||||
thread_id,
|
||||
sender,
|
||||
receiver,
|
||||
queue,
|
||||
threads: Mutex::new(threads),
|
||||
}
|
||||
}
|
||||
|
@ -480,10 +493,6 @@ impl CpuPoolState {
|
|||
fn take_threads(&mut self) -> Vec<Thread> {
|
||||
self.threads.lock().expect("threads lock poison").take()
|
||||
}
|
||||
|
||||
fn sender_len(&self) -> u64 {
|
||||
u64::try_from(self.sender.len()).unwrap_or(u64::MAX)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CpuPoolState {
|
||||
|
@ -543,15 +552,15 @@ impl Drop for ThreadVec {
|
|||
|
||||
struct Thread {
|
||||
handle: Option<JoinHandle<()>>,
|
||||
signal: Option<SendOnDrop>,
|
||||
closed: flume::Receiver<()>,
|
||||
signal: Option<DropNotifier>,
|
||||
closed: DropListener,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
async fn reap(mut self) {
|
||||
self.signal.take();
|
||||
|
||||
let _ = self.closed.recv_async().await;
|
||||
self.closed.listen().await;
|
||||
|
||||
if let Some(handle) = self.handle.take() {
|
||||
handle.join().expect("Thread panicked");
|
||||
|
@ -559,32 +568,19 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
struct SendOnDrop {
|
||||
sender: flume::Sender<()>,
|
||||
}
|
||||
|
||||
impl Drop for SendOnDrop {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.sender.try_send(());
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn(name: &'static str, id: u64, receiver: flume::Receiver<SendFn>) -> Thread {
|
||||
let (closed_tx, closed) = flume::bounded(1);
|
||||
let (signal, signal_rx) = flume::bounded(1);
|
||||
|
||||
let signal = SendOnDrop { sender: signal };
|
||||
let closed_tx = SendOnDrop { sender: closed_tx };
|
||||
fn spawn(name: &'static str, id: u64, receiver: Queue<SendFn>) -> Thread {
|
||||
let (closed_notifier, closed_listener) = drop_notifier::notifier();
|
||||
let (signal_notifier, signal_listener) = drop_notifier::notifier();
|
||||
|
||||
let handle = std::thread::Builder::new()
|
||||
.name(format!("{name}-{id}"))
|
||||
.spawn(move || run(name, id, receiver, signal_rx, closed_tx))
|
||||
.spawn(move || run(name, id, receiver, signal_listener, closed_notifier))
|
||||
.expect("Failed to spawn new thread");
|
||||
|
||||
Thread {
|
||||
handle: Some(handle),
|
||||
signal: Some(signal),
|
||||
closed,
|
||||
signal: Some(signal_notifier),
|
||||
closed: closed_listener,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -616,7 +612,7 @@ impl MetricsGuard {
|
|||
impl Drop for MetricsGuard {
|
||||
fn drop(&mut self) {
|
||||
metrics::counter!(format!("async-cpupool.{}.closed", self.name), "clean" => (!self.armed).to_string()).increment(1);
|
||||
metrics::histogram!(format!("async-cpupool.{}.duration", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64());
|
||||
metrics::histogram!(format!("async-cpupool.{}.seconds", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64());
|
||||
tracing::trace!("Stopping {}-{}", self.name, self.id);
|
||||
}
|
||||
}
|
||||
|
@ -624,16 +620,18 @@ impl Drop for MetricsGuard {
|
|||
fn run(
|
||||
name: &'static str,
|
||||
id: u64,
|
||||
receiver: flume::Receiver<SendFn>,
|
||||
signal: flume::Receiver<()>,
|
||||
closed_tx: SendOnDrop,
|
||||
receiver: Queue<SendFn>,
|
||||
signal: DropListener,
|
||||
closed_tx: DropNotifier,
|
||||
) {
|
||||
let guard = MetricsGuard::guard(name, id);
|
||||
|
||||
let mut signal = std::pin::pin!(signal.listen());
|
||||
|
||||
loop {
|
||||
match selector::blocking_select(signal.recv_async(), receiver.recv_async()) {
|
||||
selector::Either::Left(_) | selector::Either::Right(Err(_)) => break,
|
||||
selector::Either::Right(Ok(send_fn)) => invoke_send_fn(name, send_fn),
|
||||
match block_on(select(&mut signal, receiver.pop())) {
|
||||
selector::Either::Left(_) => break,
|
||||
selector::Either::Right(send_fn) => invoke_send_fn(name, send_fn),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -651,7 +649,7 @@ fn invoke_send_fn(name: &'static str, send_fn: SendFn) {
|
|||
}));
|
||||
|
||||
metrics::counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string()).increment(1);
|
||||
metrics::histogram!(format!("async-cpupool.{name}.operation.duration"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64());
|
||||
metrics::histogram!(format!("async-cpupool.{name}.operation.seconds"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64());
|
||||
|
||||
if let Err(e) = res {
|
||||
tracing::trace!("panic in spawned task: {e:?}");
|
||||
|
|
208
src/notify.rs
Normal file
208
src/notify.rs
Normal file
|
@ -0,0 +1,208 @@
|
|||
use std::{
|
||||
collections::VecDeque,
|
||||
future::{poll_fn, Future},
|
||||
sync::{
|
||||
atomic::{AtomicU8, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
task::{Poll, Waker},
|
||||
};
|
||||
|
||||
const UNNOTIFIED: u8 = 0b0000;
|
||||
const NOTIFIED_ONE: u8 = 0b0001;
|
||||
const RESOLVED: u8 = 0b0010;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
struct NotifyId(u64);
|
||||
|
||||
pub(super) struct Notify {
|
||||
state: Mutex<NotifyState>,
|
||||
}
|
||||
|
||||
struct NotifyState {
|
||||
listeners: VecDeque<(NotifyId, Arc<AtomicU8>, Waker)>,
|
||||
token: u64,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
pub(super) struct Listener<'a> {
|
||||
state: &'a Mutex<NotifyState>,
|
||||
waker: Waker,
|
||||
woken: Arc<AtomicU8>,
|
||||
id: NotifyId,
|
||||
}
|
||||
|
||||
impl Notify {
|
||||
pub(super) fn new() -> Self {
|
||||
metrics::counter!("async-cpupool.notify.created").increment(1);
|
||||
|
||||
Notify {
|
||||
state: Mutex::new(NotifyState {
|
||||
listeners: VecDeque::new(),
|
||||
token: 0,
|
||||
next_id: 0,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// although this is an async fn, it is not capable of yielding to the executor
|
||||
pub(super) async fn listen(&self) -> Listener<'_> {
|
||||
poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await
|
||||
}
|
||||
|
||||
pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> {
|
||||
let (id, woken) = self
|
||||
.state
|
||||
.lock()
|
||||
.expect("not poisoned")
|
||||
.insert(waker.clone());
|
||||
|
||||
Listener {
|
||||
state: &self.state,
|
||||
waker,
|
||||
woken,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn notify_one(&self) {
|
||||
self.state.lock().expect("not poisoned").notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
impl NotifyState {
|
||||
fn notify_one(&mut self) {
|
||||
loop {
|
||||
if let Some((_, woken, waker)) = self.listeners.pop_front() {
|
||||
// can't use "weak" because we need to know failure means we failed fr fr to avoid
|
||||
// popping unwoken listeners
|
||||
match woken.compare_exchange(
|
||||
UNNOTIFIED,
|
||||
NOTIFIED_ONE,
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => waker.wake(),
|
||||
|
||||
// if this listener isn't unnotified (races Listener Drop) we should wake the
|
||||
// next one
|
||||
Err(_) => continue,
|
||||
}
|
||||
} else {
|
||||
self.token += 1;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(&mut self, waker: Waker) -> (NotifyId, Arc<AtomicU8>) {
|
||||
let id = NotifyId(self.next_id);
|
||||
self.next_id += 1;
|
||||
|
||||
let token = if self.token > 0 {
|
||||
self.token -= 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// don't insert waker if token is true - next poll will be ready
|
||||
let woken = if token {
|
||||
Arc::new(AtomicU8::new(NOTIFIED_ONE))
|
||||
} else {
|
||||
let woken = Arc::new(AtomicU8::new(UNNOTIFIED));
|
||||
self.listeners.push_back((id, Arc::clone(&woken), waker));
|
||||
woken
|
||||
};
|
||||
|
||||
(id, woken)
|
||||
}
|
||||
|
||||
fn remove(&mut self, id: NotifyId) {
|
||||
if let Some(index) = self.find(&id) {
|
||||
self.listeners.remove(index);
|
||||
}
|
||||
}
|
||||
|
||||
fn find(&self, needle_id: &NotifyId) -> Option<usize> {
|
||||
self.listeners
|
||||
.binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id)
|
||||
.ok()
|
||||
}
|
||||
|
||||
fn update(&mut self, id: NotifyId, needle_waker: &Waker) {
|
||||
if let Some(index) = self.find(&id) {
|
||||
if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) {
|
||||
if !needle_waker.will_wake(haystack_waker) {
|
||||
*haystack_waker = needle_waker.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Notify {
|
||||
fn drop(&mut self) {
|
||||
metrics::counter!("async-cpupool.notify.dropped").increment(1);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Listener<'_> {
|
||||
fn drop(&mut self) {
|
||||
// races compare_exchange in notify_one
|
||||
let flags = self.woken.swap(RESOLVED, Ordering::AcqRel);
|
||||
|
||||
if flags == RESOLVED {
|
||||
return;
|
||||
} else if flags == NOTIFIED_ONE {
|
||||
let mut guard = self.state.lock().expect("not poisoned");
|
||||
guard.notify_one();
|
||||
} else if flags == UNNOTIFIED {
|
||||
let mut guard = self.state.lock().expect("not poisoned");
|
||||
guard.remove(self.id);
|
||||
} else {
|
||||
unreachable!("No other states exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Listener<'_> {
|
||||
type Output = ();
|
||||
|
||||
fn poll(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Self::Output> {
|
||||
let mut flags = self.woken.load(Ordering::Acquire);
|
||||
|
||||
loop {
|
||||
if flags == UNNOTIFIED {
|
||||
break;
|
||||
} else if flags == RESOLVED {
|
||||
return Poll::Ready(());
|
||||
} else {
|
||||
match self.woken.compare_exchange_weak(
|
||||
flags,
|
||||
RESOLVED,
|
||||
Ordering::Release,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => return Poll::Ready(()),
|
||||
Err(updated) => flags = updated,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if !self.waker.will_wake(cx.waker()) {
|
||||
self.waker = cx.waker().clone();
|
||||
|
||||
self.state
|
||||
.lock()
|
||||
.expect("not poisoned")
|
||||
.update(self.id, cx.waker());
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
110
src/queue.rs
Normal file
110
src/queue.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
use std::{
|
||||
collections::VecDeque,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use crate::notify::Notify;
|
||||
|
||||
pub(super) fn bounded<T>(capacity: usize) -> Queue<T> {
|
||||
Queue::bounded(capacity)
|
||||
}
|
||||
|
||||
pub(crate) struct Queue<T> {
|
||||
inner: Arc<QueueState<T>>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
struct QueueState<T> {
|
||||
queue: Mutex<VecDeque<T>>,
|
||||
push_notify: Notify,
|
||||
pop_notify: Notify,
|
||||
}
|
||||
|
||||
impl<T> Queue<T> {
|
||||
pub(super) fn bounded(capacity: usize) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(QueueState {
|
||||
queue: Mutex::new(VecDeque::new()),
|
||||
push_notify: Notify::new(),
|
||||
pop_notify: Notify::new(),
|
||||
}),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn len(&self) -> usize {
|
||||
self.inner.queue.lock().expect("not poisoned").len()
|
||||
}
|
||||
|
||||
pub(super) fn is_full_or(&self) -> Result<(), usize> {
|
||||
let len = self.len();
|
||||
|
||||
if len >= self.capacity {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(len)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn push(&self, mut item: T) {
|
||||
loop {
|
||||
let listener = self.inner.push_notify.listen().await;
|
||||
|
||||
if let Some(returned_item) = self.try_push_impl(item) {
|
||||
item = returned_item;
|
||||
|
||||
listener.await;
|
||||
} else {
|
||||
self.inner.pop_notify.notify_one();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn try_push(&self, item: T) -> Option<T> {
|
||||
match self.try_push_impl(item) {
|
||||
Some(item) => Some(item),
|
||||
None => {
|
||||
self.inner.pop_notify.notify_one();
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_push_impl(&self, item: T) -> Option<T> {
|
||||
let mut guard = self.inner.queue.lock().expect("not poisoned");
|
||||
|
||||
if self.capacity <= guard.len() {
|
||||
Some(item)
|
||||
} else {
|
||||
guard.push_back(item);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn pop(&self) -> T {
|
||||
loop {
|
||||
let listener = self.inner.pop_notify.listen().await;
|
||||
|
||||
if let Some(item) = self.try_pop() {
|
||||
self.inner.push_notify.notify_one();
|
||||
return item;
|
||||
}
|
||||
|
||||
listener.await;
|
||||
}
|
||||
}
|
||||
|
||||
fn try_pop(&self) -> Option<T> {
|
||||
self.inner.queue.lock().expect("not poisoned").pop_front()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for Queue<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: Arc::clone(&self.inner),
|
||||
capacity: self.capacity,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,20 +8,6 @@ use std::{
|
|||
task::{Context, Poll, Wake, Waker},
|
||||
};
|
||||
|
||||
struct ThreadWaker {
|
||||
thread: std::thread::Thread,
|
||||
}
|
||||
|
||||
impl Wake for ThreadWaker {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.thread.unpark();
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.thread.unpark();
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) enum Either<L, R> {
|
||||
Left(L),
|
||||
Right(R),
|
||||
|
@ -93,41 +79,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) fn blocking_select<Left, Right>(
|
||||
pub(super) async fn select<Left, Right>(
|
||||
left: Left,
|
||||
right: Right,
|
||||
) -> Either<Left::Output, Right::Output>
|
||||
where
|
||||
Left: Future,
|
||||
Right: Future,
|
||||
{
|
||||
block_on(select(left, right))
|
||||
}
|
||||
|
||||
fn block_on<F>(fut: F) -> F::Output
|
||||
where
|
||||
F: Future,
|
||||
{
|
||||
let thread_waker = Arc::new(ThreadWaker {
|
||||
thread: std::thread::current(),
|
||||
})
|
||||
.into();
|
||||
|
||||
let mut ctx = Context::from_waker(&thread_waker);
|
||||
|
||||
let mut fut = std::pin::pin!(fut);
|
||||
|
||||
loop {
|
||||
if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// doesn't race - unpark followed by park will result in park returning immediately
|
||||
std::thread::park();
|
||||
}
|
||||
}
|
||||
|
||||
async fn select<Left, Right>(left: Left, right: Right) -> Either<Left::Output, Right::Output>
|
||||
where
|
||||
Left: Future,
|
||||
Right: Future,
|
||||
|
|
64
src/spsc.rs
Normal file
64
src/spsc.rs
Normal file
|
@ -0,0 +1,64 @@
|
|||
use crate::{
|
||||
drop_notifier::{DropListener, DropNotifier},
|
||||
executor::block_on,
|
||||
queue::Queue,
|
||||
selector::{select, Either},
|
||||
Canceled,
|
||||
};
|
||||
|
||||
pub(super) fn channel<T>() -> (Sender<T>, Receiver<T>) {
|
||||
let queue = crate::queue::bounded(1);
|
||||
|
||||
let (send_notifier, send_listener) = crate::drop_notifier::notifier();
|
||||
|
||||
let (recv_notifier, recv_listener) = crate::drop_notifier::notifier();
|
||||
|
||||
(
|
||||
Sender {
|
||||
queue: queue.clone(),
|
||||
send_notifier,
|
||||
recv_listener,
|
||||
},
|
||||
Receiver {
|
||||
queue,
|
||||
recv_notifier,
|
||||
send_listener,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) struct Sender<T> {
|
||||
queue: Queue<T>,
|
||||
#[allow(unused)]
|
||||
send_notifier: DropNotifier,
|
||||
recv_listener: DropListener,
|
||||
}
|
||||
|
||||
pub(super) struct Receiver<T> {
|
||||
queue: Queue<T>,
|
||||
#[allow(unused)]
|
||||
recv_notifier: DropNotifier,
|
||||
send_listener: DropListener,
|
||||
}
|
||||
|
||||
impl<T> Sender<T> {
|
||||
pub(super) async fn send(self, item: T) -> Result<(), Canceled> {
|
||||
match select(self.queue.push(item), self.recv_listener.listen()).await {
|
||||
Either::Left(()) => Ok(()),
|
||||
Either::Right(()) => Err(Canceled),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn blocking_send(self, item: T) -> Result<(), Canceled> {
|
||||
block_on(self.send(item))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Receiver<T> {
|
||||
pub(super) async fn recv(self) -> Result<T, Canceled> {
|
||||
match select(self.queue.pop(), self.send_listener.listen()).await {
|
||||
Either::Left(item) => Ok(item),
|
||||
Either::Right(()) => Err(Canceled),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue