Attempt to drain multipart before returning errors

This commit is contained in:
asonix 2023-07-16 13:49:21 -05:00
parent 46e5834b60
commit 20f80df9fd
5 changed files with 200 additions and 78 deletions

View file

@ -1,7 +1,7 @@
[package]
name = "actix-form-data"
description = "Multipart Form Data for Actix Web"
version = "0.7.0-beta.2"
version = "0.7.0-beta.3"
license = "GPL-3.0"
authors = ["asonix <asonix@asonix.dog>"]
repository = "https://git.asonix.dog/asonix/actix-form-data.git"
@ -13,16 +13,18 @@ edition = "2021"
actix-multipart = { version = "0.6.0", default-features = false }
actix-rt = "2.5.0"
actix-web = { version = "4.0.0", default-features = false }
futures-util = "0.3.17"
futures-core = "0.3.28"
mime = "0.3.16"
thiserror = "1.0"
tokio = { version = "1", default-features = false, features = ["sync"] }
tokio = { version = "1", default-features = false, features = ["macros", "sync"] }
tokio-stream = "0.1.14"
tracing = "0.1.15"
[dev-dependencies]
async-fs = "1.2.1"
anyhow = "1.0"
futures-lite = "1.4.0"
futures-util = "0.3.17"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"

View file

