streaming/src/from_fn.rs

268 lines
6.1 KiB
Rust

//! Types and methods for constructing streams
mod type_stack;
use std::{
cell::{OnceCell, RefCell},
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use self::type_stack::TypeStack;
pin_project_lite::pin_project! {
/// Future created by [`from_fn`]
pub struct FromFn<F, I> {
yielded: PhantomData<I>,
#[pin]
future: F,
closed: bool,
}
}
pin_project_lite::pin_project! {
/// Part of Future created by [`try_from_fn`]
pub struct TryYield<F, T, E> {
yielded: PhantomData<Result<T, E>>,
#[pin]
future: F,
closed: bool,
}
}
/// Type used to manage yielding values to the stream
pub struct Yielder<I>(PhantomData<I>);
/// Future that suspends execution of the calling future so the stream can progress
#[must_use]
pub struct YieldFuture<I>(Option<I>);
impl<I> Unpin for YieldFuture<I> {}
impl<I: 'static> Yielder<I> {
/// Yield a value in the stream
pub fn yield_(&self, item: I) -> YieldFuture<I> {
YieldFuture(Some(item))
}
}
impl<T: 'static, E: 'static> Yielder<Result<T, E>> {
/// Yield an Ok value in the stream
pub fn yield_ok(&self, item: T) -> YieldFuture<Result<T, E>> {
self.yield_(Ok(item))
}
/// Yield an Err value in the stream
pub fn yield_err(&self, error: E) -> YieldFuture<Result<T, E>> {
self.yield_(Err(error))
}
}
/// Construct a stream from the given function
///
/// Example:
/// ```rust
/// let input_stream = streem::from_fn(|yielder| async move {
/// for i in 0..10 {
/// yielder.yield_(i).await;
/// }
/// });
/// ```
pub fn from_fn<F, Fut, I>(func: F) -> FromFn<Fut, I>
where
F: FnOnce(Yielder<I>) -> Fut,
Fut: Future<Output = ()>,
{
FromFn {
future: func(Yielder(PhantomData)),
yielded: PhantomData,
closed: false,
}
}
/// Construct a stream yielding results from the given function
///
/// Example:
/// ```rust
/// fn fallible_fn(i: i32) -> Result<i32, String> {
/// # Ok(i)
/// }
///
/// let input_stream = streem::try_from_fn(|yielder| async move {
/// for i in 0..10 {
/// let value = fallible_fn(i)?;
///
/// yielder.yield_ok(value).await;
/// }
///
/// Ok(()) as Result<_, String>
/// });
/// ```
pub fn try_from_fn<F, Fut, T, E>(func: F) -> FromFn<TryYield<Fut, T, E>, Result<T, E>>
where
F: FnOnce(Yielder<Result<T, E>>) -> Fut,
Fut: Future<Output = Result<(), E>>,
{
FromFn {
yielded: PhantomData,
future: TryYield {
yielded: PhantomData,
future: func(Yielder(PhantomData)),
closed: false,
},
closed: false,
}
}
thread_local! {
static YIELD_STACK: OnceCell<RefCell<TypeStack>> = OnceCell::new();
}
fn get_or_init_yield_stack(once_cell: &OnceCell<RefCell<TypeStack>>) -> &RefCell<TypeStack> {
once_cell.get_or_init(|| RefCell::new(TypeStack::new()))
}
enum YieldSlot<T> {
Filled(T),
Empty,
}
impl<T> YieldSlot<T> {
fn into_option(self) -> Option<T> {
match self {
Self::Filled(item) => Some(item),
Self::Empty => None,
}
}
fn into_poll(self) -> Poll<T> {
match self {
Self::Filled(item) => Poll::Ready(item),
Self::Empty => Poll::Pending,
}
}
}
struct YieldGuard<T: 'static>(bool, PhantomData<T>);
impl<T: 'static> YieldGuard<T> {
fn guard() -> Self {
YIELD_STACK.with(|c| {
get_or_init_yield_stack(c)
.borrow_mut()
.push::<YieldSlot<T>>(YieldSlot::Empty)
});
Self(true, PhantomData)
}
fn disarm(mut self) -> YieldSlot<T> {
self.0 = false;
YIELD_STACK.with(|c| {
get_or_init_yield_stack(c)
.borrow_mut()
.pop::<YieldSlot<T>>()
.expect("slot exists")
})
}
}
impl<T: 'static> Drop for YieldGuard<T> {
fn drop(&mut self) {
if self.0 {
YIELD_STACK.with(|c| {
get_or_init_yield_stack(c)
.borrow_mut()
.pop::<YieldSlot<T>>()
});
}
}
}
impl<F, I> futures_core::Stream for FromFn<F, I>
where
F: Future<Output = ()>,
I: 'static,
{
type Item = I;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.closed {
return Poll::Ready(None);
}
let guard = YieldGuard::<I>::guard();
*this.closed = this.future.poll(cx).is_ready();
let item = guard.disarm();
if *this.closed {
Poll::Ready(item.into_option())
} else {
item.into_poll().map(Some)
}
}
}
impl<I> Future for YieldFuture<I>
where
I: 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(item) = self.0.take() {
YIELD_STACK.with(|c| {
let mut guard = get_or_init_yield_stack(c).borrow_mut();
let entry = guard.get_mut::<YieldSlot<I>>().expect("Slot exists");
*entry = YieldSlot::Filled(item);
});
return Poll::Pending;
}
Poll::Ready(())
}
}
impl<F, T, E> Future for TryYield<F, T, E>
where
F: Future<Output = Result<(), E>>,
Result<T, E>: 'static,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.closed {
return Poll::Ready(());
}
let this = self.project();
let res = std::task::ready!(this.future.poll(cx));
*this.closed = true;
if let Err(e) = res {
YIELD_STACK.with(|c| {
let mut guard = get_or_init_yield_stack(c).borrow_mut();
let entry = guard
.get_mut::<YieldSlot<Result<T, E>>>()
.expect("Slot exists");
*entry = YieldSlot::Filled(Err(e));
});
return Poll::Pending;
}
Poll::Ready(())
}
}