Attempt to drain multipart before returning errors
This commit is contained in:
parent
46e5834b60
commit
20f80df9fd
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "actix-form-data"
|
||||
description = "Multipart Form Data for Actix Web"
|
||||
version = "0.7.0-beta.2"
|
||||
version = "0.7.0-beta.3"
|
||||
license = "GPL-3.0"
|
||||
authors = ["asonix <asonix@asonix.dog>"]
|
||||
repository = "https://git.asonix.dog/asonix/actix-form-data.git"
|
||||
|
@ -13,16 +13,18 @@ edition = "2021"
|
|||
actix-multipart = { version = "0.6.0", default-features = false }
|
||||
actix-rt = "2.5.0"
|
||||
actix-web = { version = "4.0.0", default-features = false }
|
||||
futures-util = "0.3.17"
|
||||
futures-core = "0.3.28"
|
||||
mime = "0.3.16"
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1", default-features = false, features = ["sync"] }
|
||||
tokio = { version = "1", default-features = false, features = ["macros", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
tracing = "0.1.15"
|
||||
|
||||
[dev-dependencies]
|
||||
async-fs = "1.2.1"
|
||||
anyhow = "1.0"
|
||||
futures-lite = "1.4.0"
|
||||
futures-util = "0.3.17"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "1.0"
|
||||
|
|
55
src/error.rs
55
src/error.rs
|
@ -23,7 +23,7 @@ use std::{
|
|||
};
|
||||
|
||||
use actix_web::{
|
||||
error::{PayloadError, ResponseError, ParseError},
|
||||
error::{ParseError, PayloadError, ResponseError},
|
||||
http::StatusCode,
|
||||
HttpResponse,
|
||||
};
|
||||
|
@ -58,6 +58,8 @@ pub enum Error {
|
|||
FileCount,
|
||||
#[error("File too large")]
|
||||
FileSize,
|
||||
#[error("Task panicked")]
|
||||
Panicked,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
@ -79,10 +81,7 @@ pub enum MultipartError {
|
|||
#[error("Multipart stream is not consumed")]
|
||||
NotConsumed,
|
||||
#[error("An error occured processing field `{field_name}`: `{zource}`")]
|
||||
Field {
|
||||
field_name: String,
|
||||
zource: String,
|
||||
},
|
||||
Field { field_name: String, zource: String },
|
||||
#[error("Duplicate field found for: `{0}")]
|
||||
DuplicateField(String),
|
||||
#[error("Field with name `{0}` is required")]
|
||||
|
@ -96,24 +95,51 @@ pub enum MultipartError {
|
|||
impl From<actix_multipart::MultipartError> for Error {
|
||||
fn from(value: actix_multipart::MultipartError) -> Self {
|
||||
match value {
|
||||
actix_multipart::MultipartError::NoContentDisposition => Error::Multipart(MultipartError::NoContentDisposition),
|
||||
actix_multipart::MultipartError::NoContentType => Error::Multipart(MultipartError::NoContentType),
|
||||
actix_multipart::MultipartError::ParseContentType => Error::Multipart(MultipartError::ParseContentType),
|
||||
actix_multipart::MultipartError::NoContentDisposition => {
|
||||
Error::Multipart(MultipartError::NoContentDisposition)
|
||||
}
|
||||
actix_multipart::MultipartError::NoContentType => {
|
||||
Error::Multipart(MultipartError::NoContentType)
|
||||
}
|
||||
actix_multipart::MultipartError::ParseContentType => {
|
||||
Error::Multipart(MultipartError::ParseContentType)
|
||||
}
|
||||
actix_multipart::MultipartError::Boundary => Error::Multipart(MultipartError::Boundary),
|
||||
actix_multipart::MultipartError::Nested => Error::Multipart(MultipartError::Nested),
|
||||
actix_multipart::MultipartError::Incomplete => Error::Multipart(MultipartError::Incomplete),
|
||||
actix_multipart::MultipartError::Incomplete => {
|
||||
Error::Multipart(MultipartError::Incomplete)
|
||||
}
|
||||
actix_multipart::MultipartError::Parse(e) => Error::Multipart(MultipartError::Parse(e)),
|
||||
actix_multipart::MultipartError::Payload(e) => Error::Payload(e),
|
||||
actix_multipart::MultipartError::NotConsumed => Error::Multipart(MultipartError::NotConsumed),
|
||||
actix_multipart::MultipartError::Field { field_name, source } => Error::Multipart(MultipartError::Field { field_name, zource: source.to_string() }),
|
||||
actix_multipart::MultipartError::DuplicateField(s) => Error::Multipart(MultipartError::DuplicateField(s)),
|
||||
actix_multipart::MultipartError::MissingField(s) => Error::Multipart(MultipartError::MissingField(s)),
|
||||
actix_multipart::MultipartError::UnsupportedField(s) => Error::Multipart(MultipartError::UnsupportedField(s)),
|
||||
actix_multipart::MultipartError::NotConsumed => {
|
||||
Error::Multipart(MultipartError::NotConsumed)
|
||||
}
|
||||
actix_multipart::MultipartError::Field { field_name, source } => {
|
||||
Error::Multipart(MultipartError::Field {
|
||||
field_name,
|
||||
zource: source.to_string(),
|
||||
})
|
||||
}
|
||||
actix_multipart::MultipartError::DuplicateField(s) => {
|
||||
Error::Multipart(MultipartError::DuplicateField(s))
|
||||
}
|
||||
actix_multipart::MultipartError::MissingField(s) => {
|
||||
Error::Multipart(MultipartError::MissingField(s))
|
||||
}
|
||||
actix_multipart::MultipartError::UnsupportedField(s) => {
|
||||
Error::Multipart(MultipartError::UnsupportedField(s))
|
||||
}
|
||||
e => Error::Multipart(MultipartError::Unknown(e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio::task::JoinError> for Error {
|
||||
fn from(_: tokio::task::JoinError) -> Self {
|
||||
Self::Panicked
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseError for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match *self {
|
||||
|
@ -125,6 +151,7 @@ impl ResponseError for Error {
|
|||
fn error_response(&self) -> HttpResponse {
|
||||
match *self {
|
||||
Error::Payload(ref e) => e.error_response(),
|
||||
Error::Panicked => HttpResponse::InternalServerError().finish(),
|
||||
Error::Multipart(_)
|
||||
| Error::ParseField(_)
|
||||
| Error::ParseInt(_)
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{
|
|||
upload::handle_multipart,
|
||||
};
|
||||
use actix_web::{dev::Payload, FromRequest, HttpRequest, ResponseError};
|
||||
use std::{future::Future, pin::Pin};
|
||||
use std::{future::Future, pin::Pin, rc::Rc};
|
||||
|
||||
pub trait FormData {
|
||||
type Item: 'static;
|
||||
|
@ -27,14 +27,14 @@ where
|
|||
|
||||
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
|
||||
let multipart = actix_multipart::Multipart::new(req.headers(), payload.take());
|
||||
let form = T::form(req);
|
||||
let form = Rc::new(T::form(req));
|
||||
|
||||
Box::pin(async move {
|
||||
let uploaded = match handle_multipart(multipart, &form).await {
|
||||
let uploaded = match handle_multipart(multipart, Rc::clone(&form)).await {
|
||||
Ok(Ok(uploaded)) => uploaded,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(e) => {
|
||||
if let Some(f) = form.transform_error {
|
||||
if let Some(f) = &form.transform_error {
|
||||
return Err((f)(e));
|
||||
} else {
|
||||
return Err(e.into());
|
||||
|
|
10
src/types.rs
10
src/types.rs
|
@ -19,7 +19,7 @@
|
|||
|
||||
use crate::Error;
|
||||
use actix_web::web::Bytes;
|
||||
use futures_util::Stream;
|
||||
use futures_core::Stream;
|
||||
use mime::Mime;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
|
@ -154,10 +154,10 @@ impl<T> From<MultipartContent<T>> for Value<T> {
|
|||
|
||||
pub type FileFn<T, E> = Box<
|
||||
dyn Fn(
|
||||
String,
|
||||
Option<Mime>,
|
||||
Pin<Box<dyn Stream<Item = Result<Bytes, Error>>>>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<T, E>>>>
|
||||
String,
|
||||
Option<Mime>,
|
||||
Pin<Box<dyn Stream<Item = Result<Bytes, Error>>>>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<T, E>>>>,
|
||||
>;
|
||||
|
||||
/// The field type represents a field in the form-data that is allowed to be parsed.
|
||||
|
|
197
src/upload.rs
197
src/upload.rs
|
@ -25,20 +25,27 @@ use crate::{
|
|||
},
|
||||
};
|
||||
use actix_web::web::BytesMut;
|
||||
use futures_util::{
|
||||
select,
|
||||
stream::{FuturesUnordered, StreamExt},
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use std::{collections::HashMap, future::poll_fn, path::Path, pin::Pin, rc::Rc};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::trace;
|
||||
|
||||
struct Streamer<S>(S, bool);
|
||||
|
||||
impl<S> Streamer<S> {
|
||||
async fn next(&mut self) -> Option<S::Item>
|
||||
where
|
||||
S: futures_core::Stream + Unpin,
|
||||
{
|
||||
if self.1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let opt = poll_fn(|cx| Pin::new(&mut self.0).poll_next(cx)).await;
|
||||
self.1 = opt.is_none();
|
||||
opt
|
||||
}
|
||||
}
|
||||
|
||||
fn consolidate<T>(mf: MultipartForm<T>) -> Value<T> {
|
||||
mf.into_iter().fold(
|
||||
Value::Map(HashMap::new()),
|
||||
|
@ -115,34 +122,41 @@ where
|
|||
|
||||
let filename = filename.ok_or(Error::Filename)?.to_owned();
|
||||
|
||||
let file_size = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let content_type = field.content_type().cloned();
|
||||
|
||||
let max_file_size = form.max_file_size;
|
||||
|
||||
let result = file_fn(
|
||||
filename.clone(),
|
||||
content_type.clone(),
|
||||
Box::pin(field.then(move |res| {
|
||||
let file_size = file_size.clone();
|
||||
async move {
|
||||
match res {
|
||||
Ok(bytes) => {
|
||||
let size = file_size.fetch_add(bytes.len(), Ordering::Relaxed);
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(8);
|
||||
|
||||
if size + bytes.len() > max_file_size {
|
||||
return Err(Error::FileSize);
|
||||
}
|
||||
let consume_fut = async move {
|
||||
let mut file_size = 0;
|
||||
let mut exceeded_size = false;
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
Err(e) => Err(Error::from(e)),
|
||||
}
|
||||
let mut stream = Streamer(field, false);
|
||||
|
||||
while let Some(res) = stream.next().await {
|
||||
if exceeded_size {
|
||||
tracing::trace!("Dropping oversized bytes");
|
||||
continue;
|
||||
}
|
||||
})),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Ok(bytes) = &res {
|
||||
file_size += bytes.len();
|
||||
if file_size > max_file_size {
|
||||
exceeded_size = true;
|
||||
let _ = tx.send(Err(Error::FileSize)).await;
|
||||
}
|
||||
};
|
||||
|
||||
let _ = tx.send(res.map_err(Error::from)).await;
|
||||
}
|
||||
};
|
||||
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
|
||||
let result_fut = file_fn(filename.clone(), content_type.clone(), Box::pin(stream));
|
||||
|
||||
let (_, result) = tokio::join!(consume_fut, result_fut);
|
||||
|
||||
match result {
|
||||
Ok(result) => Ok(Ok(MultipartContent::File(FileMeta {
|
||||
|
@ -155,7 +169,7 @@ where
|
|||
}
|
||||
|
||||
async fn handle_form_data<'a, T, E>(
|
||||
mut field: actix_multipart::Field,
|
||||
field: actix_multipart::Field,
|
||||
term: FieldTerminator<'a, T, E>,
|
||||
form: &Form<T, E>,
|
||||
) -> Result<MultipartContent<T>, Error>
|
||||
|
@ -166,15 +180,43 @@ where
|
|||
trace!("In handle_form_data, term: {:?}", term);
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
while let Some(res) = field.next().await {
|
||||
let b = res?;
|
||||
let mut exceeded_size = false;
|
||||
let mut error = None;
|
||||
|
||||
let mut stream = Streamer(field, false);
|
||||
|
||||
while let Some(res) = stream.next().await {
|
||||
if exceeded_size {
|
||||
tracing::trace!("Dropping oversized bytes");
|
||||
continue;
|
||||
}
|
||||
|
||||
if error.is_some() {
|
||||
tracing::trace!("Draining field while error exists");
|
||||
continue;
|
||||
}
|
||||
|
||||
let b = match res {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) if error.is_none() => {
|
||||
error = Some(e);
|
||||
continue;
|
||||
}
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if bytes.len() + b.len() > form.max_field_size {
|
||||
return Err(Error::FieldSize);
|
||||
exceeded_size = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
bytes.extend(b);
|
||||
}
|
||||
|
||||
if let Some(error) = error {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
if let FieldTerminator::Bytes = term {
|
||||
return Ok(MultipartContent::Bytes(bytes.freeze()));
|
||||
}
|
||||
|
@ -197,7 +239,7 @@ where
|
|||
|
||||
async fn handle_stream_field<T, E>(
|
||||
field: actix_multipart::Field,
|
||||
form: &Form<T, E>,
|
||||
form: Rc<Form<T, E>>,
|
||||
) -> Result<Result<MultipartHash<T>, E>, Error>
|
||||
where
|
||||
T: 'static,
|
||||
|
@ -214,12 +256,12 @@ where
|
|||
|
||||
let content = match term {
|
||||
FieldTerminator::File(file_fn) => {
|
||||
match handle_file_upload(field, content_disposition.filename, form, file_fn).await? {
|
||||
match handle_file_upload(field, content_disposition.filename, &form, file_fn).await? {
|
||||
Ok(content) => content,
|
||||
Err(e) => return Ok(Err(e)),
|
||||
}
|
||||
}
|
||||
term => handle_form_data(field, term, form).await?,
|
||||
term => handle_form_data(field, term, &form).await?,
|
||||
};
|
||||
|
||||
Ok(Ok((name, content)))
|
||||
|
@ -228,7 +270,7 @@ where
|
|||
/// Handle multipart streams from Actix Web
|
||||
pub async fn handle_multipart<T, E>(
|
||||
m: actix_multipart::Multipart,
|
||||
form: &Form<T, E>,
|
||||
form: Rc<Form<T, E>>,
|
||||
) -> Result<Result<Value<T>, E>, Error>
|
||||
where
|
||||
T: 'static,
|
||||
|
@ -238,35 +280,86 @@ where
|
|||
let mut file_count: u32 = 0;
|
||||
let mut field_count: u32 = 0;
|
||||
|
||||
let mut unordered = FuturesUnordered::new();
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
let mut m = m.fuse();
|
||||
let mut m = Streamer(m, false);
|
||||
|
||||
loop {
|
||||
select! {
|
||||
let mut stream_finished = false;
|
||||
|
||||
let mut error: Option<Error> = None;
|
||||
let mut provided_error: Option<E> = None;
|
||||
|
||||
'outer: loop {
|
||||
tokio::select! {
|
||||
opt = m.next() => {
|
||||
if let Some(res) = opt {
|
||||
unordered.push(handle_stream_field(res?, form));
|
||||
if error.is_some() || provided_error.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(field) => {
|
||||
set.spawn_local(handle_stream_field(field, Rc::clone(&form)));
|
||||
},
|
||||
Err(e) => error = Some(e.into()),
|
||||
}
|
||||
} else {
|
||||
stream_finished = true;
|
||||
}
|
||||
}
|
||||
opt = unordered.next() => {
|
||||
opt = set.join_next() => {
|
||||
if error.is_some() || provided_error.is_some() {
|
||||
if stream_finished {
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(res) = opt {
|
||||
let (name_parts, content) = match res? {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => return Ok(Err(e)),
|
||||
let (name_parts, content) = match res {
|
||||
Ok(Ok(Ok(tup))) => tup,
|
||||
Ok(Ok(Err(e))) => {
|
||||
provided_error = Some(e);
|
||||
continue;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error = Some(e);
|
||||
continue;
|
||||
},
|
||||
Err(e) => {
|
||||
error = Some(e.into());
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let (l, r) = match count(&content, file_count, field_count, &form) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error = Some(e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (l, r) = count(&content, file_count, field_count, form)?;
|
||||
file_count = l;
|
||||
field_count = r;
|
||||
|
||||
multipart_form.push((name_parts, content));
|
||||
} else if stream_finished {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
complete => break,
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(e) = provided_error {
|
||||
return Ok(Err(e));
|
||||
}
|
||||
|
||||
if let Some(e) = error {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
Ok(Ok(consolidate(multipart_form)))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue