diff --git a/Cargo.lock b/Cargo.lock index c59d6ef..8a9f367 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1852,6 +1852,7 @@ dependencies = [ "sha2", "sled", "storage-path-generator", + "streem", "thiserror", "time", "tokio", @@ -2646,6 +2647,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f11d35dae9818c4313649da4a97c8329e29357a7fe584526c1d78f5b63ef836" +[[package]] +name = "streem" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8b0c8184b0fe05b37dd75d66205195cd57563c6c87cb92134a025a34a6ab34" +dependencies = [ + "futures-core", + "pin-project-lite", +] + [[package]] name = "stringprep" version = "0.1.4" diff --git a/Cargo.toml b/Cargo.toml index 071a322..753b468 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ serde_urlencoded = "0.7.1" sha2 = "0.10.0" sled = { version = "0.34.7" } storage-path-generator = "0.1.0" +streem = "0.2.0" thiserror = "1.0" time = { version = "0.3.0", features = ["serde", "serde-well-known"] } tokio = { version = "1", features = ["full", "tracing"] } diff --git a/src/backgrounded.rs b/src/backgrounded.rs index a1aa87a..31c784c 100644 --- a/src/backgrounded.rs +++ b/src/backgrounded.rs @@ -4,7 +4,6 @@ use crate::{ error::Error, repo::{ArcRepo, UploadId}, store::Store, - stream::StreamMap, }; use actix_web::web::Bytes; use futures_core::Stream; @@ -34,7 +33,7 @@ impl Backgrounded { pub(crate) async fn proxy(repo: ArcRepo, store: S, stream: P) -> Result where S: Store, - P: Stream> + Unpin + 'static, + P: Stream> + 'static, { let mut this = Self { repo, @@ -50,12 +49,13 @@ impl Backgrounded { async fn do_proxy(&mut self, store: S, stream: P) -> Result<(), Error> where S: Store, - P: Stream> + Unpin + 'static, + P: Stream> + 'static, { self.upload_id = Some(self.repo.create_upload().await?); - let stream = - stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); + let stream = Box::pin(crate::stream::map_err(stream, |e| { + std::io::Error::new(std::io::ErrorKind::Other, e) + })); // use octet-stream, we don't know the upload's real type yet let identifier = store.save_stream(stream, APPLICATION_OCTET_STREAM).await?; diff --git a/src/details.rs b/src/details.rs index f5b9a86..9141cd2 100644 --- a/src/details.rs +++ b/src/details.rs @@ -7,9 +7,9 @@ use crate::{ formats::{InternalFormat, InternalVideoFormat}, serde_str::Serde, store::Store, - stream::IntoStreamer, }; use actix_web::web; +use streem::IntoStreamer; use time::{format_description::well_known::Rfc3339, OffsetDateTime}; #[derive(Copy, Clone, Debug, serde::Deserialize, serde::Serialize)] diff --git a/src/file.rs b/src/file.rs index bb753b4..a2408df 100644 --- a/src/file.rs +++ b/src/file.rs @@ -6,14 +6,11 @@ pub(crate) use tokio_file::File; #[cfg(not(feature = "io-uring"))] mod tokio_file { - use crate::{ - store::file_store::FileError, - stream::{IntoStreamer, StreamMap}, - Either, - }; + use crate::{store::file_store::FileError, Either}; use actix_web::web::{Bytes, BytesMut}; use futures_core::Stream; use std::{io::SeekFrom, path::Path}; + use streem::IntoStreamer; use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; use tokio_util::codec::{BytesCodec, FramedRead}; @@ -100,24 +97,25 @@ mod tokio_file { (None, None) => Either::right(self.inner), }; - Ok(FramedRead::new(obj, BytesCodec::new()).map(|res| res.map(BytesMut::freeze))) + Ok(crate::stream::map_ok( + FramedRead::new(obj, BytesCodec::new()), + BytesMut::freeze, + )) } } } #[cfg(feature = "io-uring")] mod io_uring { - use crate::{store::file_store::FileError, stream::IntoStreamer}; + use crate::store::file_store::FileError; use actix_web::web::{Bytes, BytesMut}; use futures_core::Stream; use std::{ convert::TryInto, fs::Metadata, - future::Future, path::{Path, PathBuf}, - pin::Pin, - task::{Context, Poll}, }; + use streem::IntoStreamer; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio_uring::{ buf::{IoBuf, IoBufMut}, @@ -318,20 +316,7 @@ mod io_uring { from_start: Option, len: Option, ) -> Result>, FileError> { - let size = self.metadata().await?.len(); - - let cursor = from_start.unwrap_or(0); - let size = len.unwrap_or(size - cursor) + cursor; - - Ok(BytesStream { - state: ReadFileState::File { - file: Some(self), - bytes: Some(BytesMut::new()), - }, - size, - cursor, - callback: read_file, - }) + Ok(bytes_stream(self, from_start, len)) } async fn read_at(&self, buf: T, pos: u64) -> BufResult { @@ -343,98 +328,50 @@ mod io_uring { } } - pin_project_lite::pin_project! { - struct BytesStream { - #[pin] - state: ReadFileState, - size: u64, - cursor: u64, - #[pin] - callback: F, - } - } - - pin_project_lite::pin_project! { - #[project = ReadFileStateProj] - #[project_replace = ReadFileStateProjReplace] - enum ReadFileState { - File { - file: Option, - bytes: Option, - }, - Future { - #[pin] - fut: Fut, - }, - } - } - - async fn read_file( + fn bytes_stream( file: File, - buf: BytesMut, - cursor: u64, - ) -> (File, BufResult) { - let buf_res = file.read_at(buf, cursor).await; + from_start: Option, + len: Option, + ) -> impl Stream> { + streem::try_from_fn(|yielder| async move { + let file_size = file.metadata().await?.len(); - (file, buf_res) - } + let mut cursor = from_start.unwrap_or(0); + let remaining_size = file_size.saturating_sub(cursor); + let read_until = len.unwrap_or(remaining_size) + cursor; - impl Stream for BytesStream - where - F: Fn(File, BytesMut, u64) -> Fut, - Fut: Future)> + 'static, - { - type Item = std::io::Result; + let mut bytes = BytesMut::new(); - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.as_mut().project(); + loop { + let max_size = read_until.saturating_sub(cursor); - match this.state.as_mut().project() { - ReadFileStateProj::File { file, bytes } => { - let cursor = *this.cursor; - let max_size = *this.size - *this.cursor; - - if max_size == 0 { - return Poll::Ready(None); - } - - let capacity = max_size.min(65_356) as usize; - let mut bytes = bytes.take().unwrap(); - let file = file.take().unwrap(); - - if bytes.capacity() < capacity { - bytes.reserve(capacity - bytes.capacity()); - } - - let fut = (this.callback)(file, bytes, cursor); - - this.state.project_replace(ReadFileState::Future { fut }); - self.poll_next(cx) + if max_size == 0 { + break; } - ReadFileStateProj::Future { fut } => match fut.poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready((file, (Ok(n), mut buf))) => { - let bytes = buf.split_off(n); - this.state.project_replace(ReadFileState::File { - file: Some(file), - bytes: Some(bytes), - }); + let capacity = max_size.min(65_356) as usize; - let n: u64 = match n.try_into() { - Ok(n) => n, - Err(_) => { - return Poll::Ready(Some(Err(std::io::ErrorKind::Other.into()))) - } - }; - *this.cursor += n; + if bytes.capacity() < capacity { + bytes.reserve(capacity - bytes.capacity()); + } - Poll::Ready(Some(Ok(buf.into()))) - } - Poll::Ready((_, (Err(e), _))) => Poll::Ready(Some(Err(e))), - }, + let (result, mut buf_) = file.read_at(bytes, cursor).await; + + let n = match result { + Ok(n) => n, + Err(e) => return Err(e), + }; + + bytes = buf_.split_off(n); + + let n: u64 = n.try_into().map_err(|_| std::io::ErrorKind::Other)?; + cursor += n; + + yielder.yield_ok(buf_.into()).await; } - } + + Ok(()) + }) } #[cfg(test)] diff --git a/src/ingest.rs b/src/ingest.rs index 2e394c7..062c82f 100644 --- a/src/ingest.rs +++ b/src/ingest.rs @@ -7,12 +7,12 @@ use crate::{ formats::{InternalFormat, Validations}, repo::{Alias, ArcRepo, DeleteToken, Hash}, store::Store, - stream::{IntoStreamer, MakeSend}, }; use actix_web::web::Bytes; use futures_core::Stream; use reqwest::Body; use reqwest_middleware::ClientWithMiddleware; +use streem::IntoStreamer; use tracing::{Instrument, Span}; mod hasher; @@ -30,10 +30,11 @@ pub(crate) struct Session { #[tracing::instrument(skip(stream))] async fn aggregate(stream: S) -> Result where - S: Stream> + Unpin, + S: Stream>, { let mut buf = BytesStream::new(); + let stream = std::pin::pin!(stream); let mut stream = stream.into_streamer(); while let Some(res) = stream.next().await { @@ -48,7 +49,7 @@ pub(crate) async fn ingest( repo: &ArcRepo, store: &S, client: &ClientWithMiddleware, - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, declared_alias: Option, media: &crate::config::Media, ) -> Result @@ -117,12 +118,12 @@ where }; if let Some(endpoint) = &media.external_validation { - let stream = store.to_stream(&identifier, None, None).await?.make_send(); + let stream = store.to_stream(&identifier, None, None).await?; let response = client .post(endpoint.as_str()) .header("Content-Type", input_type.media_type().as_ref()) - .body(Body::wrap_stream(stream)) + .body(Body::wrap_stream(crate::stream::make_send(stream))) .send() .instrument(tracing::info_span!("external-validation")) .await?; diff --git a/src/lib.rs b/src/lib.rs index a66c890..8c91ab3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,6 +54,7 @@ use std::{ sync::Arc, time::{Duration, SystemTime}, }; +use streem::IntoStreamer; use tokio::sync::Semaphore; use tracing::Instrument; use tracing_actix_web::TracingLogger; @@ -74,7 +75,7 @@ use self::{ repo::{sled::SledRepo, Alias, DeleteToken, Hash, Repo, UploadId, UploadResult}, serde_str::Serde, store::{file_store::FileStore, object_store::ObjectStore, Store}, - stream::{empty, once, StreamLimit, StreamMap, StreamTimeout}, + stream::{empty, once}, }; pub use self::config::{ConfigSource, PictRsConfiguration}; @@ -165,14 +166,14 @@ impl FormData for Upload { let span = tracing::info_span!("file-upload", ?filename); - let stream = stream.map(|res| res.map_err(Error::from)); - Box::pin( async move { if config.server.read_only { return Err(UploadError::ReadOnly.into()); } + let stream = crate::stream::from_err(stream); + ingest::ingest(&repo, &**store, &client, stream, None, &config.media) .await } @@ -230,14 +231,14 @@ impl FormData for Import { let span = tracing::info_span!("file-import", ?filename); - let stream = stream.map(|res| res.map_err(Error::from)); - Box::pin( async move { if config.server.read_only { return Err(UploadError::ReadOnly.into()); } + let stream = crate::stream::from_err(stream); + ingest::ingest( &repo, &**store, @@ -368,14 +369,14 @@ impl FormData for BackgroundedUpload { let span = tracing::info_span!("file-proxy", ?filename); - let stream = stream.map(|res| res.map_err(Error::from)); - Box::pin( async move { if read_only { return Err(UploadError::ReadOnly.into()); } + let stream = crate::stream::from_err(stream); + Backgrounded::proxy(repo, store, stream).await } .instrument(span), @@ -488,7 +489,7 @@ struct UrlQuery { } async fn ingest_inline( - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, repo: &ArcRepo, store: &S, client: &ClientWithMiddleware, @@ -527,7 +528,7 @@ async fn download_stream( client: &ClientWithMiddleware, url: &str, config: &Configuration, -) -> Result> + Unpin + 'static, Error> { +) -> Result> + 'static, Error> { if config.server.read_only { return Err(UploadError::ReadOnly.into()); } @@ -538,10 +539,10 @@ async fn download_stream( return Err(UploadError::Download(res.status()).into()); } - let stream = res - .bytes_stream() - .map(|res| res.map_err(Error::from)) - .limit((config.media.max_file_size * MEGABYTES) as u64); + let stream = crate::stream::limit( + config.media.max_file_size * MEGABYTES, + crate::stream::from_err(res.bytes_stream()), + ); Ok(stream) } @@ -551,7 +552,7 @@ async fn download_stream( skip(stream, repo, store, client, config) )] async fn do_download_inline( - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, repo: web::Data, store: web::Data, client: &ClientWithMiddleware, @@ -574,7 +575,7 @@ async fn do_download_inline( #[tracing::instrument(name = "Downloading file in background", skip(stream, repo, store))] async fn do_download_backgrounded( - stream: impl Stream> + Unpin + 'static, + stream: impl Stream> + 'static, repo: web::Data, store: web::Data, ) -> Result { @@ -1325,9 +1326,7 @@ async fn ranged_file_resp( ( builder, Either::left(Either::left( - range::chop_store(range, store, &identifier, len) - .await? - .map(|res| res.map_err(Error::from)), + range::chop_store(range, store, &identifier, len).await?, )), ) } else { @@ -1341,10 +1340,7 @@ async fn ranged_file_resp( } } else { //No Range header in the request - return the entire document - let stream = store - .to_stream(&identifier, None, None) - .await? - .map(|res| res.map_err(Error::from)); + let stream = crate::stream::from_err(store.to_stream(&identifier, None, None).await?); if not_found { (HttpResponse::NotFound(), Either::right(stream)) @@ -1375,10 +1371,18 @@ where E: std::error::Error + 'static, actix_web::Error: From, { - let stream = stream.timeout(Duration::from_secs(5)).map(|res| match res { - Ok(Ok(item)) => Ok(item), - Ok(Err(e)) => Err(actix_web::Error::from(e)), - Err(e) => Err(Error::from(e).into()), + let stream = crate::stream::timeout(Duration::from_secs(5), stream); + + let stream = streem::try_from_fn(|yielder| async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); + + while let Some(res) = streamer.next().await { + let item = res.map_err(Error::from)??; + yielder.yield_ok(item).await; + } + + Ok(()) as Result<(), actix_web::Error> }); srv_head(builder, ext, expires, modified).streaming(stream) diff --git a/src/migrate_store.rs b/src/migrate_store.rs index d2ce902..3b47812 100644 --- a/src/migrate_store.rs +++ b/src/migrate_store.rs @@ -7,12 +7,13 @@ use std::{ time::{Duration, Instant}, }; +use streem::IntoStreamer; + use crate::{ details::Details, error::{Error, UploadError}, repo::{ArcRepo, Hash}, store::Store, - stream::IntoStreamer, }; pub(super) async fn migrate_store( @@ -106,7 +107,8 @@ where } // Hashes are read in a consistent order - let mut stream = repo.hashes().into_streamer(); + let stream = std::pin::pin!(repo.hashes()); + let mut stream = stream.into_streamer(); let state = Rc::new(MigrateState { repo: repo.clone(), diff --git a/src/queue/cleanup.rs b/src/queue/cleanup.rs index ea3ff75..2b91edb 100644 --- a/src/queue/cleanup.rs +++ b/src/queue/cleanup.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use streem::IntoStreamer; + use crate::{ config::Configuration, error::{Error, UploadError}, @@ -8,7 +10,6 @@ use crate::{ repo::{Alias, ArcRepo, DeleteToken, Hash}, serde_str::Serde, store::Store, - stream::IntoStreamer, }; pub(super) fn perform<'a, S>( @@ -137,7 +138,8 @@ async fn alias(repo: &ArcRepo, alias: Alias, token: DeleteToken) -> Result<(), E #[tracing::instrument(skip_all)] async fn all_variants(repo: &ArcRepo) -> Result<(), Error> { - let mut hash_stream = repo.hashes().into_streamer(); + let hash_stream = std::pin::pin!(repo.hashes()); + let mut hash_stream = hash_stream.into_streamer(); while let Some(res) = hash_stream.next().await { let hash = res?; diff --git a/src/queue/process.rs b/src/queue/process.rs index 29fccf7..480a504 100644 --- a/src/queue/process.rs +++ b/src/queue/process.rs @@ -12,7 +12,6 @@ use crate::{ repo::{Alias, ArcRepo, UploadId, UploadResult}, serde_str::Serde, store::Store, - stream::StreamMap, }; use std::{path::PathBuf, sync::Arc}; @@ -131,10 +130,7 @@ where let media = media.clone(); let error_boundary = crate::sync::spawn(async move { - let stream = store2 - .to_stream(&ident, None, None) - .await? - .map(|res| res.map_err(Error::from)); + let stream = crate::stream::from_err(store2.to_stream(&ident, None, None).await?); let session = crate::ingest::ingest(&repo, &store2, &client, stream, declared_alias, &media) diff --git a/src/repo.rs b/src/repo.rs index fddda9a..01f5c6e 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -8,10 +8,10 @@ use crate::{ config, details::Details, error_code::{ErrorCode, OwnedErrorCode}, - future::LocalBoxFuture, stream::LocalBoxStream, }; use base64::Engine; +use futures_core::Stream; use std::{fmt::Debug, sync::Arc}; use url::Url; use uuid::Uuid; @@ -572,81 +572,29 @@ impl HashPage { } } -type PageFuture = LocalBoxFuture<'static, Result>; - -pub(crate) struct HashStream { - repo: Option, - page_future: Option, - page: Option, -} - -impl futures_core::Stream for HashStream { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - - loop { - let Some(repo) = &this.repo else { - return std::task::Poll::Ready(None); - }; - - let slug = if let Some(page) = &mut this.page { - // popping last in page is fine - we reversed them - if let Some(hash) = page.hashes.pop() { - return std::task::Poll::Ready(Some(Ok(hash))); - } - - let slug = page.next(); - this.page.take(); - - if let Some(slug) = slug { - Some(slug) - } else { - this.repo.take(); - return std::task::Poll::Ready(None); - } - } else { - None - }; - - if let Some(page_future) = &mut this.page_future { - let res = std::task::ready!(page_future.as_mut().poll(cx)); - - this.page_future.take(); - - match res { - Ok(mut page) => { - // reverse because we use `pop` to fetch next - page.hashes.reverse(); - - this.page = Some(page); - } - Err(e) => { - this.repo.take(); - - return std::task::Poll::Ready(Some(Err(e))); - } - } - } else { - let repo = repo.clone(); - - this.page_future = Some(Box::pin(async move { repo.hash_page(slug, 100).await })); - } - } - } -} - impl dyn FullRepo { - pub(crate) fn hashes(self: &Arc) -> HashStream { - HashStream { - repo: Some(self.clone()), - page_future: None, - page: None, - } + pub(crate) fn hashes(self: &Arc) -> impl Stream> { + let repo = self.clone(); + + streem::try_from_fn(|yielder| async move { + let mut slug = None; + + loop { + let page = repo.hash_page(slug, 100).await?; + + slug = page.next(); + + for hash in page.hashes { + yielder.yield_ok(hash).await; + } + + if slug.is_none() { + break; + } + } + + Ok(()) + }) } } diff --git a/src/repo/migrate.rs b/src/repo/migrate.rs index e10177a..8083696 100644 --- a/src/repo/migrate.rs +++ b/src/repo/migrate.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use streem::IntoStreamer; use tokio::task::JoinSet; use crate::{ @@ -12,7 +13,6 @@ use crate::{ SledRepo as OldSledRepo, }, store::Store, - stream::IntoStreamer, }; const MIGRATE_CONCURRENCY: usize = 32; @@ -35,7 +35,8 @@ pub(crate) async fn migrate_repo(old_repo: ArcRepo, new_repo: ArcRepo) -> Result tracing::warn!("Checks complete, migrating repo"); tracing::warn!("{total_size} hashes will be migrated"); - let mut hash_stream = old_repo.hashes().into_streamer(); + let hash_stream = std::pin::pin!(old_repo.hashes()); + let mut hash_stream = hash_stream.into_streamer(); let mut index = 0; while let Some(res) = hash_stream.next().await { diff --git a/src/repo/postgres.rs b/src/repo/postgres.rs index 459d56b..99f0b3e 100644 --- a/src/repo/postgres.rs +++ b/src/repo/postgres.rs @@ -21,6 +21,7 @@ use diesel_async::{ }, AsyncConnection, AsyncPgConnection, RunQueryDsl, }; +use futures_core::Stream; use tokio::sync::Notify; use tokio_postgres::{tls::NoTlsStream, AsyncMessage, Connection, NoTls, Notification, Socket}; use tracing::Instrument; @@ -30,7 +31,7 @@ use uuid::Uuid; use crate::{ details::Details, error_code::{ErrorCode, OwnedErrorCode}, - future::{LocalBoxFuture, WithMetrics, WithTimeout}, + future::{WithMetrics, WithTimeout}, serde_str::Serde, stream::LocalBoxStream, }; @@ -1477,33 +1478,29 @@ impl AliasAccessRepo for PostgresRepo { &self, timestamp: time::OffsetDateTime, ) -> Result>, RepoError> { - Ok(Box::pin(PageStream { - inner: self.inner.clone(), - future: None, - current: Vec::new(), - older_than: to_primitive(timestamp), - next: Box::new(|inner, older_than| { - Box::pin(async move { - use schema::proxies::dsl::*; + Ok(Box::pin(page_stream( + self.inner.clone(), + to_primitive(timestamp), + |inner, older_than| async move { + use schema::proxies::dsl::*; - let mut conn = inner.get_connection().await?; + let mut conn = inner.get_connection().await?; - let vec = proxies - .select((accessed, alias)) - .filter(accessed.lt(older_than)) - .order(accessed.desc()) - .limit(100) - .get_results(&mut conn) - .with_metrics("pict-rs.postgres.alias-access.older-aliases") - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)? - .map_err(PostgresError::Diesel)?; + let vec = proxies + .select((accessed, alias)) + .filter(accessed.lt(older_than)) + .order(accessed.desc()) + .limit(100) + .get_results(&mut conn) + .with_metrics("pict-rs.postgres.alias-access.older-aliases") + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; - Ok(vec) - }) - }), - })) + Ok(vec) + }, + ))) } async fn remove_alias_access(&self, _: Alias) -> Result<(), RepoError> { @@ -1570,33 +1567,29 @@ impl VariantAccessRepo for PostgresRepo { &self, timestamp: time::OffsetDateTime, ) -> Result>, RepoError> { - Ok(Box::pin(PageStream { - inner: self.inner.clone(), - future: None, - current: Vec::new(), - older_than: to_primitive(timestamp), - next: Box::new(|inner, older_than| { - Box::pin(async move { - use schema::variants::dsl::*; + Ok(Box::pin(page_stream( + self.inner.clone(), + to_primitive(timestamp), + |inner, older_than| async move { + use schema::variants::dsl::*; - let mut conn = inner.get_connection().await?; + let mut conn = inner.get_connection().await?; - let vec = variants - .select((accessed, (hash, variant))) - .filter(accessed.lt(older_than)) - .order(accessed.desc()) - .limit(100) - .get_results(&mut conn) - .with_metrics("pict-rs.postgres.variant-access.older-variants") - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)? - .map_err(PostgresError::Diesel)?; + let vec = variants + .select((accessed, (hash, variant))) + .filter(accessed.lt(older_than)) + .order(accessed.desc()) + .limit(100) + .get_results(&mut conn) + .with_metrics("pict-rs.postgres.variant-access.older-variants") + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; - Ok(vec) - }) - }), - })) + Ok(vec) + }, + ))) } async fn remove_variant_access(&self, _: Hash, _: String) -> Result<(), RepoError> { @@ -1778,62 +1771,36 @@ impl FullRepo for PostgresRepo { } } -type NextFuture = LocalBoxFuture<'static, Result, RepoError>>; - -struct PageStream { +fn page_stream( inner: Arc, - future: Option>, - current: Vec, - older_than: time::PrimitiveDateTime, - next: Box, time::PrimitiveDateTime) -> NextFuture>, -} - -impl futures_core::Stream for PageStream + mut older_than: time::PrimitiveDateTime, + next: F, +) -> impl Stream> where - I: Unpin, + F: Fn(Arc, time::PrimitiveDateTime) -> Fut, + Fut: std::future::Future, RepoError>> + + 'static, + I: 'static, { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - + streem::try_from_fn(|yielder| async move { loop { - // Pop because we reversed the list - if let Some(alias) = this.current.pop() { - return std::task::Poll::Ready(Some(Ok(alias))); - } + let mut page = (next)(inner.clone(), older_than).await?; - if let Some(future) = this.future.as_mut() { - let res = std::task::ready!(future.as_mut().poll(cx)); - - this.future.take(); - - match res { - Ok(page) if page.is_empty() => { - return std::task::Poll::Ready(None); - } - Ok(page) => { - let (mut timestamps, mut aliases): (Vec<_>, Vec<_>) = - page.into_iter().unzip(); - // reverse because we use .pop() to get next - aliases.reverse(); - - this.current = aliases; - this.older_than = timestamps.pop().expect("Verified nonempty"); - } - Err(e) => return std::task::Poll::Ready(Some(Err(e))), + if let Some((last_time, last_item)) = page.pop() { + for (_, item) in page { + yielder.yield_ok(item).await; } - } else { - let inner = this.inner.clone(); - let older_than = this.older_than; - this.future = Some((this.next)(inner, older_than)); + yielder.yield_ok(last_item).await; + + older_than = last_time; + } else { + break; } } - } + + Ok(()) + }) } impl std::fmt::Debug for PostgresRepo { diff --git a/src/store/object_store.rs b/src/store/object_store.rs index 324a0eb..7046f69 100644 --- a/src/store/object_store.rs +++ b/src/store/object_store.rs @@ -1,9 +1,6 @@ use crate::{ - bytes_stream::BytesStream, - error_code::ErrorCode, - repo::ArcRepo, - store::Store, - stream::{IntoStreamer, LocalBoxStream, StreamMap}, + bytes_stream::BytesStream, error_code::ErrorCode, repo::ArcRepo, store::Store, + stream::LocalBoxStream, }; use actix_rt::task::JoinError; use actix_web::{ @@ -21,6 +18,7 @@ use reqwest_middleware::{ClientWithMiddleware, RequestBuilder}; use rusty_s3::{actions::S3Action, Bucket, BucketError, Credentials, UrlStyle}; use std::{string::FromUtf8Error, sync::Arc, time::Duration}; use storage_path_generator::{Generator, Path}; +use streem::IntoStreamer; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::io::ReaderStream; use tracing::Instrument; @@ -393,11 +391,10 @@ impl Store for ObjectStore { return Err(status_error(response).await); } - Ok(Box::pin( - response - .bytes_stream() - .map(|res| res.map_err(payload_to_io_error)), - )) + Ok(Box::pin(crate::stream::map_err( + response.bytes_stream(), + payload_to_io_error, + ))) } #[tracing::instrument(skip(self, writer))] diff --git a/src/stream.rs b/src/stream.rs index 7a67a9e..c52dc1b 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,269 +1,173 @@ -use actix_rt::{task::JoinHandle, time::Sleep}; use actix_web::web::Bytes; -use flume::r#async::RecvStream; use futures_core::Stream; -use std::{ - future::Future, - marker::PhantomData, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Wake, Waker}, - time::Duration, -}; +use std::{pin::Pin, time::Duration}; +use streem::IntoStreamer; -pub(crate) trait MakeSend: Stream> +pub(crate) fn make_send(stream: S) -> impl Stream + Send where - T: 'static, + S: Stream + 'static, + S::Item: Send + Sync, { - fn make_send(self) -> MakeSendStream - where - Self: Sized + 'static, - { - let (tx, rx) = crate::sync::channel(4); + let (tx, rx) = crate::sync::channel(1); - MakeSendStream { - handle: crate::sync::spawn(async move { - let this = std::pin::pin!(self); + let handle = crate::sync::spawn(async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); - let mut stream = this.into_streamer(); - - while let Some(res) = stream.next().await { - if tx.send_async(res).await.is_err() { - return; - } - } - }), - rx: rx.into_stream(), - } - } -} - -impl MakeSend for S -where - S: Stream>, - T: 'static, -{ -} - -pub(crate) struct MakeSendStream -where - T: 'static, -{ - handle: actix_rt::task::JoinHandle<()>, - rx: flume::r#async::RecvStream<'static, std::io::Result>, -} - -impl Stream for MakeSendStream -where - T: 'static, -{ - type Item = std::io::Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.rx).poll_next(cx) { - Poll::Ready(opt) => Poll::Ready(opt), - Poll::Pending if std::task::ready!(Pin::new(&mut self.handle).poll(cx)).is_err() => { - Poll::Ready(Some(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "Stream panicked", - )))) + while let Some(res) = streamer.next().await { + if tx.send_async(res).await.is_err() { + break; } - Poll::Pending => Poll::Pending, } - } + }); + + streem::from_fn(|yiedler| async move { + let mut stream = rx.into_stream().into_streamer(); + + while let Some(res) = stream.next().await { + yiedler.yield_(res).await; + } + + let _ = handle.await; + }) } -pin_project_lite::pin_project! { - pub(crate) struct Map { - #[pin] - stream: S, - func: F, - } +pub(crate) fn from_iterator(iterator: I, buffer: usize) -> impl Stream + Send +where + I: IntoIterator + Send + 'static, + I::Item: Send + Sync, +{ + let (tx, rx) = crate::sync::channel(buffer); + + let handle = crate::sync::spawn_blocking(move || { + for value in iterator { + if tx.send(value).is_err() { + break; + } + } + }); + + streem::from_fn(|yielder| async move { + let mut stream = rx.into_stream().into_streamer(); + + while let Some(res) = stream.next().await { + yielder.yield_(res).await; + } + + let _ = handle.await; + }) } -pub(crate) trait StreamMap: Stream { - fn map(self, func: F) -> Map - where - F: FnMut(Self::Item) -> U, - Self: Sized, - { - Map { stream: self, func } - } +pub(crate) fn map(stream: S, f: F) -> impl Stream +where + S: Stream, + I2: 'static, + F: Fn(I1) -> I2 + Copy, +{ + streem::from_fn(|yielder| async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); + + while let Some(res) = streamer.next().await { + yielder.yield_((f)(res)).await; + } + }) } -impl StreamMap for T where T: Stream {} +#[cfg(not(feature = "io-uring"))] +pub(crate) fn map_ok(stream: S, f: F) -> impl Stream> +where + S: Stream>, + T2: 'static, + E: 'static, + F: Fn(T1) -> T2 + Copy, +{ + map(stream, move |res| res.map(f)) +} -impl Stream for Map +pub(crate) fn map_err(stream: S, f: F) -> impl Stream> +where + S: Stream>, + T: 'static, + E2: 'static, + F: Fn(E1) -> E2 + Copy, +{ + map(stream, move |res| res.map_err(f)) +} + +pub(crate) fn from_err(stream: S) -> impl Stream> +where + S: Stream>, + T: 'static, + E1: Into, + E2: 'static, +{ + map_err(stream, Into::into) +} + +pub(crate) fn empty() -> impl Stream +where + T: 'static, +{ + streem::from_fn(|_| std::future::ready(())) +} + +pub(crate) fn once(value: T) -> impl Stream +where + T: 'static, +{ + streem::from_fn(|yielder| yielder.yield_(value)) +} + +pub(crate) fn timeout( + duration: Duration, + stream: S, +) -> impl Stream> where S: Stream, - F: FnMut(S::Item) -> U, + S::Item: 'static, { - type Item = U; + streem::try_from_fn(|yielder| async move { + actix_rt::time::timeout(duration, async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - let value = std::task::ready!(this.stream.poll_next(cx)); - - Poll::Ready(value.map(this.func)) - } + while let Some(res) = streamer.next().await { + yielder.yield_ok(res).await; + } + }) + .await + .map_err(|_| TimeoutError) + }) } -pub(crate) struct Empty(PhantomData); - -impl Stream for Empty { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(None) - } -} - -pub(crate) fn empty() -> Empty { - Empty(PhantomData) -} - -pub(crate) struct Once(Option); - -impl Stream for Once +pub(crate) fn limit(limit: usize, stream: S) -> impl Stream> where - T: Unpin, + S: Stream>, + E: From + 'static, { - type Item = T; + streem::try_from_fn(|yielder| async move { + let stream = std::pin::pin!(stream); + let mut streamer = stream.into_streamer(); - fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(self.0.take()) - } -} + let mut count = 0; -pub(crate) fn once(value: T) -> Once { - Once(Some(value)) + while let Some(bytes) = streamer.try_next().await? { + count += bytes.len(); + + if count > limit { + return Err(LimitError.into()); + } + + yielder.yield_ok(bytes).await; + } + + Ok(()) + }) } pub(crate) type LocalBoxStream<'a, T> = Pin + 'a>>; -pub(crate) trait StreamLimit { - fn limit(self, limit: u64) -> Limit - where - Self: Sized, - { - Limit { - inner: self, - count: 0, - limit, - } - } -} - -pub(crate) trait StreamTimeout { - fn timeout(self, duration: Duration) -> Timeout - where - Self: Sized, - { - Timeout { - sleep: actix_rt::time::sleep(duration), - inner: self, - expired: false, - woken: Arc::new(AtomicBool::new(true)), - } - } -} - -pub(crate) trait IntoStreamer: Stream { - fn into_streamer(self) -> Streamer - where - Self: Sized, - { - Streamer(Some(self)) - } -} - -impl IntoStreamer for T where T: Stream + Unpin {} - -pub(crate) fn from_iterator( - iterator: I, - buffer: usize, -) -> IterStream { - IterStream { - state: IterStreamState::New { iterator, buffer }, - } -} - -impl StreamLimit for S where S: Stream> {} -impl StreamTimeout for S where S: Stream {} - -pub(crate) struct Streamer(Option); - -impl Streamer { - pub(crate) async fn next(&mut self) -> Option - where - S: Stream + Unpin, - { - let stream = self.0.as_mut().take()?; - - let opt = std::future::poll_fn(|cx| Pin::new(&mut *stream).poll_next(cx)).await; - - if opt.is_none() { - self.0.take(); - } - - opt - } -} - -pin_project_lite::pin_project! { - pub(crate) struct Limit { - #[pin] - inner: S, - - count: u64, - limit: u64, - } -} - -pin_project_lite::pin_project! { - pub(crate) struct Timeout { - #[pin] - sleep: Sleep, - - #[pin] - inner: S, - - expired: bool, - woken: Arc, - } -} - -enum IterStreamState -where - T: 'static, -{ - New { - iterator: I, - buffer: usize, - }, - Running { - handle: JoinHandle<()>, - receiver: RecvStream<'static, T>, - }, - Pending, -} - -pub(crate) struct IterStream -where - T: 'static, -{ - state: IterStreamState, -} - -struct TimeoutWaker { - woken: Arc, - inner: Waker, -} - #[derive(Debug, thiserror::Error)] #[error("Resonse body larger than size limit")] pub(crate) struct LimitError; @@ -271,135 +175,3 @@ pub(crate) struct LimitError; #[derive(Debug, thiserror::Error)] #[error("Timeout in body")] pub(crate) struct TimeoutError; - -impl Stream for Limit -where - S: Stream>, - E: From, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().project(); - - let limit = this.limit; - let count = this.count; - let inner = this.inner; - - inner.poll_next(cx).map(|opt| { - opt.map(|res| match res { - Ok(bytes) => { - *count += bytes.len() as u64; - if *count > *limit { - return Err(LimitError.into()); - } - Ok(bytes) - } - Err(e) => Err(e), - }) - }) - } -} - -impl Wake for TimeoutWaker { - fn wake(self: Arc) { - self.wake_by_ref() - } - - fn wake_by_ref(self: &Arc) { - self.woken.store(true, Ordering::Release); - self.inner.wake_by_ref(); - } -} - -impl Stream for Timeout -where - S: Stream, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().project(); - - if *this.expired { - return Poll::Ready(None); - } - - if this.woken.swap(false, Ordering::Acquire) { - let timeout_waker = Arc::new(TimeoutWaker { - woken: Arc::clone(this.woken), - inner: cx.waker().clone(), - }) - .into(); - - let mut timeout_cx = Context::from_waker(&timeout_waker); - - if this.sleep.poll(&mut timeout_cx).is_ready() { - *this.expired = true; - return Poll::Ready(Some(Err(TimeoutError))); - } - } - - this.inner.poll_next(cx).map(|opt| opt.map(Ok)) - } -} - -impl Stream for IterStream -where - I: IntoIterator + Send + Unpin + 'static, - T: Send + 'static, -{ - type Item = T; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().get_mut(); - - match std::mem::replace(&mut this.state, IterStreamState::Pending) { - IterStreamState::New { iterator, buffer } => { - let (sender, receiver) = crate::sync::channel(buffer); - - let mut handle = crate::sync::spawn_blocking(move || { - let iterator = iterator.into_iter(); - - for item in iterator { - if sender.send(item).is_err() { - break; - } - } - }); - - if Pin::new(&mut handle).poll(cx).is_ready() { - return Poll::Ready(None); - } - - this.state = IterStreamState::Running { - handle, - receiver: receiver.into_stream(), - }; - - self.poll_next(cx) - } - IterStreamState::Running { - mut handle, - mut receiver, - } => match Pin::new(&mut receiver).poll_next(cx) { - Poll::Ready(Some(item)) => { - this.state = IterStreamState::Running { handle, receiver }; - - Poll::Ready(Some(item)) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => { - if Pin::new(&mut handle).poll(cx).is_ready() { - return Poll::Ready(None); - } - - this.state = IterStreamState::Running { handle, receiver }; - - Poll::Pending - } - }, - IterStreamState::Pending => panic!("Polled after completion"), - } - } -} diff --git a/src/sync.rs b/src/sync.rs index 3111daf..10cc861 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -2,22 +2,27 @@ use std::sync::Arc; use tokio::sync::{Notify, Semaphore}; +#[track_caller] pub(crate) fn channel(bound: usize) -> (flume::Sender, flume::Receiver) { tracing::trace_span!(parent: None, "make channel").in_scope(|| flume::bounded(bound)) } +#[track_caller] pub(crate) fn notify() -> Arc { Arc::new(bare_notify()) } +#[track_caller] pub(crate) fn bare_notify() -> Notify { tracing::trace_span!(parent: None, "make notifier").in_scope(Notify::new) } +#[track_caller] pub(crate) fn bare_semaphore(permits: usize) -> Semaphore { tracing::trace_span!(parent: None, "make semaphore").in_scope(|| Semaphore::new(permits)) } +#[track_caller] pub(crate) fn spawn(future: F) -> actix_rt::task::JoinHandle where F: std::future::Future + 'static, @@ -26,6 +31,7 @@ where tracing::trace_span!(parent: None, "spawn task").in_scope(|| actix_rt::spawn(future)) } +#[track_caller] pub(crate) fn spawn_blocking(function: F) -> actix_rt::task::JoinHandle where F: FnOnce() -> Out + Send + 'static,