From 67c7725dba6c44261e8ee581aef090276d88cbc2 Mon Sep 17 00:00:00 2001 From: "Aode (Lion)" Date: Thu, 9 Sep 2021 20:11:03 -0500 Subject: [PATCH] It joins --- .gitignore | 2 + Cargo.toml | 8 ++ src/lib.rs | 315 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 325 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..526e6da --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "async-join" +version = "0.1.0" +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..48ee95b --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,315 @@ +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll, Wake, Waker}, +}; + +const MAX_CONSECUTIVE_POLLS: i32 = 16; + +pub fn join_all(futures: Vec) -> JoinAll +where + T: Future + Unpin, + T::Output: Unpin, +{ + JoinAll { + futures: futures + .into_iter() + .map(|fut| (Arc::new(AtomicBool::new(true)), JoinState::Pending(fut))) + .collect(), + next_poll_index: 0, + total_complete: 0, + } +} + +pub struct JoinAll +where + T: Future, +{ + futures: Vec<(Arc, JoinState)>, + next_poll_index: usize, + total_complete: usize, +} + +struct SmartWaker { + waker: Waker, + marker: Arc, +} + +impl Wake for SmartWaker { + fn wake(self: Arc) { + self.marker.store(true, Ordering::Release); + self.waker.clone().wake(); + } + + fn wake_by_ref(self: &Arc) { + self.marker.store(true, Ordering::Release); + self.waker.wake_by_ref(); + } +} + +enum JoinState +where + T: Future, +{ + Polling, + Pending(T), + Complete(::Output), +} + +impl JoinState +where + T: Future + Unpin, +{ + fn take(&mut self) -> JoinState { + std::mem::replace(self, JoinState::Polling) + } +} + +impl Future for JoinAll +where + T: Future + Unpin, + ::Output: Unpin, +{ + type Output = Vec<::Output>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let futures_len = self.futures.len(); + let current_index = self.next_poll_index % futures_len; + + let mut poll_count = 0; + let mut total_complete = self.total_complete; + + let (beginning, ending) = self.futures.split_at_mut(current_index); + + let futures_iter = ending.iter_mut().chain(beginning.iter_mut()).enumerate(); + + for (idx, (pollable, join_state)) in futures_iter { + if poll_count < MAX_CONSECUTIVE_POLLS && pollable.load(Ordering::Acquire) { + pollable.store(false, Ordering::Release); + let new_state = match join_state.take() { + JoinState::Pending(mut fut) => { + poll_count += 1; + + let waker = Arc::new(SmartWaker { + waker: cx.waker().clone(), + marker: Arc::clone(pollable), + }) + .into(); + + let mut smart_context = Context::from_waker(&waker); + + match Pin::new(&mut fut).poll(&mut smart_context) { + Poll::Ready(output) => { + total_complete += 1; + JoinState::Complete(output) + } + Poll::Pending => JoinState::Pending(fut), + } + } + otherwise => otherwise, + }; + + *join_state = new_state; + } else if poll_count >= MAX_CONSECUTIVE_POLLS { + self.next_poll_index = (idx + current_index) % futures_len; + self.total_complete = total_complete; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + + self.total_complete = total_complete; + + if self.total_complete == self.futures.len() { + return Poll::Ready(self.futures.iter_mut().map(|(_, join_state)| { + match join_state.take() { + JoinState::Complete(output) => output, + _ => unreachable!("All futures should be complete when total_complete matches inner length"), + } + }).collect()); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::{join_all, MAX_CONSECUTIVE_POLLS}; + use std::{ + future::{ready, Future}, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll, Wake}, + }; + + struct Faker { + wake_count: Mutex, + } + + impl Wake for Faker { + fn wake(self: Arc) { + *self.wake_count.lock().unwrap() += 1; + } + fn wake_by_ref(self: &Arc) { + *self.wake_count.lock().unwrap() += 1; + } + } + + fn with_context(f: impl FnOnce(&mut Context<'_>)) -> usize { + let ctx = Arc::new(Faker { + wake_count: Mutex::new(0), + }); + let waker = Arc::clone(&ctx).into(); + + (f)(&mut Context::from_waker(&waker)); + + let x = *ctx.wake_count.lock().unwrap(); + x + } + + #[test] + fn it_produces_a_vec_on_success() { + let wake_count = with_context(move |ctx| { + let futures = vec![ready(1), ready(2), ready(3), ready(4), ready(5)]; + let mut fut = join_all(futures); + + match Pin::new(&mut fut).poll(ctx) { + Poll::Ready(vec) => assert_eq!(vec, vec![1, 2, 3, 4, 5]), + Poll::Pending => panic!("Should have returned Ready"), + } + }); + + assert_eq!(wake_count, 0, "Shouldn't have caused any wakes"); + } + + #[test] + fn it_returns_pending_and_wakes_once_for_too_many_futures() { + let wake_count = with_context(move |ctx| { + let futures = (0..(MAX_CONSECUTIVE_POLLS + 1)) + .map(ready) + .collect::>(); + let mut fut = join_all(futures); + + if let Poll::Ready(_) = Pin::new(&mut fut).poll(ctx) { + panic!("Shouldn't have returned ready on first poll"); + } + + match Pin::new(&mut fut).poll(ctx) { + Poll::Ready(vec) => { + assert_eq!(vec, (0..MAX_CONSECUTIVE_POLLS + 1).collect::>()) + } + Poll::Pending => panic!("Should have completed on second poll"), + } + }); + + assert_eq!(wake_count, 1, "Should have woken once"); + } + + struct WakeOnPoll(usize); + + impl Future for WakeOnPoll { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0 += 1; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + #[test] + fn waker_is_woken() { + let wake_count = with_context(move |ctx| { + let mut wake_on_poll = WakeOnPoll(0); + let futures = vec![&mut wake_on_poll]; + + let mut fut = join_all(futures); + + if Pin::new(&mut fut).poll(ctx).is_ready() { + panic!("Future should be pending forever"); + } + + assert_eq!(wake_on_poll.0, 1, "Should have polled future once"); + }); + + assert_eq!(wake_count, 1, "Should have woken on poll") + } + + #[test] + fn poll_woken_future() { + let wake_count = with_context(move |ctx| { + let mut wake_on_poll = WakeOnPoll(0); + let futures = vec![&mut wake_on_poll]; + + let mut fut = join_all(futures); + + if Pin::new(&mut fut).poll(ctx).is_ready() { + panic!("Future should be pending forever"); + } + + if Pin::new(&mut fut).poll(ctx).is_ready() { + panic!("Future should be pending forever"); + } + + assert_eq!(wake_on_poll.0, 2, "Should have polled future twice"); + }); + + assert_eq!(wake_count, 2, "Should have woken on poll twice") + } + + struct DummyPoll(usize); + + impl Future for DummyPoll { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + self.0 += 1; + Poll::Pending + } + } + + #[test] + fn dont_poll_unwoken_future() { + let wake_count = with_context(move |ctx| { + let mut dummy_poll = DummyPoll(0); + let mut wake_on_poll = WakeOnPoll(0); + let mut dummy_poll_2 = DummyPoll(0); + let mut wake_on_poll_2 = WakeOnPoll(0); + let futures = vec![ + &mut dummy_poll as &mut (dyn Future + Unpin), + &mut wake_on_poll, + &mut dummy_poll_2, + &mut wake_on_poll_2, + ]; + + let mut fut = join_all(futures); + + for _ in 0..5 { + if Pin::new(&mut fut).poll(ctx).is_ready() { + panic!("Future should be pending forever"); + } + } + + assert_eq!(dummy_poll.0, 1, "Should have polled dummy future only once"); + assert_eq!( + wake_on_poll.0, 5, + "Should have polled waking future five times" + ); + assert_eq!( + dummy_poll_2.0, 1, + "Should have polled second dummy future only once" + ); + assert_eq!( + wake_on_poll_2.0, 5, + "Should have polled second waking future five times" + ); + }); + + assert_eq!(wake_count, 10, "Should have woken on poll five times") + } +}