diff --git a/Cargo.lock b/Cargo.lock index c59d6efd..a3f9cf03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1852,6 +1852,7 @@ dependencies = [ "sha2", "sled", "storage-path-generator", + "streem", "thiserror", "time", "tokio", @@ -2646,6 +2647,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f11d35dae9818c4313649da4a97c8329e29357a7fe584526c1d78f5b63ef836" +[[package]] +name = "streem" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "641396a5ae90767cb12d21832444ab760841ee887717d802b2c456c4f8199114" +dependencies = [ + "futures-core", + "pin-project-lite", +] + [[package]] name = "stringprep" version = "0.1.4" diff --git a/Cargo.toml b/Cargo.toml index 071a322e..084d65ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ serde_urlencoded = "0.7.1" sha2 = "0.10.0" sled = { version = "0.34.7" } storage-path-generator = "0.1.0" +streem = "0.1.1" thiserror = "1.0" time = { version = "0.3.0", features = ["serde", "serde-well-known"] } tokio = { version = "1", features = ["full", "tracing"] } diff --git a/src/backgrounded.rs b/src/backgrounded.rs index a1aa87aa..31c784c1 100644 --- a/src/backgrounded.rs +++ b/src/backgrounded.rs @@ -4,7 +4,6 @@ use crate::{ error::Error, repo::{ArcRepo, UploadId}, store::Store, - stream::StreamMap, }; use actix_web::web::Bytes; use futures_core::Stream; @@ -34,7 +33,7 @@ impl Backgrounded { pub(crate) async fn proxy(repo: ArcRepo, store: S, stream: P) -> Result where S: Store, - P: Stream> + Unpin + 'static, + P: Stream> + 'static, { let mut this = Self { repo, @@ -50,12 +49,13 @@ impl Backgrounded { async fn do_proxy(&mut self, store: S, stream: P) -> Result<(), Error> where S: Store, - P: Stream> + Unpin + 'static, + P: Stream> + 'static, { self.upload_id = Some(self.repo.create_upload().await?); - let stream = - stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); + let stream = Box::pin(crate::stream::map_err(stream, |e| { + std::io::Error::new(std::io::ErrorKind::Other, e) + })); // use octet-stream, we don't know the upload's real type yet let identifier = store.save_stream(stream, APPLICATION_OCTET_STREAM).await?; diff --git a/src/details.rs b/src/details.rs index f5b9a86c..9141cd2f 100644 --- a/src/details.rs +++ b/src/details.rs @@ -7,9 +7,9 @@ use crate::{ formats::{InternalFormat, InternalVideoFormat}, serde_str::Serde, store::Store, - stream::IntoStreamer, }; use actix_web::web; +use streem::IntoStreamer; use time::{format_description::well_known::Rfc3339, OffsetDateTime}; #[derive(Copy, Clone, Debug, serde::Deserialize, serde::Serialize)] diff --git a/src/file.rs b/src/file.rs index bb753b48..a8b0fc12 100644 --- a/src/file.rs +++ b/src/file.rs @@ -6,14 +6,11 @@ pub(crate) use tokio_file::File; #[cfg(not(feature = "io-uring"))] mod tokio_file { - use crate::{ - store::file_store::FileError, - stream::{IntoStreamer, StreamMap}, - Either, - }; + use crate::{store::file_store::FileError, Either}; use actix_web::web::{Bytes, BytesMut}; use futures_core::Stream; use std::{io::SeekFrom, path::Path}; + use streem::IntoStreamer; use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; use tokio_util::codec::{BytesCodec, FramedRead}; @@ -100,14 +97,17 @@ mod tokio_file { (None, None) => Either::right(self.inner), }; - Ok(FramedRead::new(obj, BytesCodec::new()).map(|res| res.map(BytesMut::freeze))) + Ok(crate::stream::map_ok( + FramedRead::new(obj, BytesCodec::new()), + BytesMut::freeze, + )) } } } #[cfg(feature = "io-uring")] mod io_uring { - use crate::{store::file_store::FileError, stream::IntoStreamer}; + use crate::store::file_store::FileError; use actix_web::web::{Bytes, BytesMut}; use futures_core::Stream; use std::{ @@ -118,6 +118,7 @@ mod io_uring { pin::Pin, task::{Context, Poll}, }; + use streem::IntoStreamer; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio_uring::{ buf::{IoBuf, IoBufMut}, diff --git a/src/ingest.rs b/src/ingest.rs index 2e394c7a..062c82f8 100644 --- a/src/ingest.rs +++ b/src/ingest.rs @@ -7,12 +7,12 @@ use crate::{ formats::{InternalFormat, Validations}, repo::{Alias, ArcRepo, DeleteToken, Hash}, store::Store, - stream::{IntoStreamer, MakeSend}, }; use actix_web::web::Bytes; use futures_core::Stream; use reqwest::Body; use reqwest_middleware::ClientWithMiddleware; +use streem::IntoStreamer; use tracing::{Instrument, Span}; mod hasher; @@ -30,10 +30,11 @@ pub(crate) struct Session { #[tracing::instrument(skip(stream))] async fn aggregate(stream: S) -> Result where - S: Stream> + Unpin, + S: Stream>, { let mut buf = BytesStream::new(); + let stream = std::pin::pin!(stream); let mut stream = stream.into_streamer(); while let Some(res) = stream.next().await { @@ -48,7 +49,7 @@ pub(crate) async fn ingest( repo: &ArcRepo, store: &S, client: &ClientWithMiddleware, - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, declared_alias: Option, media: &crate::config::Media, ) -> Result @@ -117,12 +118,12 @@ where }; if let Some(endpoint) = &media.external_validation { - let stream = store.to_stream(&identifier, None, None).await?.make_send(); + let stream = store.to_stream(&identifier, None, None).await?; let response = client .post(endpoint.as_str()) .header("Content-Type", input_type.media_type().as_ref()) - .body(Body::wrap_stream(stream)) + .body(Body::wrap_stream(crate::stream::make_send(stream))) .send() .instrument(tracing::info_span!("external-validation")) .await?; diff --git a/src/lib.rs b/src/lib.rs index a66c890d..8c91ab3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,6 +54,7 @@ use std::{ sync::Arc, time::{Duration, SystemTime}, }; +use streem::IntoStreamer; use tokio::sync::Semaphore; use tracing::Instrument; use tracing_actix_web::TracingLogger; @@ -74,7 +75,7 @@ use self::{ repo::{sled::SledRepo, Alias, DeleteToken, Hash, Repo, UploadId, UploadResult}, serde_str::Serde, store::{file_store::FileStore, object_store::ObjectStore, Store}, - stream::{empty, once, StreamLimit, StreamMap, StreamTimeout}, + stream::{empty, once}, }; pub use self::config::{ConfigSource, PictRsConfiguration}; @@ -165,14 +166,14 @@ impl FormData for Upload { let span = tracing::info_span!("file-upload", ?filename); - let stream = stream.map(|res| res.map_err(Error::from)); - Box::pin( async move { if config.server.read_only { return Err(UploadError::ReadOnly.into()); } + let stream = crate::stream::from_err(stream); + ingest::ingest(&repo, &**store, &client, stream, None, &config.media) .await } @@ -230,14 +231,14 @@ impl FormData for Import { let span = tracing::info_span!("file-import", ?filename); - let stream = stream.map(|res| res.map_err(Error::from)); - Box::pin( async move { if config.server.read_only { return Err(UploadError::ReadOnly.into()); } + let stream = crate::stream::from_err(stream); + ingest::ingest( &repo, &**store, @@ -368,14 +369,14 @@ impl FormData for BackgroundedUpload { let span = tracing::info_span!("file-proxy", ?filename); - let stream = stream.map(|res| res.map_err(Error::from)); - Box::pin( async move { if read_only { return Err(UploadError::ReadOnly.into()); } + let stream = crate::stream::from_err(stream); + Backgrounded::proxy(repo, store, stream).await } .instrument(span), @@ -488,7 +489,7 @@ struct UrlQuery { } async fn ingest_inline( - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, repo: &ArcRepo, store: &S, client: &ClientWithMiddleware, @@ -527,7 +528,7 @@ async fn download_stream( client: &ClientWithMiddleware, url: &str, config: &Configuration, -) -> Result> + Unpin + 'static, Error> { +) -> Result> + 'static, Error> { if config.server.read_only { return Err(UploadError::ReadOnly.into()); } @@ -538,10 +539,10 @@ async fn download_stream( return Err(UploadError::Download(res.status()).into()); } - let stream = res - .bytes_stream() - .map(|res| res.map_err(Error::from)) - .limit((config.media.max_file_size * MEGABYTES) as u64); + let stream = crate::stream::limit( + config.media.max_file_size * MEGABYTES, + crate::stream::from_err(res.bytes_stream()), + ); Ok(stream) } @@ -551,7 +552,7 @@ async fn download_stream( skip(stream, repo, store, client, config) )] async fn do_download_inline( - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, repo: web::Data, store: web::Data, client: &ClientWithMiddleware, @@ -574,7 +575,7 @@ async fn do_download_inline( #[tracing::instrument(name = "Downloading file in background", skip(stream, repo, store))] async fn do_download_backgrounded( - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, repo: web::Data, store: web::Data, ) -> Result { @@ -1325,9 +1326,7 @@ async fn ranged_file_resp( ( builder, Either::left(Either::left( - range::chop_store(range, store, &identifier, len) - .await? - .map(|res| res.map_err(Error::from)), + range::chop_store(range, store, &identifier, len).await?, )), ) } else { @@ -1341,10 +1340,7 @@ async fn ranged_file_resp( } } else { //No Range header in the request - return the entire document - let stream = store - .to_stream(&identifier, None, None) - .await? - .map(|res| res.map_err(Error::from)); + let stream = crate::stream::from_err(store.to_stream(&identifier, None, None).await?); if not_found { (HttpResponse::NotFound(), Either::right(stream)) @@ -1375,10 +1371,18 @@ where E: std::error::Error + 'static, actix_web::Error: From, { - let stream = stream.timeout(Duration::from_secs(5)).map(|res| match res { - Ok(Ok(item)) => Ok(item), - Ok(Err(e)) => Err(actix_web::Error::from(e)), - Err(e) => Err(Error::from(e).into()), + let stream = crate::stream::timeout(Duration::from_secs(5), stream); + + let stream = streem::try_from_fn(|yielder| async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); + + while let Some(res) = streamer.next().await { + let item = res.map_err(Error::from)??; + yielder.yield_ok(item).await; + } + + Ok(()) as Result<(), actix_web::Error> }); srv_head(builder, ext, expires, modified).streaming(stream) diff --git a/src/migrate_store.rs b/src/migrate_store.rs index d2ce9022..ac3b6382 100644 --- a/src/migrate_store.rs +++ b/src/migrate_store.rs @@ -7,12 +7,13 @@ use std::{ time::{Duration, Instant}, }; +use streem::IntoStreamer; + use crate::{ details::Details, error::{Error, UploadError}, repo::{ArcRepo, Hash}, store::Store, - stream::IntoStreamer, }; pub(super) async fn migrate_store( diff --git a/src/queue/cleanup.rs b/src/queue/cleanup.rs index ea3ff750..8217e3bd 100644 --- a/src/queue/cleanup.rs +++ b/src/queue/cleanup.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use streem::IntoStreamer; + use crate::{ config::Configuration, error::{Error, UploadError}, @@ -8,7 +10,6 @@ use crate::{ repo::{Alias, ArcRepo, DeleteToken, Hash}, serde_str::Serde, store::Store, - stream::IntoStreamer, }; pub(super) fn perform<'a, S>( diff --git a/src/queue/process.rs b/src/queue/process.rs index 29fccf7e..480a5043 100644 --- a/src/queue/process.rs +++ b/src/queue/process.rs @@ -12,7 +12,6 @@ use crate::{ repo::{Alias, ArcRepo, UploadId, UploadResult}, serde_str::Serde, store::Store, - stream::StreamMap, }; use std::{path::PathBuf, sync::Arc}; @@ -131,10 +130,7 @@ where let media = media.clone(); let error_boundary = crate::sync::spawn(async move { - let stream = store2 - .to_stream(&ident, None, None) - .await? - .map(|res| res.map_err(Error::from)); + let stream = crate::stream::from_err(store2.to_stream(&ident, None, None).await?); let session = crate::ingest::ingest(&repo, &store2, &client, stream, declared_alias, &media) diff --git a/src/repo/migrate.rs b/src/repo/migrate.rs index e10177a8..7809fedc 100644 --- a/src/repo/migrate.rs +++ b/src/repo/migrate.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use streem::IntoStreamer; use tokio::task::JoinSet; use crate::{ @@ -12,7 +13,6 @@ use crate::{ SledRepo as OldSledRepo, }, store::Store, - stream::IntoStreamer, }; const MIGRATE_CONCURRENCY: usize = 32; diff --git a/src/store/object_store.rs b/src/store/object_store.rs index 324a0ebc..7046f69a 100644 --- a/src/store/object_store.rs +++ b/src/store/object_store.rs @@ -1,9 +1,6 @@ use crate::{ - bytes_stream::BytesStream, - error_code::ErrorCode, - repo::ArcRepo, - store::Store, - stream::{IntoStreamer, LocalBoxStream, StreamMap}, + bytes_stream::BytesStream, error_code::ErrorCode, repo::ArcRepo, store::Store, + stream::LocalBoxStream, }; use actix_rt::task::JoinError; use actix_web::{ @@ -21,6 +18,7 @@ use reqwest_middleware::{ClientWithMiddleware, RequestBuilder}; use rusty_s3::{actions::S3Action, Bucket, BucketError, Credentials, UrlStyle}; use std::{string::FromUtf8Error, sync::Arc, time::Duration}; use storage_path_generator::{Generator, Path}; +use streem::IntoStreamer; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::io::ReaderStream; use tracing::Instrument; @@ -393,11 +391,10 @@ impl Store for ObjectStore { return Err(status_error(response).await); } - Ok(Box::pin( - response - .bytes_stream() - .map(|res| res.map_err(payload_to_io_error)), - )) + Ok(Box::pin(crate::stream::map_err( + response.bytes_stream(), + payload_to_io_error, + ))) } #[tracing::instrument(skip(self, writer))] diff --git a/src/stream.rs b/src/stream.rs index 7a67a9e5..f7d44492 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,269 +1,170 @@ -use actix_rt::{task::JoinHandle, time::Sleep}; use actix_web::web::Bytes; -use flume::r#async::RecvStream; use futures_core::Stream; -use std::{ - future::Future, - marker::PhantomData, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Wake, Waker}, - time::Duration, -}; +use std::{pin::Pin, time::Duration}; +use streem::IntoStreamer; -pub(crate) trait MakeSend: Stream> +pub(crate) fn make_send(stream: S) -> impl Stream + Send where - T: 'static, + S: Stream + 'static, + S::Item: Send + Sync, { - fn make_send(self) -> MakeSendStream - where - Self: Sized + 'static, - { - let (tx, rx) = crate::sync::channel(4); + let (tx, rx) = crate::sync::channel(1); - MakeSendStream { - handle: crate::sync::spawn(async move { - let this = std::pin::pin!(self); + let handle = crate::sync::spawn(async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); - let mut stream = this.into_streamer(); - - while let Some(res) = stream.next().await { - if tx.send_async(res).await.is_err() { - return; - } - } - }), - rx: rx.into_stream(), - } - } -} - -impl MakeSend for S -where - S: Stream>, - T: 'static, -{ -} - -pub(crate) struct MakeSendStream -where - T: 'static, -{ - handle: actix_rt::task::JoinHandle<()>, - rx: flume::r#async::RecvStream<'static, std::io::Result>, -} - -impl Stream for MakeSendStream -where - T: 'static, -{ - type Item = std::io::Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.rx).poll_next(cx) { - Poll::Ready(opt) => Poll::Ready(opt), - Poll::Pending if std::task::ready!(Pin::new(&mut self.handle).poll(cx)).is_err() => { - Poll::Ready(Some(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "Stream panicked", - )))) + while let Some(res) = streamer.next().await { + if tx.send_async(res).await.is_err() { + break; } - Poll::Pending => Poll::Pending, } - } + }); + + streem::from_fn(|yiedler| async move { + let mut stream = rx.into_stream().into_streamer(); + + while let Some(res) = stream.next().await { + yiedler.yield_(res).await; + } + + let _ = handle.await; + }) } -pin_project_lite::pin_project! { - pub(crate) struct Map { - #[pin] - stream: S, - func: F, - } +pub(crate) fn from_iterator(iterator: I, buffer: usize) -> impl Stream + Send +where + I: IntoIterator + Send + 'static, + I::Item: Send + Sync, +{ + let (tx, rx) = crate::sync::channel(buffer); + + let handle = crate::sync::spawn_blocking(move || { + for value in iterator { + if tx.send(value).is_err() { + break; + } + } + }); + + streem::from_fn(|yielder| async move { + let mut stream = rx.into_stream().into_streamer(); + + while let Some(res) = stream.next().await { + yielder.yield_(res).await; + } + + let _ = handle.await; + }) } -pub(crate) trait StreamMap: Stream { - fn map(self, func: F) -> Map - where - F: FnMut(Self::Item) -> U, - Self: Sized, - { - Map { stream: self, func } - } +pub(crate) fn map_ok(stream: S, f: F) -> impl Stream> +where + S: Stream>, + T2: 'static, + E: 'static, + F: Fn(T1) -> T2 + Copy, +{ + streem::from_fn(|yielder| async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); + + while let Some(res) = streamer.next().await { + yielder.yield_(res.map(f)).await; + } + }) } -impl StreamMap for T where T: Stream {} +pub(crate) fn map_err(stream: S, f: F) -> impl Stream> +where + S: Stream>, + T: 'static, + E2: 'static, + F: Fn(E1) -> E2 + Copy, +{ + streem::from_fn(|yielder| async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); -impl Stream for Map + while let Some(res) = streamer.next().await { + yielder.yield_(res.map_err(f)).await; + } + }) +} + +pub(crate) fn from_err(stream: S) -> impl Stream> +where + S: Stream>, + T: 'static, + E1: Into, + E2: 'static, +{ + map_err(stream, Into::into) +} + +pub(crate) fn empty() -> impl Stream +where + T: 'static, +{ + streem::from_fn(|_| std::future::ready(())) +} + +pub(crate) fn once(value: T) -> impl Stream +where + T: 'static, +{ + streem::from_fn(|yielder| yielder.yield_(value)) +} + +pub(crate) fn timeout( + duration: Duration, + stream: S, +) -> impl Stream> where S: Stream, - F: FnMut(S::Item) -> U, + S::Item: 'static, { - type Item = U; + streem::try_from_fn(|yielder| async move { + actix_rt::time::timeout(duration, async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - let value = std::task::ready!(this.stream.poll_next(cx)); - - Poll::Ready(value.map(this.func)) - } + while let Some(res) = streamer.next().await { + yielder.yield_ok(res).await; + } + }) + .await + .map_err(|_| TimeoutError) + }) } -pub(crate) struct Empty(PhantomData); - -impl Stream for Empty { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(None) - } -} - -pub(crate) fn empty() -> Empty { - Empty(PhantomData) -} - -pub(crate) struct Once(Option); - -impl Stream for Once +pub(crate) fn limit(limit: usize, stream: S) -> impl Stream> where - T: Unpin, + S: Stream>, + E: From + 'static, { - type Item = T; + streem::try_from_fn(|yielder| async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); - fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(self.0.take()) - } -} + let mut count = 0; -pub(crate) fn once(value: T) -> Once { - Once(Some(value)) + while let Some(bytes) = streamer.try_next().await? { + count += bytes.len(); + + if count > limit { + return Err(LimitError.into()); + } + + yielder.yield_ok(bytes).await; + } + + Ok(()) + }) } pub(crate) type LocalBoxStream<'a, T> = Pin + 'a>>; -pub(crate) trait StreamLimit { - fn limit(self, limit: u64) -> Limit - where - Self: Sized, - { - Limit { - inner: self, - count: 0, - limit, - } - } -} - -pub(crate) trait StreamTimeout { - fn timeout(self, duration: Duration) -> Timeout - where - Self: Sized, - { - Timeout { - sleep: actix_rt::time::sleep(duration), - inner: self, - expired: false, - woken: Arc::new(AtomicBool::new(true)), - } - } -} - -pub(crate) trait IntoStreamer: Stream { - fn into_streamer(self) -> Streamer - where - Self: Sized, - { - Streamer(Some(self)) - } -} - -impl IntoStreamer for T where T: Stream + Unpin {} - -pub(crate) fn from_iterator( - iterator: I, - buffer: usize, -) -> IterStream { - IterStream { - state: IterStreamState::New { iterator, buffer }, - } -} - -impl StreamLimit for S where S: Stream> {} -impl StreamTimeout for S where S: Stream {} - -pub(crate) struct Streamer(Option); - -impl Streamer { - pub(crate) async fn next(&mut self) -> Option - where - S: Stream + Unpin, - { - let stream = self.0.as_mut().take()?; - - let opt = std::future::poll_fn(|cx| Pin::new(&mut *stream).poll_next(cx)).await; - - if opt.is_none() { - self.0.take(); - } - - opt - } -} - -pin_project_lite::pin_project! { - pub(crate) struct Limit { - #[pin] - inner: S, - - count: u64, - limit: u64, - } -} - -pin_project_lite::pin_project! { - pub(crate) struct Timeout { - #[pin] - sleep: Sleep, - - #[pin] - inner: S, - - expired: bool, - woken: Arc, - } -} - -enum IterStreamState -where - T: 'static, -{ - New { - iterator: I, - buffer: usize, - }, - Running { - handle: JoinHandle<()>, - receiver: RecvStream<'static, T>, - }, - Pending, -} - -pub(crate) struct IterStream -where - T: 'static, -{ - state: IterStreamState, -} - -struct TimeoutWaker { - woken: Arc, - inner: Waker, -} - #[derive(Debug, thiserror::Error)] #[error("Resonse body larger than size limit")] pub(crate) struct LimitError; @@ -271,135 +172,3 @@ pub(crate) struct LimitError; #[derive(Debug, thiserror::Error)] #[error("Timeout in body")] pub(crate) struct TimeoutError; - -impl Stream for Limit -where - S: Stream>, - E: From, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().project(); - - let limit = this.limit; - let count = this.count; - let inner = this.inner; - - inner.poll_next(cx).map(|opt| { - opt.map(|res| match res { - Ok(bytes) => { - *count += bytes.len() as u64; - if *count > *limit { - return Err(LimitError.into()); - } - Ok(bytes) - } - Err(e) => Err(e), - }) - }) - } -} - -impl Wake for TimeoutWaker { - fn wake(self: Arc) { - self.wake_by_ref() - } - - fn wake_by_ref(self: &Arc) { - self.woken.store(true, Ordering::Release); - self.inner.wake_by_ref(); - } -} - -impl Stream for Timeout -where - S: Stream, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().project(); - - if *this.expired { - return Poll::Ready(None); - } - - if this.woken.swap(false, Ordering::Acquire) { - let timeout_waker = Arc::new(TimeoutWaker { - woken: Arc::clone(this.woken), - inner: cx.waker().clone(), - }) - .into(); - - let mut timeout_cx = Context::from_waker(&timeout_waker); - - if this.sleep.poll(&mut timeout_cx).is_ready() { - *this.expired = true; - return Poll::Ready(Some(Err(TimeoutError))); - } - } - - this.inner.poll_next(cx).map(|opt| opt.map(Ok)) - } -} - -impl Stream for IterStream -where - I: IntoIterator + Send + Unpin + 'static, - T: Send + 'static, -{ - type Item = T; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().get_mut(); - - match std::mem::replace(&mut this.state, IterStreamState::Pending) { - IterStreamState::New { iterator, buffer } => { - let (sender, receiver) = crate::sync::channel(buffer); - - let mut handle = crate::sync::spawn_blocking(move || { - let iterator = iterator.into_iter(); - - for item in iterator { - if sender.send(item).is_err() { - break; - } - } - }); - - if Pin::new(&mut handle).poll(cx).is_ready() { - return Poll::Ready(None); - } - - this.state = IterStreamState::Running { - handle, - receiver: receiver.into_stream(), - }; - - self.poll_next(cx) - } - IterStreamState::Running { - mut handle, - mut receiver, - } => match Pin::new(&mut receiver).poll_next(cx) { - Poll::Ready(Some(item)) => { - this.state = IterStreamState::Running { handle, receiver }; - - Poll::Ready(Some(item)) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => { - if Pin::new(&mut handle).poll(cx).is_ready() { - return Poll::Ready(None); - } - - this.state = IterStreamState::Running { handle, receiver }; - - Poll::Pending - } - }, - IterStreamState::Pending => panic!("Polled after completion"), - } - } -}