join-all/src/lib.rs
2021-09-09 20:11:03 -05:00

316 lines
8.9 KiB
Rust

use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Wake, Waker},
};
const MAX_CONSECUTIVE_POLLS: i32 = 16;
pub fn join_all<T>(futures: Vec<T>) -> JoinAll<T>
where
T: Future + Unpin,
T::Output: Unpin,
{
JoinAll {
futures: futures
.into_iter()
.map(|fut| (Arc::new(AtomicBool::new(true)), JoinState::Pending(fut)))
.collect(),
next_poll_index: 0,
total_complete: 0,
}
}
pub struct JoinAll<T>
where
T: Future,
{
futures: Vec<(Arc<AtomicBool>, JoinState<T>)>,
next_poll_index: usize,
total_complete: usize,
}
struct SmartWaker {
waker: Waker,
marker: Arc<AtomicBool>,
}
impl Wake for SmartWaker {
fn wake(self: Arc<Self>) {
self.marker.store(true, Ordering::Release);
self.waker.clone().wake();
}
fn wake_by_ref(self: &Arc<Self>) {
self.marker.store(true, Ordering::Release);
self.waker.wake_by_ref();
}
}
enum JoinState<T>
where
T: Future,
{
Polling,
Pending(T),
Complete(<T as Future>::Output),
}
impl<T> JoinState<T>
where
T: Future + Unpin,
{
fn take(&mut self) -> JoinState<T> {
std::mem::replace(self, JoinState::Polling)
}
}
impl<T> Future for JoinAll<T>
where
T: Future + Unpin,
<T as Future>::Output: Unpin,
{
type Output = Vec<<T as Future>::Output>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let futures_len = self.futures.len();
let current_index = self.next_poll_index % futures_len;
let mut poll_count = 0;
let mut total_complete = self.total_complete;
let (beginning, ending) = self.futures.split_at_mut(current_index);
let futures_iter = ending.iter_mut().chain(beginning.iter_mut()).enumerate();
for (idx, (pollable, join_state)) in futures_iter {
if poll_count < MAX_CONSECUTIVE_POLLS && pollable.load(Ordering::Acquire) {
pollable.store(false, Ordering::Release);
let new_state = match join_state.take() {
JoinState::Pending(mut fut) => {
poll_count += 1;
let waker = Arc::new(SmartWaker {
waker: cx.waker().clone(),
marker: Arc::clone(pollable),
})
.into();
let mut smart_context = Context::from_waker(&waker);
match Pin::new(&mut fut).poll(&mut smart_context) {
Poll::Ready(output) => {
total_complete += 1;
JoinState::Complete(output)
}
Poll::Pending => JoinState::Pending(fut),
}
}
otherwise => otherwise,
};
*join_state = new_state;
} else if poll_count >= MAX_CONSECUTIVE_POLLS {
self.next_poll_index = (idx + current_index) % futures_len;
self.total_complete = total_complete;
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
self.total_complete = total_complete;
if self.total_complete == self.futures.len() {
return Poll::Ready(self.futures.iter_mut().map(|(_, join_state)| {
match join_state.take() {
JoinState::Complete(output) => output,
_ => unreachable!("All futures should be complete when total_complete matches inner length"),
}
}).collect());
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::{join_all, MAX_CONSECUTIVE_POLLS};
use std::{
future::{ready, Future},
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Wake},
};
struct Faker {
wake_count: Mutex<usize>,
}
impl Wake for Faker {
fn wake(self: Arc<Self>) {
*self.wake_count.lock().unwrap() += 1;
}
fn wake_by_ref(self: &Arc<Self>) {
*self.wake_count.lock().unwrap() += 1;
}
}
fn with_context(f: impl FnOnce(&mut Context<'_>)) -> usize {
let ctx = Arc::new(Faker {
wake_count: Mutex::new(0),
});
let waker = Arc::clone(&ctx).into();
(f)(&mut Context::from_waker(&waker));
let x = *ctx.wake_count.lock().unwrap();
x
}
#[test]
fn it_produces_a_vec_on_success() {
let wake_count = with_context(move |ctx| {
let futures = vec![ready(1), ready(2), ready(3), ready(4), ready(5)];
let mut fut = join_all(futures);
match Pin::new(&mut fut).poll(ctx) {
Poll::Ready(vec) => assert_eq!(vec, vec![1, 2, 3, 4, 5]),
Poll::Pending => panic!("Should have returned Ready"),
}
});
assert_eq!(wake_count, 0, "Shouldn't have caused any wakes");
}
#[test]
fn it_returns_pending_and_wakes_once_for_too_many_futures() {
let wake_count = with_context(move |ctx| {
let futures = (0..(MAX_CONSECUTIVE_POLLS + 1))
.map(ready)
.collect::<Vec<_>>();
let mut fut = join_all(futures);
if let Poll::Ready(_) = Pin::new(&mut fut).poll(ctx) {
panic!("Shouldn't have returned ready on first poll");
}
match Pin::new(&mut fut).poll(ctx) {
Poll::Ready(vec) => {
assert_eq!(vec, (0..MAX_CONSECUTIVE_POLLS + 1).collect::<Vec<_>>())
}
Poll::Pending => panic!("Should have completed on second poll"),
}
});
assert_eq!(wake_count, 1, "Should have woken once");
}
struct WakeOnPoll(usize);
impl Future for WakeOnPoll {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0 += 1;
cx.waker().wake_by_ref();
Poll::Pending
}
}
#[test]
fn waker_is_woken() {
let wake_count = with_context(move |ctx| {
let mut wake_on_poll = WakeOnPoll(0);
let futures = vec![&mut wake_on_poll];
let mut fut = join_all(futures);
if Pin::new(&mut fut).poll(ctx).is_ready() {
panic!("Future should be pending forever");
}
assert_eq!(wake_on_poll.0, 1, "Should have polled future once");
});
assert_eq!(wake_count, 1, "Should have woken on poll")
}
#[test]
fn poll_woken_future() {
let wake_count = with_context(move |ctx| {
let mut wake_on_poll = WakeOnPoll(0);
let futures = vec![&mut wake_on_poll];
let mut fut = join_all(futures);
if Pin::new(&mut fut).poll(ctx).is_ready() {
panic!("Future should be pending forever");
}
if Pin::new(&mut fut).poll(ctx).is_ready() {
panic!("Future should be pending forever");
}
assert_eq!(wake_on_poll.0, 2, "Should have polled future twice");
});
assert_eq!(wake_count, 2, "Should have woken on poll twice")
}
struct DummyPoll(usize);
impl Future for DummyPoll {
type Output = ();
fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
self.0 += 1;
Poll::Pending
}
}
#[test]
fn dont_poll_unwoken_future() {
let wake_count = with_context(move |ctx| {
let mut dummy_poll = DummyPoll(0);
let mut wake_on_poll = WakeOnPoll(0);
let mut dummy_poll_2 = DummyPoll(0);
let mut wake_on_poll_2 = WakeOnPoll(0);
let futures = vec![
&mut dummy_poll as &mut (dyn Future<Output = ()> + Unpin),
&mut wake_on_poll,
&mut dummy_poll_2,
&mut wake_on_poll_2,
];
let mut fut = join_all(futures);
for _ in 0..5 {
if Pin::new(&mut fut).poll(ctx).is_ready() {
panic!("Future should be pending forever");
}
}
assert_eq!(dummy_poll.0, 1, "Should have polled dummy future only once");
assert_eq!(
wake_on_poll.0, 5,
"Should have polled waking future five times"
);
assert_eq!(
dummy_poll_2.0, 1,
"Should have polled second dummy future only once"
);
assert_eq!(
wake_on_poll_2.0, 5,
"Should have polled second waking future five times"
);
});
assert_eq!(wake_count, 10, "Should have woken on poll five times")
}
}