diff --git a/src/error.rs b/src/error.rs index 0c4067b..7ce0ac2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -65,6 +65,8 @@ pub enum Error { MissingMiddleware, #[error("Impossible Error! Middleware exists, didn't fail, and didn't send value")] TxDropped, + #[error("Panic in spawned future")] + Panic, } impl From for Error { @@ -73,6 +75,12 @@ impl From for Error { } } +impl From for Error { + fn from(_: crate::spawn::Canceled) -> Self { + Error::Panic + } +} + impl ResponseError for Error { fn status_code(&self) -> StatusCode { match *self { @@ -100,7 +108,7 @@ impl ResponseError for Error { | Error::Filename | Error::FileCount | Error::FileSize => HttpResponse::BadRequest().finish(), - Error::MissingMiddleware | Error::TxDropped => { + Error::Panic | Error::MissingMiddleware | Error::TxDropped => { HttpResponse::InternalServerError().finish() } } diff --git a/src/lib.rs b/src/lib.rs index 551cd1a..eaea46e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ mod error; mod middleware; +mod spawn; mod types; mod upload; @@ -84,3 +85,5 @@ pub use self::{ types::{Field, FileMeta, Form, Value}, upload::handle_multipart, }; + +use self::spawn::spawn; diff --git a/src/spawn.rs b/src/spawn.rs new file mode 100644 index 0000000..898eda5 --- /dev/null +++ b/src/spawn.rs @@ -0,0 +1,23 @@ +use log::error; +use std::future::Future; +use tokio::sync::oneshot::channel; + +#[derive(Debug, thiserror::Error)] +#[error("Task panicked")] +pub(crate) struct Canceled; + +pub(crate) async fn spawn(f: F) -> Result +where + F: Future + 'static, + T: 'static, +{ + let (tx, rx) = channel(); + + actix_rt::spawn(async move { + if let Err(_) = tx.send(f.await) { + error!("rx dropped (this shouldn't happen)"); + } + }); + + rx.await.map_err(|_| Canceled) +} diff --git a/src/upload.rs b/src/upload.rs index 2a15a52..b4fca7d 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -25,7 +25,10 @@ use crate::{ }, }; use bytes::BytesMut; -use futures::stream::StreamExt; +use futures::{ + select, + stream::{FuturesUnordered, StreamExt}, +}; use log::trace; use std::{ collections::HashMap, @@ -207,35 +210,62 @@ async fn handle_stream_field( } /// Handle multipart streams from Actix Web -pub async fn handle_multipart( - mut m: actix_multipart::Multipart, - form: Form, -) -> Result { +pub async fn handle_multipart(m: actix_multipart::Multipart, form: Form) -> Result { let mut multipart_form = Vec::new(); let mut file_count: u32 = 0; let mut field_count: u32 = 0; - while let Some(res) = m.next().await { - let field = res?; - let (name_parts, content) = handle_stream_field(field, form.clone()).await?; + let mut unordered = FuturesUnordered::new(); - match content { - MultipartContent::File(_) => { - file_count += 1; - if file_count >= form.max_files { - return Err(Error::FileCount); + let mut m = m.fuse(); + + loop { + select! { + opt = m.next() => { + if let Some(res) = opt { + let field = res?; + + unordered.push(crate::spawn(handle_stream_field(field, form.clone()))); } } - _ => { - field_count += 1; - if field_count >= form.max_fields { - return Err(Error::FieldCount); + opt = unordered.next() => { + if let Some(res) = opt { + let (name_parts, content) = res??; + + let (l, r) = count(&content, file_count, field_count, &form)?; + file_count = l; + field_count = r; + + multipart_form.push((name_parts, content)); } } + complete => break, } - - multipart_form.push((name_parts, content)); } Ok(consolidate(multipart_form)) } + +fn count( + content: &MultipartContent, + mut file_count: u32, + mut field_count: u32, + form: &Form, +) -> Result<(u32, u32), Error> { + match content { + MultipartContent::File(_) => { + file_count += 1; + if file_count >= form.max_files { + return Err(Error::FileCount); + } + } + _ => { + field_count += 1; + if field_count >= form.max_fields { + return Err(Error::FieldCount); + } + } + } + + Ok((file_count, field_count)) +}