use std::{ future::Future, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, task::{Context, Poll, Wake, Waker}, }; pub(super) enum Either { Left(L), Right(R), } struct Select { left: F1, left_woken: Arc, right: F2, right_woken: Arc, } struct SelectWaker { inner: Waker, flag: Arc, } impl Wake for SelectWaker { fn wake_by_ref(self: &Arc) { self.flag.store(true, Ordering::Release); self.inner.wake_by_ref(); } fn wake(self: Arc) { self.flag.store(true, Ordering::Release); match Arc::try_unwrap(self) { Ok(this) => this.inner.wake(), Err(this) => this.inner.wake_by_ref(), } } } impl Future for Select where F1: Future + Unpin, F2: Future + Unpin, { type Output = Either; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let left_waker = Arc::new(SelectWaker { inner: cx.waker().clone(), flag: self.left_woken.clone(), }) .into(); let mut left_ctx = Context::from_waker(&left_waker); if let Poll::Ready(left_out) = Pin::new(&mut self.left).poll(&mut left_ctx) { return Poll::Ready(Either::Left(left_out)); } let right_waker = Arc::new(SelectWaker { inner: cx.waker().clone(), flag: self.right_woken.clone(), }) .into(); let mut right_ctx = Context::from_waker(&right_waker); if let Poll::Ready(right_out) = Pin::new(&mut self.right).poll(&mut right_ctx) { return Poll::Ready(Either::Right(right_out)); } Poll::Pending } } pub(super) async fn select( left: Left, right: Right, ) -> Either where Left: Future, Right: Future, { let left = std::pin::pin!(left); let right = std::pin::pin!(right); Select { left, left_woken: Arc::new(AtomicBool::new(true)), right, right_woken: Arc::new(AtomicBool::new(true)), } .await }