diff --git a/src/error.rs b/src/error.rs index cbc00b0..8b57bd9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -19,7 +19,7 @@ use std::{ num::{ParseFloatError, ParseIntError}, - string::FromUtf8Error, + str::Utf8Error, }; use actix_web::{ @@ -35,7 +35,7 @@ pub enum Error { #[error("Error in multipart creation")] Multipart(#[from] MultipartError), #[error("Failed to parse field")] - ParseField(#[from] FromUtf8Error), + ParseField(#[from] Utf8Error), #[error("Failed to parse int")] ParseInt(#[from] ParseIntError), #[error("Failed to parse float")] diff --git a/src/upload.rs b/src/upload.rs index d106e55..ef421df 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -25,43 +25,9 @@ use crate::{ }, }; use actix_web::web::BytesMut; -use std::{collections::HashMap, future::poll_fn, path::Path, pin::Pin, rc::Rc}; -use tokio::{sync::mpsc::Receiver, task::JoinSet}; -use tracing::trace; - -struct Streamer(S, bool); - -impl Streamer -where - S: futures_core::Stream + Unpin, -{ - async fn next(&mut self) -> Option { - if self.1 { - return None; - } - - let opt = poll_fn(|cx| Pin::new(&mut self.0).poll_next(cx)).await; - self.1 = opt.is_none(); - opt - } - - fn is_closed(&self) -> bool { - self.1 - } -} - -struct ReceiverStream(Receiver); - -impl futures_core::Stream for ReceiverStream { - type Item = T; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Pin::new(&mut self.0).poll_recv(cx) - } -} +use std::{collections::HashMap, path::Path, rc::Rc}; +use streem::IntoStreamer; +use tokio::task::JoinSet; fn consolidate(mf: MultipartForm) -> Value { mf.into_iter().fold( @@ -100,16 +66,13 @@ fn parse_multipart_name(name: String) -> Result, Error> { NamePart::Map(part.to_owned()) } }) - .fold(Ok(vec![]), |acc, part| match acc { - Ok(mut v) => { - if v.is_empty() && !part.is_map() { - return Err(Error::ContentDisposition); - } - - v.push(part); - Ok(v) + .try_fold(vec![], |mut v, part| { + if v.is_empty() && !part.is_map() { + return Err(Error::ContentDisposition); } - Err(e) => Err(e), + + v.push(part); + Ok(v) }) } @@ -143,42 +106,40 @@ where let max_file_size = form.max_file_size; - let (tx, rx) = tokio::sync::mpsc::channel(8); - - let consume_fut = async move { + let field_stream = streem::try_from_fn(move |yielder| async move { let mut file_size = 0; - let mut exceeded_size = false; - let mut stream = Streamer(field, false); + let mut stream = field.into_streamer(); - while let Some(res) = stream.next().await { + while let Some(bytes) = stream.try_next().await? { tracing::trace!("Bytes from field"); - if exceeded_size { - tracing::trace!("Dropping oversized bytes"); - continue; + file_size += bytes.len(); + + if file_size > max_file_size { + drop(bytes); + + while stream.try_next().await?.is_some() { + tracing::trace!("Dropping oversized bytes"); + } + + return Err(Error::FileSize); } - if let Ok(bytes) = &res { - file_size += bytes.len(); - if file_size > max_file_size { - exceeded_size = true; - let _ = tx.send(Err(Error::FileSize)).await; - } - }; - - let _ = tx.send(res.map_err(Error::from)).await; + yielder.yield_ok(bytes).await; } - drop(tx); tracing::debug!("Finished consuming field"); - }; - let stream = ReceiverStream(rx); + Ok(()) + }); - let result_fut = file_fn(filename.clone(), content_type.clone(), Box::pin(stream)); - - let (_, result) = tokio::join!(consume_fut, result_fut); + let result = file_fn( + filename.clone(), + content_type.clone(), + Box::pin(field_stream), + ) + .await; match result { Ok(result) => Ok(Ok(MultipartContent::File(FileMeta { @@ -199,59 +160,54 @@ where T: 'static, E: 'static, { - trace!("In handle_form_data, term: {:?}", term); - let mut bytes = BytesMut::new(); + tracing::trace!("In handle_form_data, term: {:?}", term); + let mut buf = Vec::new(); - let mut exceeded_size = false; - let mut error = None; + let mut stream = field.into_streamer(); - let mut stream = Streamer(field, false); - - while let Some(res) = stream.next().await { + while let Some(bytes) = stream.try_next().await? { tracing::trace!("bytes from field"); - if exceeded_size { - tracing::trace!("Dropping oversized bytes"); - continue; - } + if buf.len() + bytes.len() > form.max_field_size { + drop(buf); - if error.is_some() { - tracing::trace!("Draining field while error exists"); - continue; - } - - let b = match res { - Ok(bytes) => bytes, - Err(e) if error.is_none() => { - error = Some(e); - continue; + while stream.try_next().await?.is_some() { + tracing::trace!("Dropping oversized bytes"); } - Err(_) => continue, - }; - if bytes.len() + b.len() > form.max_field_size { - exceeded_size = true; - continue; + return Err(Error::FieldSize); } - bytes.extend(b); + buf.push(bytes); } + let bytes = match buf.len() { + 0 => return Err(Error::FieldSize), + 1 => buf.pop().expect("contains an element"), + _ => { + let total_length = buf.iter().map(|b| b.len()).sum(); + + let mut bytes = BytesMut::with_capacity(total_length); + + for b in buf { + bytes.extend(b); + } + + bytes.freeze() + } + }; + tracing::debug!("Finished consuming field"); - if let Some(error) = error { - return Err(error.into()); - } - if let FieldTerminator::Bytes = term { - return Ok(MultipartContent::Bytes(bytes.freeze())); + return Ok(MultipartContent::Bytes(bytes)); } - let s = String::from_utf8(bytes.to_vec()).map_err(Error::ParseField)?; + let s = std::str::from_utf8(&bytes).map_err(Error::ParseField)?; match term { FieldTerminator::Bytes | FieldTerminator::File(_) => Err(Error::FieldType), - FieldTerminator::Text => Ok(MultipartContent::Text(s)), + FieldTerminator::Text => Ok(MultipartContent::Text(String::from(s))), FieldTerminator::Float => s .parse() .map_err(Error::ParseFloat) @@ -294,6 +250,7 @@ where } /// Handle multipart streams from Actix Web +#[tracing::instrument(level = "TRACE", skip_all)] pub async fn handle_multipart( m: actix_multipart::Multipart, form: Rc>, @@ -308,25 +265,33 @@ where let mut set = JoinSet::new(); - let mut m = Streamer(m, false); + let mut m = m.into_streamer(); let mut error: Option = None; let mut provided_error: Option = None; + let mut is_closed = false; + let mut stream_error = false; 'outer: loop { tracing::trace!("multipart loop"); if error.is_some() || provided_error.is_some() { - while let Some(res) = m.next().await { - if let Ok(field) = res { - let mut stream = Streamer(field, false); - while let Some(_res) = stream.next().await { - tracing::trace!("Throwing away uploaded bytes, we have an error"); + set.abort_all(); + + if !stream_error { + while let Some(res) = m.next().await { + if let Ok(field) = res { + let mut stream = field.into_streamer(); + while stream.next().await.is_some() { + tracing::trace!("Throwing away uploaded bytes, we have an error"); + } + } else { + break; } } } - while let Some(_) = set.join_next().await { + while set.join_next().await.is_some() { tracing::trace!("Throwing away joined result"); } @@ -334,14 +299,18 @@ where } tokio::select! { - opt = m.next(), if !m.is_closed() => { + opt = m.next(), if !is_closed => { tracing::trace!("Selected stream"); + is_closed = opt.is_none(); + if let Some(res) = opt { match res { Ok(field) => { set.spawn_local(handle_stream_field(field, Rc::clone(&form))); }, Err(e) => { + is_closed = true; + stream_error = true; error = Some(e.into()); continue 'outer; }