It joins
This commit is contained in:
commit
67c7725dba
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
/target
|
||||
Cargo.lock
|
8
Cargo.toml
Normal file
8
Cargo.toml
Normal file
|
@ -0,0 +1,8 @@
|
|||
[package]
|
||||
name = "async-join"
|
||||
version = "0.1.0"
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
315
src/lib.rs
Normal file
315
src/lib.rs
Normal file
|
@ -0,0 +1,315 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::{Context, Poll, Wake, Waker},
|
||||
};
|
||||
|
||||
const MAX_CONSECUTIVE_POLLS: i32 = 16;
|
||||
|
||||
pub fn join_all<T>(futures: Vec<T>) -> JoinAll<T>
|
||||
where
|
||||
T: Future + Unpin,
|
||||
T::Output: Unpin,
|
||||
{
|
||||
JoinAll {
|
||||
futures: futures
|
||||
.into_iter()
|
||||
.map(|fut| (Arc::new(AtomicBool::new(true)), JoinState::Pending(fut)))
|
||||
.collect(),
|
||||
next_poll_index: 0,
|
||||
total_complete: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JoinAll<T>
|
||||
where
|
||||
T: Future,
|
||||
{
|
||||
futures: Vec<(Arc<AtomicBool>, JoinState<T>)>,
|
||||
next_poll_index: usize,
|
||||
total_complete: usize,
|
||||
}
|
||||
|
||||
struct SmartWaker {
|
||||
waker: Waker,
|
||||
marker: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Wake for SmartWaker {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.marker.store(true, Ordering::Release);
|
||||
self.waker.clone().wake();
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.marker.store(true, Ordering::Release);
|
||||
self.waker.wake_by_ref();
|
||||
}
|
||||
}
|
||||
|
||||
enum JoinState<T>
|
||||
where
|
||||
T: Future,
|
||||
{
|
||||
Polling,
|
||||
Pending(T),
|
||||
Complete(<T as Future>::Output),
|
||||
}
|
||||
|
||||
impl<T> JoinState<T>
|
||||
where
|
||||
T: Future + Unpin,
|
||||
{
|
||||
fn take(&mut self) -> JoinState<T> {
|
||||
std::mem::replace(self, JoinState::Polling)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for JoinAll<T>
|
||||
where
|
||||
T: Future + Unpin,
|
||||
<T as Future>::Output: Unpin,
|
||||
{
|
||||
type Output = Vec<<T as Future>::Output>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let futures_len = self.futures.len();
|
||||
let current_index = self.next_poll_index % futures_len;
|
||||
|
||||
let mut poll_count = 0;
|
||||
let mut total_complete = self.total_complete;
|
||||
|
||||
let (beginning, ending) = self.futures.split_at_mut(current_index);
|
||||
|
||||
let futures_iter = ending.iter_mut().chain(beginning.iter_mut()).enumerate();
|
||||
|
||||
for (idx, (pollable, join_state)) in futures_iter {
|
||||
if poll_count < MAX_CONSECUTIVE_POLLS && pollable.load(Ordering::Acquire) {
|
||||
pollable.store(false, Ordering::Release);
|
||||
let new_state = match join_state.take() {
|
||||
JoinState::Pending(mut fut) => {
|
||||
poll_count += 1;
|
||||
|
||||
let waker = Arc::new(SmartWaker {
|
||||
waker: cx.waker().clone(),
|
||||
marker: Arc::clone(pollable),
|
||||
})
|
||||
.into();
|
||||
|
||||
let mut smart_context = Context::from_waker(&waker);
|
||||
|
||||
match Pin::new(&mut fut).poll(&mut smart_context) {
|
||||
Poll::Ready(output) => {
|
||||
total_complete += 1;
|
||||
JoinState::Complete(output)
|
||||
}
|
||||
Poll::Pending => JoinState::Pending(fut),
|
||||
}
|
||||
}
|
||||
otherwise => otherwise,
|
||||
};
|
||||
|
||||
*join_state = new_state;
|
||||
} else if poll_count >= MAX_CONSECUTIVE_POLLS {
|
||||
self.next_poll_index = (idx + current_index) % futures_len;
|
||||
self.total_complete = total_complete;
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
self.total_complete = total_complete;
|
||||
|
||||
if self.total_complete == self.futures.len() {
|
||||
return Poll::Ready(self.futures.iter_mut().map(|(_, join_state)| {
|
||||
match join_state.take() {
|
||||
JoinState::Complete(output) => output,
|
||||
_ => unreachable!("All futures should be complete when total_complete matches inner length"),
|
||||
}
|
||||
}).collect());
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{join_all, MAX_CONSECUTIVE_POLLS};
|
||||
use std::{
|
||||
future::{ready, Future},
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
task::{Context, Poll, Wake},
|
||||
};
|
||||
|
||||
struct Faker {
|
||||
wake_count: Mutex<usize>,
|
||||
}
|
||||
|
||||
impl Wake for Faker {
|
||||
fn wake(self: Arc<Self>) {
|
||||
*self.wake_count.lock().unwrap() += 1;
|
||||
}
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
*self.wake_count.lock().unwrap() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn with_context(f: impl FnOnce(&mut Context<'_>)) -> usize {
|
||||
let ctx = Arc::new(Faker {
|
||||
wake_count: Mutex::new(0),
|
||||
});
|
||||
let waker = Arc::clone(&ctx).into();
|
||||
|
||||
(f)(&mut Context::from_waker(&waker));
|
||||
|
||||
let x = *ctx.wake_count.lock().unwrap();
|
||||
x
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_produces_a_vec_on_success() {
|
||||
let wake_count = with_context(move |ctx| {
|
||||
let futures = vec![ready(1), ready(2), ready(3), ready(4), ready(5)];
|
||||
let mut fut = join_all(futures);
|
||||
|
||||
match Pin::new(&mut fut).poll(ctx) {
|
||||
Poll::Ready(vec) => assert_eq!(vec, vec![1, 2, 3, 4, 5]),
|
||||
Poll::Pending => panic!("Should have returned Ready"),
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(wake_count, 0, "Shouldn't have caused any wakes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_returns_pending_and_wakes_once_for_too_many_futures() {
|
||||
let wake_count = with_context(move |ctx| {
|
||||
let futures = (0..(MAX_CONSECUTIVE_POLLS + 1))
|
||||
.map(ready)
|
||||
.collect::<Vec<_>>();
|
||||
let mut fut = join_all(futures);
|
||||
|
||||
if let Poll::Ready(_) = Pin::new(&mut fut).poll(ctx) {
|
||||
panic!("Shouldn't have returned ready on first poll");
|
||||
}
|
||||
|
||||
match Pin::new(&mut fut).poll(ctx) {
|
||||
Poll::Ready(vec) => {
|
||||
assert_eq!(vec, (0..MAX_CONSECUTIVE_POLLS + 1).collect::<Vec<_>>())
|
||||
}
|
||||
Poll::Pending => panic!("Should have completed on second poll"),
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(wake_count, 1, "Should have woken once");
|
||||
}
|
||||
|
||||
struct WakeOnPoll(usize);
|
||||
|
||||
impl Future for WakeOnPoll {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.0 += 1;
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn waker_is_woken() {
|
||||
let wake_count = with_context(move |ctx| {
|
||||
let mut wake_on_poll = WakeOnPoll(0);
|
||||
let futures = vec![&mut wake_on_poll];
|
||||
|
||||
let mut fut = join_all(futures);
|
||||
|
||||
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
||||
panic!("Future should be pending forever");
|
||||
}
|
||||
|
||||
assert_eq!(wake_on_poll.0, 1, "Should have polled future once");
|
||||
});
|
||||
|
||||
assert_eq!(wake_count, 1, "Should have woken on poll")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poll_woken_future() {
|
||||
let wake_count = with_context(move |ctx| {
|
||||
let mut wake_on_poll = WakeOnPoll(0);
|
||||
let futures = vec![&mut wake_on_poll];
|
||||
|
||||
let mut fut = join_all(futures);
|
||||
|
||||
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
||||
panic!("Future should be pending forever");
|
||||
}
|
||||
|
||||
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
||||
panic!("Future should be pending forever");
|
||||
}
|
||||
|
||||
assert_eq!(wake_on_poll.0, 2, "Should have polled future twice");
|
||||
});
|
||||
|
||||
assert_eq!(wake_count, 2, "Should have woken on poll twice")
|
||||
}
|
||||
|
||||
struct DummyPoll(usize);
|
||||
|
||||
impl Future for DummyPoll {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.0 += 1;
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dont_poll_unwoken_future() {
|
||||
let wake_count = with_context(move |ctx| {
|
||||
let mut dummy_poll = DummyPoll(0);
|
||||
let mut wake_on_poll = WakeOnPoll(0);
|
||||
let mut dummy_poll_2 = DummyPoll(0);
|
||||
let mut wake_on_poll_2 = WakeOnPoll(0);
|
||||
let futures = vec![
|
||||
&mut dummy_poll as &mut (dyn Future<Output = ()> + Unpin),
|
||||
&mut wake_on_poll,
|
||||
&mut dummy_poll_2,
|
||||
&mut wake_on_poll_2,
|
||||
];
|
||||
|
||||
let mut fut = join_all(futures);
|
||||
|
||||
for _ in 0..5 {
|
||||
if Pin::new(&mut fut).poll(ctx).is_ready() {
|
||||
panic!("Future should be pending forever");
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(dummy_poll.0, 1, "Should have polled dummy future only once");
|
||||
assert_eq!(
|
||||
wake_on_poll.0, 5,
|
||||
"Should have polled waking future five times"
|
||||
);
|
||||
assert_eq!(
|
||||
dummy_poll_2.0, 1,
|
||||
"Should have polled second dummy future only once"
|
||||
);
|
||||
assert_eq!(
|
||||
wake_on_poll_2.0, 5,
|
||||
"Should have polled second waking future five times"
|
||||
);
|
||||
});
|
||||
|
||||
assert_eq!(wake_count, 10, "Should have woken on poll five times")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue