315 lines
7.4 KiB
Rust
315 lines
7.4 KiB
Rust
use std::{
|
|
collections::{BTreeMap, VecDeque},
|
|
sync::{Arc, Mutex, Weak},
|
|
};
|
|
|
|
use jive::task::sync::{AbortHandle, JoinError, JoinHandle};
|
|
|
|
pub struct JoinSet<V> {
|
|
index: u64,
|
|
map: JoinMap<u64, V>,
|
|
}
|
|
|
|
pub struct JoinMap<K, V> {
|
|
woken: Arc<Mutex<VecDeque<Weak<K>>>>,
|
|
handles: BTreeMap<Arc<K>, JoinHandle<V>>,
|
|
unpolled: BTreeMap<Arc<K>, JoinHandle<V>>,
|
|
abort_on_drop: bool,
|
|
}
|
|
|
|
impl<V> JoinSet<V> {
|
|
pub fn new() -> Self {
|
|
JoinSet {
|
|
index: 0,
|
|
map: JoinMap::new(),
|
|
}
|
|
}
|
|
|
|
pub fn detach_on_drop(mut self) -> Self {
|
|
self.map = self.map.detach_on_drop();
|
|
self
|
|
}
|
|
|
|
pub fn spawn<Fut>(&mut self, future: Fut) -> AbortHandle
|
|
where
|
|
V: Send + 'static,
|
|
Fut: std::future::Future<Output = V> + Send + 'static,
|
|
{
|
|
let index = self.index;
|
|
self.index += 1;
|
|
|
|
let (abort_handle, _) = self.map.insert(index, future);
|
|
|
|
abort_handle
|
|
}
|
|
|
|
pub fn spawn_local<Fut>(&mut self, future: Fut) -> AbortHandle
|
|
where
|
|
V: Send + 'static,
|
|
Fut: std::future::Future<Output = V> + Send + 'static,
|
|
{
|
|
let index = self.index;
|
|
self.index += 1;
|
|
|
|
let (abort_handle, _) = self.map.insert_local(index, future);
|
|
|
|
abort_handle
|
|
}
|
|
|
|
pub fn spawn_blocking<F>(&mut self, func: F)
|
|
where
|
|
V: Send + 'static,
|
|
F: Fn() -> V + Send + 'static,
|
|
{
|
|
let index = self.index;
|
|
self.index += 1;
|
|
|
|
self.map.insert_blocking(index, func);
|
|
}
|
|
|
|
// Cancel Safety
|
|
//
|
|
// This future is strictly cancel-safe.
|
|
pub async fn join_next(&mut self) -> Option<Result<V, JoinError>> {
|
|
self.map.join_next().await.map(|(_, res)| res)
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.map.len()
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.map.is_empty()
|
|
}
|
|
}
|
|
|
|
impl<K, V> JoinMap<K, V> {
|
|
pub fn new() -> Self {
|
|
JoinMap {
|
|
woken: Arc::new(Mutex::new(VecDeque::new())),
|
|
handles: BTreeMap::new(),
|
|
unpolled: BTreeMap::new(),
|
|
abort_on_drop: true,
|
|
}
|
|
}
|
|
|
|
pub fn detach_on_drop(mut self) -> Self {
|
|
self.abort_on_drop = false;
|
|
self
|
|
}
|
|
|
|
pub fn insert<Fut>(&mut self, key: K, future: Fut) -> (AbortHandle, Option<JoinHandle<V>>)
|
|
where
|
|
K: Ord + Send + Sync + 'static,
|
|
V: Send + 'static,
|
|
Fut: std::future::Future<Output = V> + Send + 'static,
|
|
{
|
|
let key = Arc::new(key);
|
|
|
|
let out = self.remove(&key);
|
|
|
|
let handle = jive::spawn(future);
|
|
|
|
let abort_handle = handle.abort_handle();
|
|
|
|
self.unpolled.insert(key, handle);
|
|
|
|
(abort_handle, out)
|
|
}
|
|
|
|
pub fn insert_local<Fut>(&mut self, key: K, future: Fut) -> (AbortHandle, Option<JoinHandle<V>>)
|
|
where
|
|
K: Ord + Send + Sync + 'static,
|
|
V: Send + 'static,
|
|
Fut: std::future::Future<Output = V> + 'static,
|
|
{
|
|
let key = Arc::new(key);
|
|
|
|
let out = self.remove(&key);
|
|
|
|
let handle = jive::spawn_local(future);
|
|
|
|
let abort_handle = handle.abort_handle();
|
|
|
|
self.unpolled.insert(key, handle);
|
|
|
|
(abort_handle, out)
|
|
}
|
|
|
|
pub fn insert_blocking<F>(&mut self, key: K, func: F) -> Option<JoinHandle<V>>
|
|
where
|
|
K: Ord + Send + Sync + 'static,
|
|
V: Send + 'static,
|
|
F: Fn() -> V + Send + 'static,
|
|
{
|
|
let key = Arc::new(key);
|
|
|
|
let out = self.remove(&key);
|
|
|
|
let handle = jive::spawn_blocking(func);
|
|
|
|
self.unpolled.insert(key, handle);
|
|
|
|
out
|
|
}
|
|
|
|
pub fn remove<Q>(&mut self, key: &Q) -> Option<JoinHandle<V>>
|
|
where
|
|
Arc<K>: std::borrow::Borrow<Q> + Ord,
|
|
Q: Ord + ?Sized,
|
|
{
|
|
if let Some(handle) = self.unpolled.remove(key) {
|
|
return Some(handle);
|
|
}
|
|
|
|
if let Some(handle) = self.handles.remove(key) {
|
|
return Some(handle);
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
// Cancel Safety
|
|
//
|
|
// This future is strictly cancel-safe.
|
|
pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)>
|
|
where
|
|
K: Ord + Send + Sync + 'static,
|
|
{
|
|
JoinNext { map: self }.await
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.handles.len() + self.unpolled.len()
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.handles.is_empty() && self.unpolled.is_empty()
|
|
}
|
|
}
|
|
|
|
struct JoinNext<'a, K, V> {
|
|
map: &'a mut JoinMap<K, V>,
|
|
}
|
|
|
|
struct JoinMapWaker<K> {
|
|
key: Weak<K>,
|
|
woken: Arc<Mutex<VecDeque<Weak<K>>>>,
|
|
inner: std::task::Waker,
|
|
}
|
|
|
|
impl<K> std::task::Wake for JoinMapWaker<K> {
|
|
fn wake(self: Arc<Self>) {
|
|
self.wake_by_ref()
|
|
}
|
|
|
|
fn wake_by_ref(self: &Arc<Self>) {
|
|
self.woken.lock().unwrap().push_back(self.key.clone());
|
|
self.inner.wake_by_ref();
|
|
}
|
|
}
|
|
|
|
impl<'a, K, V> JoinNext<'a, K, V> {
|
|
fn next_woken(&mut self) -> Option<Arc<K>> {
|
|
let mut guard = self.map.woken.lock().unwrap();
|
|
|
|
while let Some(key) = guard.pop_front() {
|
|
if let Some(key) = key.upgrade() {
|
|
return Some(key);
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
fn handle(&mut self, key: &Arc<K>) -> Option<&mut JoinHandle<V>>
|
|
where
|
|
K: Ord,
|
|
{
|
|
self.map.handles.get_mut(key)
|
|
}
|
|
|
|
fn remove(&mut self, key: &Arc<K>)
|
|
where
|
|
K: Ord,
|
|
{
|
|
self.map.handles.remove(key);
|
|
}
|
|
}
|
|
|
|
impl<'a, K, V> std::future::Future for JoinNext<'a, K, V>
|
|
where
|
|
K: Ord + Send + Sync + 'static,
|
|
{
|
|
type Output = Option<(K, Result<V, JoinError>)>;
|
|
|
|
fn poll(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Self::Output> {
|
|
let this = self.get_mut();
|
|
|
|
if this.map.handles.is_empty() && this.map.unpolled.is_empty() {
|
|
return std::task::Poll::Ready(None);
|
|
}
|
|
|
|
while let Some(key) = this.next_woken() {
|
|
if let Some(handle) = this.handle(&key) {
|
|
if let std::task::Poll::Ready(value) = std::pin::Pin::new(handle).poll(cx) {
|
|
this.remove(&key);
|
|
|
|
let key = Arc::into_inner(key).expect("No other holders");
|
|
|
|
return std::task::Poll::Ready(Some((key, value)));
|
|
}
|
|
}
|
|
}
|
|
|
|
while let Some((key, mut handle)) = this.map.unpolled.pop_first() {
|
|
let waker = Arc::new(JoinMapWaker {
|
|
key: Arc::downgrade(&key),
|
|
woken: this.map.woken.clone(),
|
|
inner: cx.waker().clone(),
|
|
})
|
|
.into();
|
|
|
|
let mut cx = std::task::Context::from_waker(&waker);
|
|
|
|
if let std::task::Poll::Ready(value) = std::pin::Pin::new(&mut handle).poll(&mut cx) {
|
|
let key = Arc::into_inner(key).expect("No other holders");
|
|
|
|
return std::task::Poll::Ready(Some((key, value)));
|
|
}
|
|
|
|
this.map.handles.insert(key, handle);
|
|
}
|
|
|
|
std::task::Poll::Pending
|
|
}
|
|
}
|
|
|
|
impl<K, V> Default for JoinMap<K, V> {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl<V> Default for JoinSet<V> {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl<K, V> Drop for JoinMap<K, V> {
|
|
fn drop(&mut self) {
|
|
if self.abort_on_drop {
|
|
for handle in self.unpolled.values() {
|
|
handle.abort();
|
|
}
|
|
for handle in self.handles.values() {
|
|
handle.abort();
|
|
}
|
|
}
|
|
}
|
|
}
|