Implement JoinMap and JoinSet

This commit is contained in:
asonix 2023-08-25 17:52:17 -05:00
parent bad22d64cc
commit d580286b22
2 changed files with 290 additions and 9 deletions

26
examples/demo.rs Normal file
View file

@ -0,0 +1,26 @@
use std::time::Duration;
fn main() {
jive::block_on(async move {
let mut set = jive_joinmap::JoinSet::new();
for _ in 0..20 {
if set.len() >= 10 {
set.join_next().await;
println!("Joined");
}
set.spawn(async move {
jive::time::sleep(Duration::from_secs(2)).await;
});
println!("Spawned");
jive::time::sleep(Duration::from_millis(250)).await;
}
while set.join_next().await.is_some() {
println!("Joined");
// drain set
}
})
}

View file

@ -1,14 +1,269 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
use std::{
collections::{BTreeMap, VecDeque},
sync::{Arc, Mutex, Weak},
};
use jive::task::sync::{JoinError, JoinHandle};
pub struct JoinSet<V> {
index: u64,
map: JoinMap<u64, V>,
}
#[cfg(test)]
mod tests {
use super::*;
pub struct JoinMap<K, V> {
woken: Arc<Mutex<VecDeque<Weak<K>>>>,
handles: BTreeMap<Arc<K>, JoinHandle<V>>,
unpolled: BTreeMap<Arc<K>, JoinHandle<V>>,
}
#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
impl<V> JoinSet<V> {
pub fn new() -> Self {
JoinSet {
index: 0,
map: JoinMap::new(),
}
}
pub fn spawn<Fut>(&mut self, future: Fut)
where
V: Send + 'static,
Fut: std::future::Future<Output = V> + Send + 'static,
{
let index = self.index;
self.index += 1;
self.map.insert(index, future);
}
pub fn spawn_local<Fut>(&mut self, future: Fut)
where
V: Send + 'static,
Fut: std::future::Future<Output = V> + Send + 'static,
{
let index = self.index;
self.index += 1;
self.map.insert_local(index, future);
}
pub fn spawn_blocking<F>(&mut self, func: F)
where
V: Send + 'static,
F: Fn() -> V + Send + 'static,
{
let index = self.index;
self.index += 1;
self.map.insert_blocking(index, func);
}
// Cancel Safety
//
// This future is strictly cancel-safe.
pub async fn join_next(&mut self) -> Option<Result<V, JoinError>> {
self.map.join_next().await.map(|(_, res)| res)
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
impl<K, V> JoinMap<K, V> {
pub fn new() -> Self {
JoinMap {
woken: Arc::new(Mutex::new(VecDeque::new())),
handles: BTreeMap::new(),
unpolled: BTreeMap::new(),
}
}
pub fn insert<Fut>(&mut self, key: K, future: Fut) -> Option<JoinHandle<V>>
where
K: Ord + Send + Sync + 'static,
V: Send + 'static,
Fut: std::future::Future<Output = V> + Send + 'static,
{
let key = Arc::new(key);
let out = self.remove(&key);
let handle = jive::spawn(future);
self.unpolled.insert(key, handle);
out
}
pub fn insert_local<Fut>(&mut self, key: K, future: Fut) -> Option<JoinHandle<V>>
where
K: Ord + Send + Sync + 'static,
V: Send + 'static,
Fut: std::future::Future<Output = V> + 'static,
{
let key = Arc::new(key);
let out = self.remove(&key);
let handle = jive::spawn_local(future);
self.unpolled.insert(key, handle);
out
}
pub fn insert_blocking<F>(&mut self, key: K, func: F) -> Option<JoinHandle<V>>
where
K: Ord + Send + Sync + 'static,
V: Send + 'static,
F: Fn() -> V + Send + 'static,
{
let key = Arc::new(key);
let out = self.remove(&key);
let handle = jive::spawn_blocking(func);
self.unpolled.insert(key, handle);
out
}
pub fn remove<Q>(&mut self, key: &Q) -> Option<JoinHandle<V>>
where
Arc<K>: std::borrow::Borrow<Q> + Ord,
Q: Ord + ?Sized,
{
if let Some(handle) = self.unpolled.remove(key) {
return Some(handle);
}
if let Some(handle) = self.handles.remove(key) {
return Some(handle);
}
None
}
// Cancel Safety
//
// This future is strictly cancel-safe.
pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)>
where
K: Ord + Send + Sync + 'static,
{
JoinNext { map: self }.await
}
pub fn len(&self) -> usize {
self.handles.len() + self.unpolled.len()
}
pub fn is_empty(&self) -> bool {
self.handles.is_empty() && self.unpolled.is_empty()
}
}
struct JoinNext<'a, K, V> {
map: &'a mut JoinMap<K, V>,
}
struct JoinMapWaker<K> {
key: Weak<K>,
woken: Arc<Mutex<VecDeque<Weak<K>>>>,
inner: std::task::Waker,
}
impl<K> std::task::Wake for JoinMapWaker<K> {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
self.woken.lock().unwrap().push_back(self.key.clone());
self.inner.wake_by_ref();
}
}
impl<'a, K, V> JoinNext<'a, K, V> {
fn next_woken(&mut self) -> Option<Arc<K>> {
let mut guard = self.map.woken.lock().unwrap();
while let Some(key) = guard.pop_front() {
if let Some(key) = key.upgrade() {
return Some(key);
}
}
None
}
fn handle(&mut self, key: &Arc<K>) -> Option<&mut JoinHandle<V>>
where
K: Ord,
{
self.map.handles.get_mut(key)
}
fn remove(&mut self, key: &Arc<K>)
where
K: Ord,
{
self.map.handles.remove(key);
}
}
impl<'a, K, V> std::future::Future for JoinNext<'a, K, V>
where
K: Ord + Send + Sync + 'static,
{
type Output = Option<(K, Result<V, JoinError>)>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.get_mut();
if this.map.handles.is_empty() && this.map.unpolled.is_empty() {
return std::task::Poll::Ready(None);
}
while let Some(key) = this.next_woken() {
if let Some(handle) = this.handle(&key) {
if let std::task::Poll::Ready(value) = std::pin::Pin::new(handle).poll(cx) {
this.remove(&key);
let key = Arc::into_inner(key).expect("No other holders");
return std::task::Poll::Ready(Some((key, value)));
}
}
}
while let Some((key, mut handle)) = this.map.unpolled.pop_first() {
let waker = Arc::new(JoinMapWaker {
key: Arc::downgrade(&key),
woken: this.map.woken.clone(),
inner: cx.waker().clone(),
})
.into();
let mut cx = std::task::Context::from_waker(&waker);
if let std::task::Poll::Ready(value) = std::pin::Pin::new(&mut handle).poll(&mut cx) {
let key = Arc::into_inner(key).expect("No other holders");
return std::task::Poll::Ready(Some((key, value)));
}
this.map.handles.insert(key, handle);
}
std::task::Poll::Pending
}
}