diff --git a/src/error.rs b/src/error.rs index bbe3c85..cec2cc2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -128,6 +128,9 @@ pub(crate) enum UploadError { #[error("Command failed")] Status, + + #[error(transparent)] + Limit(#[from] super::LimitError), } impl From for UploadError { @@ -152,6 +155,7 @@ impl ResponseError for Error { fn status_code(&self) -> StatusCode { match self.kind { UploadError::DuplicateAlias + | UploadError::Limit(_) | UploadError::NoFiles | UploadError::Upload(_) | UploadError::ParseReq(_) => StatusCode::BAD_REQUEST, diff --git a/src/file.rs b/src/file.rs index 2c8e79d..703251f 100644 --- a/src/file.rs +++ b/src/file.rs @@ -11,12 +11,18 @@ pub(crate) use io_uring::File; pub(crate) use tokio_file::File; pin_project_lite::pin_project! { - struct CrateError { + pub(super) struct CrateError { #[pin] inner: S } } +impl CrateError { + pub(super) fn new(inner: S) -> Self { + CrateError { inner } + } +} + impl Stream for CrateError where S: Stream>, @@ -112,11 +118,10 @@ mod tokio_file { (None, None) => Either::right(self.inner), }; - Ok(super::CrateError { - inner: BytesFreezer { - inner: FramedRead::new(obj, BytesCodec::new()), - }, - }) + Ok(super::CrateError::new(BytesFreezer::new(FramedRead::new( + obj, + BytesCodec::new(), + )))) } } @@ -127,6 +132,12 @@ mod tokio_file { } } + impl BytesFreezer { + fn new(inner: S) -> Self { + BytesFreezer { inner } + } + } + impl Stream for BytesFreezer where S: Stream> + Unpin, diff --git a/src/main.rs b/src/main.rs index c6b3931..4cd0125 100644 --- a/src/main.rs +++ b/src/main.rs @@ -57,6 +57,7 @@ use self::{ config::{Config, Format}, either::Either, error::{Error, UploadError}, + file::CrateError, middleware::{Deadline, Internal}, upload_manager::{Details, UploadManager, UploadManagerSession}, validate::{image_webp, video_mp4}, @@ -340,6 +341,59 @@ struct UrlQuery { url: String, } +pin_project_lite::pin_project! { + struct Limit { + #[pin] + inner: S, + + count: u64, + limit: u64, + } +} + +impl Limit { + fn new(inner: S, limit: u64) -> Self { + Limit { + inner, + count: 0, + limit, + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("Resonse body larger than size limit")] +struct LimitError; + +impl Stream for Limit +where + S: Stream>, + E: From, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.as_mut().project(); + + let limit = this.limit; + let count = this.count; + let inner = this.inner; + + inner.poll_next(cx).map(|opt| { + opt.map(|res| match res { + Ok(bytes) => { + *count += bytes.len() as u64; + if *count > *limit { + return Err(LimitError.into()); + } + Ok(bytes) + } + Err(e) => Err(e), + }) + }) + } +} + /// download an image from a URL #[instrument(name = "Downloading file", skip(client, manager))] async fn download( @@ -347,15 +401,19 @@ async fn download( manager: web::Data, query: web::Query, ) -> Result { - let mut res = client.get(&query.url).propagate().send().await?; + let res = client.get(&query.url).propagate().send().await?; if !res.status().is_success() { return Err(UploadError::Download(res.status()).into()); } - let fut = res.body().limit(CONFIG.max_file_size() * MEGABYTES); + let mut stream = Limit::new( + CrateError::new(res), + (CONFIG.max_file_size() * MEGABYTES) as u64, + ); - let stream = Box::pin(once(fut)); + // SAFETY: stream is shadowed, so original cannot not be moved + let stream = unsafe { Pin::new_unchecked(&mut stream) }; let permit = PROCESS_SEMAPHORE.acquire().await?; let session = manager.session().upload(stream).await?; @@ -743,7 +801,7 @@ async fn ranged_file_resp( Ok(srv_response( builder, - Box::pin(stream), + stream, details.content_type(), 7 * DAYS, details.system_time(), @@ -759,7 +817,7 @@ fn srv_response( modified: SystemTime, ) -> HttpResponse where - S: Stream> + Unpin + 'static, + S: Stream> + 'static, E: std::error::Error + 'static, actix_web::Error: From, { @@ -772,7 +830,8 @@ where ])) .insert_header((ACCEPT_RANGES, "bytes")) .content_type(ext.to_string()) - .streaming(stream) + // TODO: remove pin when actix-web drops Unpin requirement + .streaming(Box::pin(stream)) } #[derive(Debug, serde::Deserialize)] diff --git a/src/stream.rs b/src/stream.rs index 33135dc..cfd2994 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -22,12 +22,19 @@ pub(crate) struct Process { span: Span, } -pub(crate) struct ProcessRead { - inner: I, - span: Span, - err_recv: Receiver, - err_closed: bool, - handle: JoinHandle<()>, +struct DropHandle { + inner: JoinHandle<()>, +} + +pin_project_lite::pin_project! { + struct ProcessRead { + #[pin] + inner: I, + span: Span, + err_recv: Receiver, + err_closed: bool, + handle: DropHandle, + } } impl Process { @@ -86,13 +93,13 @@ impl Process { .instrument(span), ); - Some(Box::pin(ProcessRead { + Some(ProcessRead { inner: stdout, span: self.span, err_recv: rx, err_closed: false, - handle, - })) + handle: DropHandle { inner: handle }, + }) } pub(crate) fn file_read( @@ -129,30 +136,36 @@ impl Process { .instrument(span), ); - Some(Box::pin(ProcessRead { + Some(ProcessRead { inner: stdout, span: self.span, err_recv: rx, err_closed: false, - handle, - })) + handle: DropHandle { inner: handle }, + }) } } impl AsyncRead for ProcessRead where - I: AsyncRead + Unpin, + I: AsyncRead, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let span = self.span.clone(); + let this = self.as_mut().project(); + + let span = this.span; + let err_recv = this.err_recv; + let err_closed = this.err_closed; + let inner = this.inner; + span.in_scope(|| { - if !self.err_closed { - if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(cx) { - self.err_closed = true; + if !*err_closed { + if let Poll::Ready(res) = Pin::new(err_recv).poll(cx) { + *err_closed = true; if let Ok(err) = res { let display = format!("{}", err); let debug = format!("{:?}", err); @@ -163,7 +176,7 @@ where } } - if let Poll::Ready(res) = Pin::new(&mut self.inner).poll_read(cx, buf) { + if let Poll::Ready(res) = inner.poll_read(cx, buf) { if let Err(err) = &res { let display = format!("{}", err); let debug = format!("{:?}", err); @@ -178,9 +191,9 @@ where } } -impl Drop for ProcessRead { +impl Drop for DropHandle { fn drop(&mut self) { - self.handle.abort(); + self.inner.abort(); } } diff --git a/src/upload_manager/session.rs b/src/upload_manager/session.rs index 3bc0c47..718f194 100644 --- a/src/upload_manager/session.rs +++ b/src/upload_manager/session.rs @@ -9,15 +9,13 @@ use crate::{ }, }; use actix_web::web; -use futures_util::stream::{LocalBoxStream, StreamExt}; +use futures_util::stream::{Stream, StreamExt}; use std::path::PathBuf; use tokio::io::AsyncRead; use tracing::{debug, instrument, warn, Span}; use tracing_futures::Instrument; use uuid::Uuid; -type UploadStream = LocalBoxStream<'static, Result>; - pub(crate) struct UploadManagerSession { manager: UploadManager, alias: Option, @@ -136,7 +134,7 @@ impl UploadManagerSession { alias: String, content_type: mime::Mime, validate: bool, - mut stream: UploadStream, + mut stream: impl Stream> + Unpin, ) -> Result where Error: From, @@ -177,7 +175,10 @@ impl UploadManagerSession { /// Upload the file, discarding bytes if it's already present, or saving if it's new #[instrument(skip(self, stream))] - pub(crate) async fn upload(mut self, mut stream: UploadStream) -> Result + pub(crate) async fn upload( + mut self, + mut stream: impl Stream> + Unpin, + ) -> Result where Error: From, {