join/src/lib.rs
2022-03-07 19:25:22 -06:00

137 lines
3.3 KiB
Rust

use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Wake, Waker},
};
pub fn join<T, U>(t: T, u: U) -> impl Future<Output = (T::Output, U::Output)> + Unpin
where
T: Future + Unpin,
T::Output: Unpin,
U: Future + Unpin,
U::Output: Unpin,
{
Join {
t_state: JoinState {
woken: Arc::new(AtomicBool::new(true)),
state: FutureState::Running(t),
},
u_state: JoinState {
woken: Arc::new(AtomicBool::new(true)),
state: FutureState::Running(u),
},
}
}
struct JoinWaker {
woken: Arc<AtomicBool>,
parent: Waker,
}
enum FutureState<T: Future> {
Running(T),
Ready(Option<T::Output>),
}
struct JoinState<T: Future> {
woken: Arc<AtomicBool>,
state: FutureState<T>,
}
struct Join<T: Future, U: Future> {
t_state: JoinState<T>,
u_state: JoinState<U>,
}
impl<T: Future> JoinState<T> {
fn is_running(&self) -> bool {
matches!(self.state, FutureState::Running(_))
}
fn check_for_poll(&self) -> bool {
self.is_running() && self.woken.swap(false, Ordering::AcqRel)
}
}
impl Wake for JoinWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
self.woken.store(true, Ordering::Release);
self.parent.wake_by_ref();
}
}
impl<T> Future for JoinState<T>
where
T: Future + Unpin,
T::Output: Unpin,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = Arc::new(JoinWaker {
woken: Arc::clone(&self.woken),
parent: cx.waker().clone(),
})
.into();
let mut context = Context::from_waker(&waker);
let this = self.get_mut();
match std::mem::replace(&mut this.state, FutureState::Ready(None)) {
FutureState::Running(mut fut) => match Pin::new(&mut fut).poll(&mut context) {
Poll::Ready(res) => {
this.state = FutureState::Ready(Some(res));
Poll::Ready(())
}
Poll::Pending => {
this.state = FutureState::Running(fut);
Poll::Pending
}
},
FutureState::Ready(_) => panic!("Future polled after completion"),
}
}
}
impl<T, U> Future for Join<T, U>
where
T: Future + Unpin,
T::Output: Unpin,
U: Future + Unpin,
U::Output: Unpin,
{
type Output = (T::Output, U::Output);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let mut ready = false;
if this.t_state.check_for_poll() {
ready |= Pin::new(&mut this.t_state).poll(cx).is_ready();
}
if this.u_state.check_for_poll() {
ready |= Pin::new(&mut this.u_state).poll(cx).is_ready();
}
if ready {
if let (FutureState::Ready(ref mut t_opt), FutureState::Ready(ref mut u_opt)) =
(&mut this.t_state.state, &mut this.u_state.state)
{
return Poll::Ready((t_opt.take().unwrap(), u_opt.take().unwrap()));
}
}
Poll::Pending
}
}