@ -23,7 +23,7 @@ use std::{
};
use actix_web::{
error::{PayloadError, ResponseError, ParseError},
error::{ParseError, PayloadError, ResponseError},
http::StatusCode,
HttpResponse,
};
@ -58,6 +58,8 @@ pub enum Error {
FileCount,
#[error("File too large")]
FileSize,
#[error("Task panicked")]
Panicked,
}
#[derive(Debug, thiserror::Error)]
@ -79,10 +81,7 @@ pub enum MultipartError {
#[error("Multipart stream is not consumed")]
NotConsumed,
#[error("An error occured processing field `{field_name}`: `{zource}`")]
Field {
field_name: String,
zource: String,
},
Field { field_name: String, zource: String },
#[error("Duplicate field found for: `{0}")]
DuplicateField(String),
#[error("Field with name `{0}` is required")]
@ -96,24 +95,51 @@ pub enum MultipartError {
impl From<actix_multipart::MultipartError> for Error {
fn from(value: actix_multipart::MultipartError) -> Self {
match value {
actix_multipart::MultipartError::NoContentDisposition => Error::Multipart(MultipartError::NoContentDisposition),
actix_multipart::MultipartError::NoContentType => Error::Multipart(MultipartError::NoContentType),
actix_multipart::MultipartError::ParseContentType => Error::Multipart(MultipartError::ParseContentType),
actix_multipart::MultipartError::NoContentDisposition => {
Error::Multipart(MultipartError::NoContentDisposition)
}
actix_multipart::MultipartError::NoContentType => {
Error::Multipart(MultipartError::NoContentType)
}
actix_multipart::MultipartError::ParseContentType => {
Error::Multipart(MultipartError::ParseContentType)
}
actix_multipart::MultipartError::Boundary => Error::Multipart(MultipartError::Boundary),
actix_multipart::MultipartError::Nested => Error::Multipart(MultipartError::Nested),
actix_multipart::MultipartError::Incomplete => Error::Multipart(MultipartError::Incomplete),
actix_multipart::MultipartError::Incomplete => {
Error::Multipart(MultipartError::Incomplete)
}
actix_multipart::MultipartError::Parse(e) => Error::Multipart(MultipartError::Parse(e)),
actix_multipart::MultipartError::Payload(e) => Error::Payload(e),
actix_multipart::MultipartError::NotConsumed => Error::Multipart(MultipartError::NotConsumed),
actix_multipart::MultipartError::Field { field_name, source } => Error::Multipart(MultipartError::Field { field_name, zource: source.to_string() }),
actix_multipart::MultipartError::DuplicateField(s) => Error::Multipart(MultipartError::DuplicateField(s)),
actix_multipart::MultipartError::MissingField(s) => Error::Multipart(MultipartError::MissingField(s)),
actix_multipart::MultipartError::UnsupportedField(s) => Error::Multipart(MultipartError::UnsupportedField(s)),
actix_multipart::MultipartError::NotConsumed => {
Error::Multipart(MultipartError::NotConsumed)
}
actix_multipart::MultipartError::Field { field_name, source } => {
Error::Multipart(MultipartError::Field {
field_name,
zource: source.to_string(),
})
}
actix_multipart::MultipartError::DuplicateField(s) => {
Error::Multipart(MultipartError::DuplicateField(s))
}
actix_multipart::MultipartError::MissingField(s) => {
Error::Multipart(MultipartError::MissingField(s))
}
actix_multipart::MultipartError::UnsupportedField(s) => {
Error::Multipart(MultipartError::UnsupportedField(s))
}
e => Error::Multipart(MultipartError::Unknown(e.to_string())),
}
}
}
impl From<tokio::task::JoinError> for Error {
fn from(_: tokio::task::JoinError) -> Self {
Self::Panicked
}
}
impl ResponseError for Error {
fn status_code(&self) -> StatusCode {
match *self {
@ -125,6 +151,7 @@ impl ResponseError for Error {
fn error_response(&self) -> HttpResponse {
match *self {
Error::Payload(ref e) => e.error_response(),
Error::Panicked => HttpResponse::InternalServerError().finish(),
Error::Multipart(_)
| Error::ParseField(_)
| Error::ParseInt(_)

View file

@ -3,7 +3,7 @@ use crate::{
upload::handle_multipart,
};
use actix_web::{dev::Payload, FromRequest, HttpRequest, ResponseError};
use std::{future::Future, pin::Pin};
use std::{future::Future, pin::Pin, rc::Rc};
pub trait FormData {
type Item: 'static;
@ -27,14 +27,14 @@ where
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
let multipart = actix_multipart::Multipart::new(req.headers(), payload.take());
let form = T::form(req);
let form = Rc::new(T::form(req));
Box::pin(async move {
let uploaded = match handle_multipart(multipart, &form).await {
let uploaded = match handle_multipart(multipart, Rc::clone(&form)).await {
Ok(Ok(uploaded)) => uploaded,
Ok(Err(e)) => return Err(e.into()),
Err(e) => {
if let Some(f) = form.transform_error {
if let Some(f) = &form.transform_error {
return Err((f)(e));
} else {
return Err(e.into());

View file

@ -19,7 +19,7 @@
use crate::Error;
use actix_web::web::Bytes;
use futures_util::Stream;
use futures_core::Stream;
use mime::Mime;
use std::{
collections::{HashMap, VecDeque},
@ -154,10 +154,10 @@ impl<T> From<MultipartContent<T>> for Value<T> {
pub type FileFn<T, E> = Box<
dyn Fn(
String,
Option<Mime>,
Pin<Box<dyn Stream<Item = Result<Bytes, Error>>>>,
) -> Pin<Box<dyn Future<Output = Result<T, E>>>>
String,
Option<Mime>,
Pin<Box<dyn Stream<Item = Result<Bytes, Error>>>>,
) -> Pin<Box<dyn Future<Output = Result<T, E>>>>,
>;
/// The field type represents a field in the form-data that is allowed to be parsed.

View file

@ -25,20 +25,27 @@ use crate::{
},
};
use actix_web::web::BytesMut;
use futures_util::{
select,
stream::{FuturesUnordered, StreamExt},
};
use std::{
collections::HashMap,
path::Path,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use std::{collections::HashMap, future::poll_fn, path::Path, pin::Pin, rc::Rc};
use tokio::task::JoinSet;
use tracing::trace;
struct Streamer<S>(S, bool);
impl<S> Streamer<S> {
async fn next(&mut self) -> Option<S::Item>
where
S: futures_core::Stream + Unpin,
{
if self.1 {
return None;
}
let opt = poll_fn(|cx| Pin::new(&mut self.0).poll_next(cx)).await;
self.1 = opt.is_none();
opt
}
}
fn consolidate<T>(mf: MultipartForm<T>) -> Value<T> {
mf.into_iter().fold(
Value::Map(HashMap::new()),
@ -115,34 +122,41 @@ where
let filename = filename.ok_or(Error::Filename)?.to_owned();
let file_size = Arc::new(AtomicUsize::new(0));
let content_type = field.content_type().cloned();
let max_file_size = form.max_file_size;
let result = file_fn(
filename.clone(),
content_type.clone(),
Box::pin(field.then(move |res| {
let file_size = file_size.clone();
async move {
match res {
Ok(bytes) => {
let size = file_size.fetch_add(bytes.len(), Ordering::Relaxed);
let (tx, rx) = tokio::sync::mpsc::channel(8);
if size + bytes.len() > max_file_size {
return Err(Error::FileSize);
}
let consume_fut = async move {
let mut file_size = 0;
let mut exceeded_size = false;
Ok(bytes)
}
Err(e) => Err(Error::from(e)),
}
let mut stream = Streamer(field, false);
while let Some(res) = stream.next().await {
if exceeded_size {
tracing::trace!("Dropping oversized bytes");
continue;
}
})),
)
.await;
if let Ok(bytes) = &res {
file_size += bytes.len();
if file_size > max_file_size {
exceeded_size = true;
let _ = tx.send(Err(Error::FileSize)).await;
}
};
let _ = tx.send(res.map_err(Error::from)).await;
}
};
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let result_fut = file_fn(filename.clone(), content_type.clone(), Box::pin(stream));
let (_, result) = tokio::join!(consume_fut, result_fut);
match result {
Ok(result) => Ok(Ok(MultipartContent::File(FileMeta {
@ -155,7 +169,7 @@ where
}
async fn handle_form_data<'a, T, E>(
mut field: actix_multipart::Field,
field: actix_multipart::Field,
term: FieldTerminator<'a, T, E>,
form: &Form<T, E>,
) -> Result<MultipartContent<T>, Error>
@ -166,15 +180,43 @@ where
trace!("In handle_form_data, term: {:?}", term);
let mut bytes = BytesMut::new();
while let Some(res) = field.next().await {
let b = res?;
let mut exceeded_size = false;
let mut error = None;
let mut stream = Streamer(field, false);
while let Some(res) = stream.next().await {
if exceeded_size {
tracing::trace!("Dropping oversized bytes");
continue;
}
if error.is_some() {
tracing::trace!("Draining field while error exists");
continue;
}
let b = match res {
Ok(bytes) => bytes,
Err(e) if error.is_none() => {
error = Some(e);
continue;
}
Err(_) => continue,
};
if bytes.len() + b.len() > form.max_field_size {
return Err(Error::FieldSize);
exceeded_size = true;
continue;
}
bytes.extend(b);
}
if let Some(error) = error {
return Err(error.into());
}
if let FieldTerminator::Bytes = term {
return Ok(MultipartContent::Bytes(bytes.freeze()));
}
@ -197,7 +239,7 @@ where
async fn handle_stream_field<T, E>(
field: actix_multipart::Field,
form: &Form<T, E>,
form: Rc<Form<T, E>>,
) -> Result<Result<MultipartHash<T>, E>, Error>
where
T: 'static,
@ -214,12 +256,12 @@ where
let content = match term {
FieldTerminator::File(file_fn) => {
match handle_file_upload(field, content_disposition.filename, form, file_fn).await? {
match handle_file_upload(field, content_disposition.filename, &form, file_fn).await? {
Ok(content) => content,
Err(e) => return Ok(Err(e)),
}
}
term => handle_form_data(field, term, form).await?,
term => handle_form_data(field, term, &form).await?,
};
Ok(Ok((name, content)))
@ -228,7 +270,7 @@ where
/// Handle multipart streams from Actix Web
pub async fn handle_multipart<T, E>(
m: actix_multipart::Multipart,
form: &Form<T, E>,
form: Rc<Form<T, E>>,
) -> Result<Result<Value<T>, E>, Error>
where
T: 'static,
@ -238,35 +280,86 @@ where
let mut file_count: u32 = 0;
let mut field_count: u32 = 0;
let mut unordered = FuturesUnordered::new();
let mut set = JoinSet::new();
let mut m = m.fuse();
let mut m = Streamer(m, false);
loop {
select! {
let mut stream_finished = false;
let mut error: Option<Error> = None;
let mut provided_error: Option<E> = None;
'outer: loop {
tokio::select! {
opt = m.next() => {
if let Some(res) = opt {
unordered.push(handle_stream_field(res?, form));
if error.is_some() || provided_error.is_some() {
continue;
}
match res {
Ok(field) => {
set.spawn_local(handle_stream_field(field, Rc::clone(&form)));
},
Err(e) => error = Some(e.into()),
}
} else {
stream_finished = true;
}
}
opt = unordered.next() => {
opt = set.join_next() => {
if error.is_some() || provided_error.is_some() {
if stream_finished {
break 'outer;
}
continue;
}
if let Some(res) = opt {
let (name_parts, content) = match res? {
Ok(tup) => tup,
Err(e) => return Ok(Err(e)),
let (name_parts, content) = match res {
Ok(Ok(Ok(tup))) => tup,
Ok(Ok(Err(e))) => {
provided_error = Some(e);
continue;
}
Ok(Err(e)) => {
error = Some(e);
continue;
},
Err(e) => {
error = Some(e.into());
continue;
},
};
let (l, r) = match count(&content, file_count, field_count, &form) {
Ok(tup) => tup,
Err(e) => {
error = Some(e);
continue;
}
};
let (l, r) = count(&content, file_count, field_count, form)?;
file_count = l;
field_count = r;
multipart_form.push((name_parts, content));
} else if stream_finished {
break 'outer;
}
}
complete => break,
}
}
if let Some(e) = provided_error {
return Ok(Err(e));
}
if let Some(e) = error {
return Err(e);
}
Ok(Ok(consolidate(multipart_form)))
}