async-cpupool/src/selector.rs

102 lines
2.2 KiB
Rust

use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Wake, Waker},
};
pub(super) enum Either<L, R> {
Left(L),
Right(R),
}
struct Select<F1, F2> {
left: F1,
left_woken: Arc<AtomicBool>,
right: F2,
right_woken: Arc<AtomicBool>,
}
struct SelectWaker {
inner: Waker,
flag: Arc<AtomicBool>,
}
impl Wake for SelectWaker {
fn wake_by_ref(self: &Arc<Self>) {
self.flag.store(true, Ordering::Release);
self.inner.wake_by_ref();
}
fn wake(self: Arc<Self>) {
self.flag.store(true, Ordering::Release);
match Arc::try_unwrap(self) {
Ok(this) => this.inner.wake(),
Err(this) => this.inner.wake_by_ref(),
}
}
}
impl<F1, F2> Future for Select<F1, F2>
where
F1: Future + Unpin,
F2: Future + Unpin,
{
type Output = Either<F1::Output, F2::Output>;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let left_waker = Arc::new(SelectWaker {
inner: cx.waker().clone(),
flag: self.left_woken.clone(),
})
.into();
let mut left_ctx = Context::from_waker(&left_waker);
if let Poll::Ready(left_out) = Pin::new(&mut self.left).poll(&mut left_ctx) {
return Poll::Ready(Either::Left(left_out));
}
let right_waker = Arc::new(SelectWaker {
inner: cx.waker().clone(),
flag: self.right_woken.clone(),
})
.into();
let mut right_ctx = Context::from_waker(&right_waker);
if let Poll::Ready(right_out) = Pin::new(&mut self.right).poll(&mut right_ctx) {
return Poll::Ready(Either::Right(right_out));
}
Poll::Pending
}
}
pub(super) async fn select<Left, Right>(
left: Left,
right: Right,
) -> Either<Left::Output, Right::Output>
where
Left: Future,
Right: Future,
{
let left = std::pin::pin!(left);
let right = std::pin::pin!(right);
Select {
left,
left_woken: Arc::new(AtomicBool::new(true)),
right,
right_woken: Arc::new(AtomicBool::new(true)),
}
.await
}