use std::{ future::Future, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, task::{Context, Poll, Wake, Waker}, }; struct ThreadWaker { thread: std::thread::Thread, } impl Wake for ThreadWaker { fn wake(self: Arc) { self.thread.unpark(); } fn wake_by_ref(self: &Arc) { self.thread.unpark(); } } pub(super) enum Either { Left(L), Right(R), } struct Select { left: F1, left_woken: Arc, right: F2, right_woken: Arc, } struct SelectWaker { inner: Waker, flag: Arc, } impl Wake for SelectWaker { fn wake_by_ref(self: &Arc) { self.flag.store(true, Ordering::Release); self.inner.wake_by_ref(); } fn wake(self: Arc) { self.flag.store(true, Ordering::Release); match Arc::try_unwrap(self) { Ok(this) => this.inner.wake(), Err(this) => this.inner.wake_by_ref(), } } } impl Future for Select where F1: Future + Unpin, F2: Future + Unpin, { type Output = Either; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let left_waker = Arc::new(SelectWaker { inner: cx.waker().clone(), flag: self.left_woken.clone(), }) .into(); let mut left_ctx = Context::from_waker(&left_waker); if let Poll::Ready(left_out) = Pin::new(&mut self.left).poll(&mut left_ctx) { return Poll::Ready(Either::Left(left_out)); } let right_waker = Arc::new(SelectWaker { inner: cx.waker().clone(), flag: self.right_woken.clone(), }) .into(); let mut right_ctx = Context::from_waker(&right_waker); if let Poll::Ready(right_out) = Pin::new(&mut self.right).poll(&mut right_ctx) { return Poll::Ready(Either::Right(right_out)); } Poll::Pending } } pub(super) fn blocking_select( left: Left, right: Right, ) -> Either where Left: Future, Right: Future, { block_on(select(left, right)) } fn block_on(fut: F) -> F::Output where F: Future, { let thread_waker = Arc::new(ThreadWaker { thread: std::thread::current(), }) .into(); let mut ctx = Context::from_waker(&thread_waker); let mut fut = std::pin::pin!(fut); loop { if let Poll::Ready(out) = fut.as_mut().poll(&mut ctx) { return out; } // doesn't race - unpark followed by park will result in park returning immediately std::thread::park(); } } async fn select(left: Left, right: Right) -> Either where Left: Future, Right: Future, { let left = std::pin::pin!(left); let right = std::pin::pin!(right); Select { left, left_woken: Arc::new(AtomicBool::new(true)), right, right_woken: Arc::new(AtomicBool::new(true)), } .await }