137 lines
3.3 KiB
Rust
137 lines
3.3 KiB
Rust
use std::{
|
|
future::Future,
|
|
pin::Pin,
|
|
sync::{
|
|
atomic::{AtomicBool, Ordering},
|
|
Arc,
|
|
},
|
|
task::{Context, Poll, Wake, Waker},
|
|
};
|
|
|
|
pub fn join<T, U>(t: T, u: U) -> impl Future<Output = (T::Output, U::Output)> + Unpin
|
|
where
|
|
T: Future + Unpin,
|
|
T::Output: Unpin,
|
|
U: Future + Unpin,
|
|
U::Output: Unpin,
|
|
{
|
|
Join {
|
|
t_state: JoinState {
|
|
woken: Arc::new(AtomicBool::new(true)),
|
|
state: FutureState::Running(t),
|
|
},
|
|
u_state: JoinState {
|
|
woken: Arc::new(AtomicBool::new(true)),
|
|
state: FutureState::Running(u),
|
|
},
|
|
}
|
|
}
|
|
|
|
struct JoinWaker {
|
|
woken: Arc<AtomicBool>,
|
|
parent: Waker,
|
|
}
|
|
|
|
enum FutureState<T: Future> {
|
|
Running(T),
|
|
Ready(Option<T::Output>),
|
|
}
|
|
|
|
struct JoinState<T: Future> {
|
|
woken: Arc<AtomicBool>,
|
|
state: FutureState<T>,
|
|
}
|
|
|
|
struct Join<T: Future, U: Future> {
|
|
t_state: JoinState<T>,
|
|
u_state: JoinState<U>,
|
|
}
|
|
|
|
impl<T: Future> JoinState<T> {
|
|
fn is_running(&self) -> bool {
|
|
matches!(self.state, FutureState::Running(_))
|
|
}
|
|
|
|
fn check_for_poll(&self) -> bool {
|
|
self.is_running() && self.woken.swap(false, Ordering::AcqRel)
|
|
}
|
|
}
|
|
|
|
impl Wake for JoinWaker {
|
|
fn wake(self: Arc<Self>) {
|
|
self.wake_by_ref();
|
|
}
|
|
|
|
fn wake_by_ref(self: &Arc<Self>) {
|
|
self.woken.store(true, Ordering::Release);
|
|
self.parent.wake_by_ref();
|
|
}
|
|
}
|
|
|
|
impl<T> Future for JoinState<T>
|
|
where
|
|
T: Future + Unpin,
|
|
T::Output: Unpin,
|
|
{
|
|
type Output = ();
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let waker = Arc::new(JoinWaker {
|
|
woken: Arc::clone(&self.woken),
|
|
parent: cx.waker().clone(),
|
|
})
|
|
.into();
|
|
|
|
let mut context = Context::from_waker(&waker);
|
|
|
|
let this = self.get_mut();
|
|
|
|
match std::mem::replace(&mut this.state, FutureState::Ready(None)) {
|
|
FutureState::Running(mut fut) => match Pin::new(&mut fut).poll(&mut context) {
|
|
Poll::Ready(res) => {
|
|
this.state = FutureState::Ready(Some(res));
|
|
Poll::Ready(())
|
|
}
|
|
Poll::Pending => {
|
|
this.state = FutureState::Running(fut);
|
|
Poll::Pending
|
|
}
|
|
},
|
|
FutureState::Ready(_) => panic!("Future polled after completion"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T, U> Future for Join<T, U>
|
|
where
|
|
T: Future + Unpin,
|
|
T::Output: Unpin,
|
|
U: Future + Unpin,
|
|
U::Output: Unpin,
|
|
{
|
|
type Output = (T::Output, U::Output);
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let this = self.get_mut();
|
|
|
|
let mut ready = false;
|
|
if this.t_state.check_for_poll() {
|
|
ready |= Pin::new(&mut this.t_state).poll(cx).is_ready();
|
|
}
|
|
|
|
if this.u_state.check_for_poll() {
|
|
ready |= Pin::new(&mut this.u_state).poll(cx).is_ready();
|
|
}
|
|
|
|
if ready {
|
|
if let (FutureState::Ready(ref mut t_opt), FutureState::Ready(ref mut u_opt)) =
|
|
(&mut this.t_state.state, &mut this.u_state.state)
|
|
{
|
|
return Poll::Ready((t_opt.take().unwrap(), u_opt.take().unwrap()));
|
|
}
|
|
}
|
|
|
|
Poll::Pending
|
|
}
|
|
}
|