jive-joinmap/src/lib.rs

315 lines
7.4 KiB
Rust

use std::{
collections::{BTreeMap, VecDeque},
sync::{Arc, Mutex, Weak},
};
use jive::task::sync::{AbortHandle, JoinError, JoinHandle};
pub struct JoinSet<V> {
index: u64,
map: JoinMap<u64, V>,
}
pub struct JoinMap<K, V> {
woken: Arc<Mutex<VecDeque<Weak<K>>>>,
handles: BTreeMap<Arc<K>, JoinHandle<V>>,
unpolled: BTreeMap<Arc<K>, JoinHandle<V>>,
abort_on_drop: bool,
}
impl<V> JoinSet<V> {
pub fn new() -> Self {
JoinSet {
index: 0,
map: JoinMap::new(),
}
}
pub fn detach_on_drop(mut self) -> Self {
self.map = self.map.detach_on_drop();
self
}
pub fn spawn<Fut>(&mut self, future: Fut) -> AbortHandle
where
V: Send + 'static,
Fut: std::future::Future<Output = V> + Send + 'static,
{
let index = self.index;
self.index += 1;
let (abort_handle, _) = self.map.insert(index, future);
abort_handle
}
pub fn spawn_local<Fut>(&mut self, future: Fut) -> AbortHandle
where
V: Send + 'static,
Fut: std::future::Future<Output = V> + Send + 'static,
{
let index = self.index;
self.index += 1;
let (abort_handle, _) = self.map.insert_local(index, future);
abort_handle
}
pub fn spawn_blocking<F>(&mut self, func: F)
where
V: Send + 'static,
F: Fn() -> V + Send + 'static,
{
let index = self.index;
self.index += 1;
self.map.insert_blocking(index, func);
}
// Cancel Safety
//
// This future is strictly cancel-safe.
pub async fn join_next(&mut self) -> Option<Result<V, JoinError>> {
self.map.join_next().await.map(|(_, res)| res)
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
impl<K, V> JoinMap<K, V> {
pub fn new() -> Self {
JoinMap {
woken: Arc::new(Mutex::new(VecDeque::new())),
handles: BTreeMap::new(),
unpolled: BTreeMap::new(),
abort_on_drop: true,
}
}
pub fn detach_on_drop(mut self) -> Self {
self.abort_on_drop = false;
self
}
pub fn insert<Fut>(&mut self, key: K, future: Fut) -> (AbortHandle, Option<JoinHandle<V>>)
where
K: Ord + Send + Sync + 'static,
V: Send + 'static,
Fut: std::future::Future<Output = V> + Send + 'static,
{
let key = Arc::new(key);
let out = self.remove(&key);
let handle = jive::spawn(future);
let abort_handle = handle.abort_handle();
self.unpolled.insert(key, handle);
(abort_handle, out)
}
pub fn insert_local<Fut>(&mut self, key: K, future: Fut) -> (AbortHandle, Option<JoinHandle<V>>)
where
K: Ord + Send + Sync + 'static,
V: Send + 'static,
Fut: std::future::Future<Output = V> + 'static,
{
let key = Arc::new(key);
let out = self.remove(&key);
let handle = jive::spawn_local(future);
let abort_handle = handle.abort_handle();
self.unpolled.insert(key, handle);
(abort_handle, out)
}
pub fn insert_blocking<F>(&mut self, key: K, func: F) -> Option<JoinHandle<V>>
where
K: Ord + Send + Sync + 'static,
V: Send + 'static,
F: Fn() -> V + Send + 'static,
{
let key = Arc::new(key);
let out = self.remove(&key);
let handle = jive::spawn_blocking(func);
self.unpolled.insert(key, handle);
out
}
pub fn remove<Q>(&mut self, key: &Q) -> Option<JoinHandle<V>>
where
Arc<K>: std::borrow::Borrow<Q> + Ord,
Q: Ord + ?Sized,
{
if let Some(handle) = self.unpolled.remove(key) {
return Some(handle);
}
if let Some(handle) = self.handles.remove(key) {
return Some(handle);
}
None
}
// Cancel Safety
//
// This future is strictly cancel-safe.
pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)>
where
K: Ord + Send + Sync + 'static,
{
JoinNext { map: self }.await
}
pub fn len(&self) -> usize {
self.handles.len() + self.unpolled.len()
}
pub fn is_empty(&self) -> bool {
self.handles.is_empty() && self.unpolled.is_empty()
}
}
struct JoinNext<'a, K, V> {
map: &'a mut JoinMap<K, V>,
}
struct JoinMapWaker<K> {
key: Weak<K>,
woken: Arc<Mutex<VecDeque<Weak<K>>>>,
inner: std::task::Waker,
}
impl<K> std::task::Wake for JoinMapWaker<K> {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
self.woken.lock().unwrap().push_back(self.key.clone());
self.inner.wake_by_ref();
}
}
impl<'a, K, V> JoinNext<'a, K, V> {
fn next_woken(&mut self) -> Option<Arc<K>> {
let mut guard = self.map.woken.lock().unwrap();
while let Some(key) = guard.pop_front() {
if let Some(key) = key.upgrade() {
return Some(key);
}
}
None
}
fn handle(&mut self, key: &Arc<K>) -> Option<&mut JoinHandle<V>>
where
K: Ord,
{
self.map.handles.get_mut(key)
}
fn remove(&mut self, key: &Arc<K>)
where
K: Ord,
{
self.map.handles.remove(key);
}
}
impl<'a, K, V> std::future::Future for JoinNext<'a, K, V>
where
K: Ord + Send + Sync + 'static,
{
type Output = Option<(K, Result<V, JoinError>)>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.get_mut();
if this.map.handles.is_empty() && this.map.unpolled.is_empty() {
return std::task::Poll::Ready(None);
}
while let Some(key) = this.next_woken() {
if let Some(handle) = this.handle(&key) {
if let std::task::Poll::Ready(value) = std::pin::Pin::new(handle).poll(cx) {
this.remove(&key);
let key = Arc::into_inner(key).expect("No other holders");
return std::task::Poll::Ready(Some((key, value)));
}
}
}
while let Some((key, mut handle)) = this.map.unpolled.pop_first() {
let waker = Arc::new(JoinMapWaker {
key: Arc::downgrade(&key),
woken: this.map.woken.clone(),
inner: cx.waker().clone(),
})
.into();
let mut cx = std::task::Context::from_waker(&waker);
if let std::task::Poll::Ready(value) = std::pin::Pin::new(&mut handle).poll(&mut cx) {
let key = Arc::into_inner(key).expect("No other holders");
return std::task::Poll::Ready(Some((key, value)));
}
this.map.handles.insert(key, handle);
}
std::task::Poll::Pending
}
}
impl<K, V> Default for JoinMap<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<V> Default for JoinSet<V> {
fn default() -> Self {
Self::new()
}
}
impl<K, V> Drop for JoinMap<K, V> {
fn drop(&mut self) {
if self.abort_on_drop {
for handle in self.unpolled.values() {
handle.abort();
}
for handle in self.handles.values() {
handle.abort();
}
}
}
}