use std::{ future::Future, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, task::{Context, Poll, Wake, Waker}, }; const MAX_CONSECUTIVE_POLLS: i32 = 16; pub fn join_all(futures: Vec) -> JoinAll where T: Future + Unpin, T::Output: Unpin, { JoinAll { futures: futures .into_iter() .map(|fut| (Arc::new(AtomicBool::new(true)), JoinState::Pending(fut))) .collect(), next_poll_index: 0, total_complete: 0, } } pub struct JoinAll where T: Future, { futures: Vec<(Arc, JoinState)>, next_poll_index: usize, total_complete: usize, } struct SmartWaker { waker: Waker, marker: Arc, } impl Wake for SmartWaker { fn wake(self: Arc) { self.marker.store(true, Ordering::Release); self.waker.clone().wake(); } fn wake_by_ref(self: &Arc) { self.marker.store(true, Ordering::Release); self.waker.wake_by_ref(); } } enum JoinState where T: Future, { Polling, Pending(T), Complete(::Output), } impl JoinState where T: Future + Unpin, { fn take(&mut self) -> JoinState { std::mem::replace(self, JoinState::Polling) } } impl Future for JoinAll where T: Future + Unpin, ::Output: Unpin, { type Output = Vec<::Output>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let futures_len = self.futures.len(); let current_index = self.next_poll_index % futures_len; let mut poll_count = 0; let mut total_complete = self.total_complete; let (beginning, ending) = self.futures.split_at_mut(current_index); let futures_iter = ending.iter_mut().chain(beginning.iter_mut()).enumerate(); for (idx, (pollable, join_state)) in futures_iter { if poll_count < MAX_CONSECUTIVE_POLLS && pollable.load(Ordering::Acquire) { pollable.store(false, Ordering::Release); let new_state = match join_state.take() { JoinState::Pending(mut fut) => { poll_count += 1; let waker = Arc::new(SmartWaker { waker: cx.waker().clone(), marker: Arc::clone(pollable), }) .into(); let mut smart_context = Context::from_waker(&waker); match Pin::new(&mut fut).poll(&mut smart_context) { Poll::Ready(output) => { total_complete += 1; JoinState::Complete(output) } Poll::Pending => JoinState::Pending(fut), } } otherwise => otherwise, }; *join_state = new_state; } else if poll_count >= MAX_CONSECUTIVE_POLLS { self.next_poll_index = (idx + current_index) % futures_len; self.total_complete = total_complete; cx.waker().wake_by_ref(); return Poll::Pending; } } self.total_complete = total_complete; if self.total_complete == self.futures.len() { return Poll::Ready(self.futures.iter_mut().map(|(_, join_state)| { match join_state.take() { JoinState::Complete(output) => output, _ => unreachable!("All futures should be complete when total_complete matches inner length"), } }).collect()); } Poll::Pending } } #[cfg(test)] mod tests { use super::{join_all, MAX_CONSECUTIVE_POLLS}; use std::{ future::{ready, Future}, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Wake}, }; struct Faker { wake_count: Mutex, } impl Wake for Faker { fn wake(self: Arc) { *self.wake_count.lock().unwrap() += 1; } fn wake_by_ref(self: &Arc) { *self.wake_count.lock().unwrap() += 1; } } fn with_context(f: impl FnOnce(&mut Context<'_>)) -> usize { let ctx = Arc::new(Faker { wake_count: Mutex::new(0), }); let waker = Arc::clone(&ctx).into(); (f)(&mut Context::from_waker(&waker)); let x = *ctx.wake_count.lock().unwrap(); x } #[test] fn it_produces_a_vec_on_success() { let wake_count = with_context(move |ctx| { let futures = vec![ready(1), ready(2), ready(3), ready(4), ready(5)]; let mut fut = join_all(futures); match Pin::new(&mut fut).poll(ctx) { Poll::Ready(vec) => assert_eq!(vec, vec![1, 2, 3, 4, 5]), Poll::Pending => panic!("Should have returned Ready"), } }); assert_eq!(wake_count, 0, "Shouldn't have caused any wakes"); } #[test] fn it_returns_pending_and_wakes_once_for_too_many_futures() { let wake_count = with_context(move |ctx| { let futures = (0..(MAX_CONSECUTIVE_POLLS + 1)) .map(ready) .collect::>(); let mut fut = join_all(futures); if let Poll::Ready(_) = Pin::new(&mut fut).poll(ctx) { panic!("Shouldn't have returned ready on first poll"); } match Pin::new(&mut fut).poll(ctx) { Poll::Ready(vec) => { assert_eq!(vec, (0..MAX_CONSECUTIVE_POLLS + 1).collect::>()) } Poll::Pending => panic!("Should have completed on second poll"), } }); assert_eq!(wake_count, 1, "Should have woken once"); } struct WakeOnPoll(usize); impl Future for WakeOnPoll { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.0 += 1; cx.waker().wake_by_ref(); Poll::Pending } } #[test] fn waker_is_woken() { let wake_count = with_context(move |ctx| { let mut wake_on_poll = WakeOnPoll(0); let futures = vec![&mut wake_on_poll]; let mut fut = join_all(futures); if Pin::new(&mut fut).poll(ctx).is_ready() { panic!("Future should be pending forever"); } assert_eq!(wake_on_poll.0, 1, "Should have polled future once"); }); assert_eq!(wake_count, 1, "Should have woken on poll") } #[test] fn poll_woken_future() { let wake_count = with_context(move |ctx| { let mut wake_on_poll = WakeOnPoll(0); let futures = vec![&mut wake_on_poll]; let mut fut = join_all(futures); if Pin::new(&mut fut).poll(ctx).is_ready() { panic!("Future should be pending forever"); } if Pin::new(&mut fut).poll(ctx).is_ready() { panic!("Future should be pending forever"); } assert_eq!(wake_on_poll.0, 2, "Should have polled future twice"); }); assert_eq!(wake_count, 2, "Should have woken on poll twice") } struct DummyPoll(usize); impl Future for DummyPoll { type Output = (); fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { self.0 += 1; Poll::Pending } } #[test] fn dont_poll_unwoken_future() { let wake_count = with_context(move |ctx| { let mut dummy_poll = DummyPoll(0); let mut wake_on_poll = WakeOnPoll(0); let mut dummy_poll_2 = DummyPoll(0); let mut wake_on_poll_2 = WakeOnPoll(0); let futures = vec![ &mut dummy_poll as &mut (dyn Future + Unpin), &mut wake_on_poll, &mut dummy_poll_2, &mut wake_on_poll_2, ]; let mut fut = join_all(futures); for _ in 0..5 { if Pin::new(&mut fut).poll(ctx).is_ready() { panic!("Future should be pending forever"); } } assert_eq!(dummy_poll.0, 1, "Should have polled dummy future only once"); assert_eq!( wake_on_poll.0, 5, "Should have polled waking future five times" ); assert_eq!( dummy_poll_2.0, 1, "Should have polled second dummy future only once" ); assert_eq!( wake_on_poll_2.0, 5, "Should have polled second waking future five times" ); }); assert_eq!(wake_count, 10, "Should have woken on poll five times") } }