316 lines
8.9 KiB
Rust
316 lines
8.9 KiB
Rust
use std::{
|
|
future::Future,
|
|
pin::Pin,
|
|
sync::{
|
|
atomic::{AtomicBool, Ordering},
|
|
Arc,
|
|
},
|
|
task::{Context, Poll, Wake, Waker},
|
|
};
|
|
|
|
const MAX_CONSECUTIVE_POLLS: i32 = 16;
|
|
|
|
pub fn join_all<T>(futures: Vec<T>) -> JoinAll<T>
|
|
where
|
|
T: Future + Unpin,
|
|
T::Output: Unpin,
|
|
{
|
|
JoinAll {
|
|
futures: futures
|
|
.into_iter()
|
|
.map(|fut| (Arc::new(AtomicBool::new(true)), JoinState::Pending(fut)))
|
|
.collect(),
|
|
next_poll_index: 0,
|
|
total_complete: 0,
|
|
}
|
|
}
|
|
|
|
pub struct JoinAll<T>
|
|
where
|
|
T: Future,
|
|
{
|
|
futures: Vec<(Arc<AtomicBool>, JoinState<T>)>,
|
|
next_poll_index: usize,
|
|
total_complete: usize,
|
|
}
|
|
|
|
struct SmartWaker {
|
|
waker: Waker,
|
|
marker: Arc<AtomicBool>,
|
|
}
|
|
|
|
impl Wake for SmartWaker {
|
|
fn wake(self: Arc<Self>) {
|
|
self.marker.store(true, Ordering::Release);
|
|
self.waker.clone().wake();
|
|
}
|
|
|
|
fn wake_by_ref(self: &Arc<Self>) {
|
|
self.marker.store(true, Ordering::Release);
|
|
self.waker.wake_by_ref();
|
|
}
|
|
}
|
|
|
|
enum JoinState<T>
|
|
where
|
|
T: Future,
|
|
{
|
|
Polling,
|
|
Pending(T),
|
|
Complete(<T as Future>::Output),
|
|
}
|
|
|
|
impl<T> JoinState<T>
|
|
where
|
|
T: Future + Unpin,
|
|
{
|
|
fn take(&mut self) -> JoinState<T> {
|
|
std::mem::replace(self, JoinState::Polling)
|
|
}
|
|
}
|
|
|
|
impl<T> Future for JoinAll<T>
|
|
where
|
|
T: Future + Unpin,
|
|
<T as Future>::Output: Unpin,
|
|
{
|
|
type Output = Vec<<T as Future>::Output>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let futures_len = self.futures.len();
|
|
let current_index = self.next_poll_index % futures_len;
|
|
|
|
let mut poll_count = 0;
|
|
let mut total_complete = self.total_complete;
|
|
|
|
let (beginning, ending) = self.futures.split_at_mut(current_index);
|
|
|
|
let futures_iter = ending.iter_mut().chain(beginning.iter_mut()).enumerate();
|
|
|
|
for (idx, (pollable, join_state)) in futures_iter {
|
|
if poll_count < MAX_CONSECUTIVE_POLLS && pollable.load(Ordering::Acquire) {
|
|
pollable.store(false, Ordering::Release);
|
|
let new_state = match join_state.take() {
|
|
JoinState::Pending(mut fut) => {
|
|
poll_count += 1;
|
|
|
|
let waker = Arc::new(SmartWaker {
|
|
waker: cx.waker().clone(),
|
|
marker: Arc::clone(pollable),
|
|
})
|
|
.into();
|
|
|
|
let mut smart_context = Context::from_waker(&waker);
|
|
|
|
match Pin::new(&mut fut).poll(&mut smart_context) {
|
|
Poll::Ready(output) => {
|
|
total_complete += 1;
|
|
JoinState::Complete(output)
|
|
}
|
|
Poll::Pending => JoinState::Pending(fut),
|
|
}
|
|
}
|
|
otherwise => otherwise,
|
|
};
|
|
|
|
*join_state = new_state;
|
|
} else if poll_count >= MAX_CONSECUTIVE_POLLS {
|
|
self.next_poll_index = (idx + current_index) % futures_len;
|
|
self.total_complete = total_complete;
|
|
cx.waker().wake_by_ref();
|
|
return Poll::Pending;
|
|
}
|
|
}
|
|
|
|
self.total_complete = total_complete;
|
|
|
|
if self.total_complete == self.futures.len() {
|
|
return Poll::Ready(self.futures.iter_mut().map(|(_, join_state)| {
|
|
match join_state.take() {
|
|
JoinState::Complete(output) => output,
|
|
_ => unreachable!("All futures should be complete when total_complete matches inner length"),
|
|
}
|
|
}).collect());
|
|
}
|
|
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{join_all, MAX_CONSECUTIVE_POLLS};
|
|
use std::{
|
|
future::{ready, Future},
|
|
pin::Pin,
|
|
sync::{Arc, Mutex},
|
|
task::{Context, Poll, Wake},
|
|
};
|
|
|
|
struct Faker {
|
|
wake_count: Mutex<usize>,
|
|
}
|
|
|
|
impl Wake for Faker {
|
|
fn wake(self: Arc<Self>) {
|
|
*self.wake_count.lock().unwrap() += 1;
|
|
}
|
|
fn wake_by_ref(self: &Arc<Self>) {
|
|
*self.wake_count.lock().unwrap() += 1;
|
|
}
|
|
}
|
|
|
|
fn with_context(f: impl FnOnce(&mut Context<'_>)) -> usize {
|
|
let ctx = Arc::new(Faker {
|
|
wake_count: Mutex::new(0),
|
|
});
|
|
let waker = Arc::clone(&ctx).into();
|
|
|
|
(f)(&mut Context::from_waker(&waker));
|
|
|
|
let x = *ctx.wake_count.lock().unwrap();
|
|
x
|
|
}
|
|
|
|
#[test]
|
|
fn it_produces_a_vec_on_success() {
|
|
let wake_count = with_context(move |ctx| {
|
|
let futures = vec![ready(1), ready(2), ready(3), ready(4), ready(5)];
|
|
let mut fut = join_all(futures);
|
|
|
|
match Pin::new(&mut fut).poll(ctx) {
|
|
Poll::Ready(vec) => assert_eq!(vec, vec![1, 2, 3, 4, 5]),
|
|
Poll::Pending => panic!("Should have returned Ready"),
|
|
}
|
|
});
|
|
|
|
assert_eq!(wake_count, 0, "Shouldn't have caused any wakes");
|
|
}
|
|
|
|
#[test]
|
|
fn it_returns_pending_and_wakes_once_for_too_many_futures() {
|
|
let wake_count = with_context(move |ctx| {
|
|
let futures = (0..(MAX_CONSECUTIVE_POLLS + 1))
|
|
.map(ready)
|
|
.collect::<Vec<_>>();
|
|
let mut fut = join_all(futures);
|
|
|
|
if let Poll::Ready(_) = Pin::new(&mut fut).poll(ctx) {
|
|
panic!("Shouldn't have returned ready on first poll");
|
|
}
|
|
|
|
match Pin::new(&mut fut).poll(ctx) {
|
|
Poll::Ready(vec) => {
|
|
assert_eq!(vec, (0..MAX_CONSECUTIVE_POLLS + 1).collect::<Vec<_>>())
|
|
}
|
|
Poll::Pending => panic!("Should have completed on second poll"),
|
|
}
|
|
});
|
|
|
|
assert_eq!(wake_count, 1, "Should have woken once");
|
|
}
|
|
|
|
struct WakeOnPoll(usize);
|
|
|
|
impl Future for WakeOnPoll {
|
|
type Output = ();
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
self.0 += 1;
|
|
cx.waker().wake_by_ref();
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn waker_is_woken() {
|
|
let wake_count = with_context(move |ctx| {
|
|
let mut wake_on_poll = WakeOnPoll(0);
|
|
let futures = vec![&mut wake_on_poll];
|
|
|
|
let mut fut = join_all(futures);
|
|
|
|
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
|
panic!("Future should be pending forever");
|
|
}
|
|
|
|
assert_eq!(wake_on_poll.0, 1, "Should have polled future once");
|
|
});
|
|
|
|
assert_eq!(wake_count, 1, "Should have woken on poll")
|
|
}
|
|
|
|
#[test]
|
|
fn poll_woken_future() {
|
|
let wake_count = with_context(move |ctx| {
|
|
let mut wake_on_poll = WakeOnPoll(0);
|
|
let futures = vec![&mut wake_on_poll];
|
|
|
|
let mut fut = join_all(futures);
|
|
|
|
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
|
panic!("Future should be pending forever");
|
|
}
|
|
|
|
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
|
panic!("Future should be pending forever");
|
|
}
|
|
|
|
assert_eq!(wake_on_poll.0, 2, "Should have polled future twice");
|
|
});
|
|
|
|
assert_eq!(wake_count, 2, "Should have woken on poll twice")
|
|
}
|
|
|
|
struct DummyPoll(usize);
|
|
|
|
impl Future for DummyPoll {
|
|
type Output = ();
|
|
|
|
fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
|
|
self.0 += 1;
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn dont_poll_unwoken_future() {
|
|
let wake_count = with_context(move |ctx| {
|
|
let mut dummy_poll = DummyPoll(0);
|
|
let mut wake_on_poll = WakeOnPoll(0);
|
|
let mut dummy_poll_2 = DummyPoll(0);
|
|
let mut wake_on_poll_2 = WakeOnPoll(0);
|
|
let futures = vec![
|
|
&mut dummy_poll as &mut (dyn Future<Output = ()> + Unpin),
|
|
&mut wake_on_poll,
|
|
&mut dummy_poll_2,
|
|
&mut wake_on_poll_2,
|
|
];
|
|
|
|
let mut fut = join_all(futures);
|
|
|
|
for _ in 0..5 {
|
|
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
|
panic!("Future should be pending forever");
|
|
}
|
|
}
|
|
|
|
assert_eq!(dummy_poll.0, 1, "Should have polled dummy future only once");
|
|
assert_eq!(
|
|
wake_on_poll.0, 5,
|
|
"Should have polled waking future five times"
|
|
);
|
|
assert_eq!(
|
|
dummy_poll_2.0, 1,
|
|
"Should have polled second dummy future only once"
|
|
);
|
|
assert_eq!(
|
|
wake_on_poll_2.0, 5,
|
|
"Should have polled second waking future five times"
|
|
);
|
|
});
|
|
|
|
assert_eq!(wake_count, 10, "Should have woken on poll five times")
|
|
}
|
|
}
|