diff --git a/src/error.rs b/src/error.rs index 679f70b..940f221 100644 --- a/src/error.rs +++ b/src/error.rs @@ -118,7 +118,7 @@ pub(crate) enum UploadError { Range, #[error("Hit limit")] - Limit(#[from] super::LimitError), + Limit(#[from] crate::stream::LimitError), } impl From for UploadError { diff --git a/src/main.rs b/src/main.rs index 8d5549d..8286c17 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,13 +14,11 @@ use std::{ collections::BTreeSet, future::ready, path::PathBuf, - pin::Pin, sync::atomic::{AtomicU64, Ordering}, - task::{Context, Poll}, time::SystemTime, }; use tokio::{io::AsyncReadExt, sync::Semaphore}; -use tracing::{debug, error, info, instrument}; +use tracing::{debug, info, instrument}; use tracing_actix_web::TracingLogger; use tracing_awc::Tracing; use tracing_futures::Instrument; @@ -44,6 +42,7 @@ mod range; mod repo; mod serde_str; mod store; +mod stream; mod tmp_file; mod upload_manager; mod validate; @@ -61,6 +60,7 @@ use self::{ repo::{Alias, DeleteToken, Repo}, serde_str::Serde, store::{file_store::FileStore, object_store::ObjectStore, Store}, + stream::StreamLimit, upload_manager::{UploadManager, UploadManagerSession}, }; @@ -138,59 +138,6 @@ struct UrlQuery { url: String, } -pin_project_lite::pin_project! { - struct Limit { - #[pin] - inner: S, - - count: u64, - limit: u64, - } -} - -impl Limit { - fn new(inner: S, limit: u64) -> Self { - Limit { - inner, - count: 0, - limit, - } - } -} - -#[derive(Debug, thiserror::Error)] -#[error("Resonse body larger than size limit")] -struct LimitError; - -impl Stream for Limit -where - S: Stream>, - E: From, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().project(); - - let limit = this.limit; - let count = this.count; - let inner = this.inner; - - inner.poll_next(cx).map(|opt| { - opt.map(|res| match res { - Ok(bytes) => { - *count += bytes.len() as u64; - if *count > *limit { - return Err(LimitError.into()); - } - Ok(bytes) - } - Err(e) => Err(e), - }) - }) - } -} - /// download an image from a URL #[instrument(name = "Downloading file", skip(client, manager))] async fn download( @@ -205,10 +152,9 @@ async fn download( return Err(UploadError::Download(res.status()).into()); } - let stream = Limit::new( - res.map_err(Error::from), - (CONFIG.media.max_file_size * MEGABYTES) as u64, - ); + let stream = res + .map_err(Error::from) + .limit((CONFIG.media.max_file_size * MEGABYTES) as u64); futures_util::pin_mut!(stream); diff --git a/src/repo/sled.rs b/src/repo/sled.rs index a049861..b65b562 100644 --- a/src/repo/sled.rs +++ b/src/repo/sled.rs @@ -4,9 +4,11 @@ use crate::{ Alias, AliasRepo, AlreadyExists, DeleteToken, Details, HashRepo, Identifier, IdentifierRepo, QueueRepo, SettingsRepo, }, + stream::from_iterator, }; +use futures_util::Stream; use sled::{Db, IVec, Tree}; -use std::sync::Arc; +use std::{pin::Pin, sync::Arc}; use tokio::sync::Notify; use super::BaseRepo; @@ -205,65 +207,8 @@ impl IdentifierRepo for SledRepo { } } -type BoxIterator<'a, T> = Box + Send + 'a>; - -type HashIterator = BoxIterator<'static, Result>; - type StreamItem = Result; - -type NextFutResult = Result<(HashIterator, Option), Error>; - -pub(crate) struct HashStream { - hashes: Option, - next_fut: Option>, -} - -impl futures_util::Stream for HashStream { - type Item = StreamItem; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - - if let Some(mut fut) = this.next_fut.take() { - match fut.as_mut().poll(cx) { - std::task::Poll::Ready(Ok((iter, opt))) => { - this.hashes = Some(iter); - std::task::Poll::Ready(opt) - } - std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Some(Err(e))), - std::task::Poll::Pending => { - this.next_fut = Some(fut); - std::task::Poll::Pending - } - } - } else if let Some(mut iter) = this.hashes.take() { - let fut = Box::pin(async move { - actix_rt::task::spawn_blocking(move || { - let opt = iter.next(); - - (iter, opt) - }) - .await - .map(|(iter, opt)| { - ( - iter, - opt.map(|res| res.map_err(SledError::from).map_err(Error::from)), - ) - }) - .map_err(SledError::from) - .map_err(Error::from) - }); - - this.next_fut = Some(fut); - std::pin::Pin::new(this).poll_next(cx) - } else { - std::task::Poll::Ready(None) - } - } -} +type LocalBoxStream<'a, T> = Pin + 'a>>; fn hash_alias_key(hash: &IVec, alias: &Alias) -> Vec { let mut v = hash.to_vec(); @@ -273,15 +218,16 @@ fn hash_alias_key(hash: &IVec, alias: &Alias) -> Vec { #[async_trait::async_trait(?Send)] impl HashRepo for SledRepo { - type Stream = HashStream; + type Stream = LocalBoxStream<'static, StreamItem>; async fn hashes(&self) -> Self::Stream { - let iter = self.hashes.iter().keys(); + let iter = self + .hashes + .iter() + .keys() + .map(|res| res.map_err(Error::from)); - HashStream { - hashes: Some(Box::new(iter)), - next_fut: None, - } + Box::pin(from_iterator(iter)) } #[tracing::instrument] diff --git a/src/store/object_store.rs b/src/store/object_store.rs index edddd71..f2bf76b 100644 --- a/src/store/object_store.rs +++ b/src/store/object_store.rs @@ -2,22 +2,16 @@ use crate::{ error::Error, repo::{Repo, SettingsRepo}, store::Store, + stream::StreamTimeout, }; -use actix_rt::time::Sleep; use actix_web::web::Bytes; -use futures_util::{stream::Stream, TryStreamExt}; +use futures_util::{Stream, StreamExt}; use s3::{ client::Client, command::Command, creds::Credentials, request_trait::Request, Bucket, Region, }; use std::{ - future::Future, pin::Pin, string::FromUtf8Error, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Wake, Waker}, time::{Duration, Instant}, }; use storage_path_generator::{Generator, Path}; @@ -58,17 +52,6 @@ pub(crate) struct ObjectStore { client: reqwest::Client, } -pin_project_lite::pin_project! { - struct Timeout { - sleep: Option>>, - - woken: Arc, - - #[pin] - inner: S, - } -} - #[async_trait::async_trait(?Send)] impl Store for ObjectStore { type Identifier = ObjectId; @@ -139,11 +122,12 @@ impl Store for ObjectStore { let allotted = allotted.saturating_sub(now.elapsed()); - let stream = response - .bytes_stream() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); + let stream = response.bytes_stream().timeout(allotted).map(|res| { + res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + .and_then(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))) + }); - Ok(request_span.in_scope(|| Box::pin(timeout(allotted, stream)))) + Ok(request_span.in_scope(|| Box::pin(stream))) } #[tracing::instrument(skip(writer))] @@ -266,67 +250,6 @@ async fn init_generator(repo: &Repo) -> Result { } } -fn timeout(duration: Duration, stream: S) -> impl Stream> -where - S: Stream>, -{ - Timeout { - sleep: Some(Box::pin(actix_rt::time::sleep(duration))), - woken: Arc::new(AtomicBool::new(true)), - inner: stream, - } -} - -struct TimeoutWaker { - woken: Arc, - inner: Waker, -} - -impl Wake for TimeoutWaker { - fn wake(self: Arc) { - self.wake_by_ref() - } - - fn wake_by_ref(self: &Arc) { - self.woken.store(true, Ordering::Release); - self.inner.wake_by_ref(); - } -} - -impl Stream for Timeout -where - S: Stream>, -{ - type Item = std::io::Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().project(); - - if this.woken.swap(false, Ordering::Acquire) { - if let Some(mut sleep) = this.sleep.take() { - let timeout_waker = Arc::new(TimeoutWaker { - woken: Arc::clone(this.woken), - inner: cx.waker().clone(), - }) - .into(); - let mut timeout_cx = Context::from_waker(&timeout_waker); - if let Poll::Ready(()) = sleep.as_mut().poll(&mut timeout_cx) { - return Poll::Ready(Some(Err(std::io::Error::new( - std::io::ErrorKind::Other, - Error::from(ObjectError::Elapsed), - )))); - } else { - *this.sleep = Some(sleep); - } - } else { - return Poll::Ready(None); - } - } - - this.inner.poll_next(cx) - } -} - impl std::fmt::Debug for ObjectStore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ObjectStore") diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..7ded6bd --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,228 @@ +use actix_rt::{task::JoinHandle, time::Sleep}; +use actix_web::web::Bytes; +use futures_util::Stream; +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll, Wake, Waker}, + time::Duration, +}; + +pub(crate) trait StreamLimit { + fn limit(self, limit: u64) -> Limit + where + Self: Sized, + { + Limit { + inner: self, + count: 0, + limit, + } + } +} + +pub(crate) trait StreamTimeout { + fn timeout(self, duration: Duration) -> Timeout + where + Self: Sized, + { + Timeout { + sleep: actix_rt::time::sleep(duration), + inner: self, + expired: false, + woken: Arc::new(AtomicBool::new(true)), + } + } +} + +pub(crate) fn from_iterator( + iterator: I, +) -> IterStream { + IterStream { + state: IterStreamState::New { iterator }, + } +} + +impl StreamLimit for S where S: Stream> {} +impl StreamTimeout for S where S: Stream {} + +pin_project_lite::pin_project! { + pub(crate) struct Limit { + #[pin] + inner: S, + + count: u64, + limit: u64, + } +} + +pin_project_lite::pin_project! { + pub(crate) struct Timeout { + #[pin] + sleep: Sleep, + + #[pin] + inner: S, + + expired: bool, + woken: Arc, + } +} + +enum IterStreamState { + New { + iterator: I, + }, + Running { + handle: JoinHandle<()>, + receiver: tokio::sync::mpsc::Receiver, + }, + Pending, +} + +pub(crate) struct IterStream { + state: IterStreamState, +} + +struct TimeoutWaker { + woken: Arc, + inner: Waker, +} + +#[derive(Debug, thiserror::Error)] +#[error("Resonse body larger than size limit")] +pub(crate) struct LimitError; + +#[derive(Debug, thiserror::Error)] +#[error("Timeout in body")] +pub(crate) struct TimeoutError; + +impl Stream for Limit +where + S: Stream>, + E: From, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.as_mut().project(); + + let limit = this.limit; + let count = this.count; + let inner = this.inner; + + inner.poll_next(cx).map(|opt| { + opt.map(|res| match res { + Ok(bytes) => { + *count += bytes.len() as u64; + if *count > *limit { + return Err(LimitError.into()); + } + Ok(bytes) + } + Err(e) => Err(e), + }) + }) + } +} + +impl Wake for TimeoutWaker { + fn wake(self: Arc) { + self.wake_by_ref() + } + + fn wake_by_ref(self: &Arc) { + self.woken.store(true, Ordering::Release); + self.inner.wake_by_ref(); + } +} + +impl Stream for Timeout +where + S: Stream, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.as_mut().project(); + + if *this.expired { + return Poll::Ready(None); + } + + if this.woken.swap(false, Ordering::Acquire) { + let timeout_waker = Arc::new(TimeoutWaker { + woken: Arc::clone(this.woken), + inner: cx.waker().clone(), + }) + .into(); + + let mut timeout_cx = Context::from_waker(&timeout_waker); + + if this.sleep.poll(&mut timeout_cx).is_ready() { + *this.expired = true; + return Poll::Ready(Some(Err(TimeoutError))); + } + } + + this.inner.poll_next(cx).map(|opt| opt.map(Ok)) + } +} + +impl Stream for IterStream +where + I: IntoIterator + Send + Unpin + 'static, + T: Send + 'static, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.as_mut().get_mut(); + + match std::mem::replace(&mut this.state, IterStreamState::Pending) { + IterStreamState::New { iterator } => { + let (sender, receiver) = tokio::sync::mpsc::channel(1); + + let mut handle = actix_rt::task::spawn_blocking(move || { + let iterator = iterator.into_iter(); + + for item in iterator { + if sender.blocking_send(item).is_err() { + break; + } + } + }); + + if Pin::new(&mut handle).poll(cx).is_ready() { + return Poll::Ready(None); + } + + this.state = IterStreamState::Running { handle, receiver }; + } + IterStreamState::Running { + mut handle, + mut receiver, + } => match Pin::new(&mut receiver).poll_recv(cx) { + Poll::Ready(Some(item)) => { + if Pin::new(&mut handle).poll(cx).is_ready() { + return Poll::Ready(Some(item)); + } + + this.state = IterStreamState::Running { handle, receiver }; + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => { + this.state = IterStreamState::Running { handle, receiver }; + return Poll::Pending; + } + }, + IterStreamState::Pending => return Poll::Ready(None), + } + + self.poll_next(cx) + } +}