Simplify stream implementations with streem
This commit is contained in:
parent
8b422644fb
commit
0f7614ec3b
|
@ -19,7 +19,7 @@
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
num::{ParseFloatError, ParseIntError},
|
num::{ParseFloatError, ParseIntError},
|
||||||
string::FromUtf8Error,
|
str::Utf8Error,
|
||||||
};
|
};
|
||||||
|
|
||||||
use actix_web::{
|
use actix_web::{
|
||||||
|
@ -35,7 +35,7 @@ pub enum Error {
|
||||||
#[error("Error in multipart creation")]
|
#[error("Error in multipart creation")]
|
||||||
Multipart(#[from] MultipartError),
|
Multipart(#[from] MultipartError),
|
||||||
#[error("Failed to parse field")]
|
#[error("Failed to parse field")]
|
||||||
ParseField(#[from] FromUtf8Error),
|
ParseField(#[from] Utf8Error),
|
||||||
#[error("Failed to parse int")]
|
#[error("Failed to parse int")]
|
||||||
ParseInt(#[from] ParseIntError),
|
ParseInt(#[from] ParseIntError),
|
||||||
#[error("Failed to parse float")]
|
#[error("Failed to parse float")]
|
||||||
|
|
193
src/upload.rs
193
src/upload.rs
|
@ -25,43 +25,9 @@ use crate::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use actix_web::web::BytesMut;
|
use actix_web::web::BytesMut;
|
||||||
use std::{collections::HashMap, future::poll_fn, path::Path, pin::Pin, rc::Rc};
|
use std::{collections::HashMap, path::Path, rc::Rc};
|
||||||
use tokio::{sync::mpsc::Receiver, task::JoinSet};
|
use streem::IntoStreamer;
|
||||||
use tracing::trace;
|
use tokio::task::JoinSet;
|
||||||
|
|
||||||
struct Streamer<S>(S, bool);
|
|
||||||
|
|
||||||
impl<S> Streamer<S>
|
|
||||||
where
|
|
||||||
S: futures_core::Stream + Unpin,
|
|
||||||
{
|
|
||||||
async fn next(&mut self) -> Option<S::Item> {
|
|
||||||
if self.1 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let opt = poll_fn(|cx| Pin::new(&mut self.0).poll_next(cx)).await;
|
|
||||||
self.1 = opt.is_none();
|
|
||||||
opt
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_closed(&self) -> bool {
|
|
||||||
self.1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ReceiverStream<T>(Receiver<T>);
|
|
||||||
|
|
||||||
impl<T> futures_core::Stream for ReceiverStream<T> {
|
|
||||||
type Item = T;
|
|
||||||
|
|
||||||
fn poll_next(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Option<Self::Item>> {
|
|
||||||
Pin::new(&mut self.0).poll_recv(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn consolidate<T>(mf: MultipartForm<T>) -> Value<T> {
|
fn consolidate<T>(mf: MultipartForm<T>) -> Value<T> {
|
||||||
mf.into_iter().fold(
|
mf.into_iter().fold(
|
||||||
|
@ -100,16 +66,13 @@ fn parse_multipart_name(name: String) -> Result<Vec<NamePart>, Error> {
|
||||||
NamePart::Map(part.to_owned())
|
NamePart::Map(part.to_owned())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.fold(Ok(vec![]), |acc, part| match acc {
|
.try_fold(vec![], |mut v, part| {
|
||||||
Ok(mut v) => {
|
if v.is_empty() && !part.is_map() {
|
||||||
if v.is_empty() && !part.is_map() {
|
return Err(Error::ContentDisposition);
|
||||||
return Err(Error::ContentDisposition);
|
|
||||||
}
|
|
||||||
|
|
||||||
v.push(part);
|
|
||||||
Ok(v)
|
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
|
||||||
|
v.push(part);
|
||||||
|
Ok(v)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,42 +106,40 @@ where
|
||||||
|
|
||||||
let max_file_size = form.max_file_size;
|
let max_file_size = form.max_file_size;
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel(8);
|
let field_stream = streem::try_from_fn(move |yielder| async move {
|
||||||
|
|
||||||
let consume_fut = async move {
|
|
||||||
let mut file_size = 0;
|
let mut file_size = 0;
|
||||||
let mut exceeded_size = false;
|
|
||||||
|
|
||||||
let mut stream = Streamer(field, false);
|
let mut stream = field.into_streamer();
|
||||||
|
|
||||||
while let Some(res) = stream.next().await {
|
while let Some(bytes) = stream.try_next().await? {
|
||||||
tracing::trace!("Bytes from field");
|
tracing::trace!("Bytes from field");
|
||||||
|
|
||||||
if exceeded_size {
|
file_size += bytes.len();
|
||||||
tracing::trace!("Dropping oversized bytes");
|
|
||||||
continue;
|
if file_size > max_file_size {
|
||||||
|
drop(bytes);
|
||||||
|
|
||||||
|
while stream.try_next().await?.is_some() {
|
||||||
|
tracing::trace!("Dropping oversized bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(Error::FileSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(bytes) = &res {
|
yielder.yield_ok(bytes).await;
|
||||||
file_size += bytes.len();
|
|
||||||
if file_size > max_file_size {
|
|
||||||
exceeded_size = true;
|
|
||||||
let _ = tx.send(Err(Error::FileSize)).await;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let _ = tx.send(res.map_err(Error::from)).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
drop(tx);
|
|
||||||
tracing::debug!("Finished consuming field");
|
tracing::debug!("Finished consuming field");
|
||||||
};
|
|
||||||
|
|
||||||
let stream = ReceiverStream(rx);
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
let result_fut = file_fn(filename.clone(), content_type.clone(), Box::pin(stream));
|
let result = file_fn(
|
||||||
|
filename.clone(),
|
||||||
let (_, result) = tokio::join!(consume_fut, result_fut);
|
content_type.clone(),
|
||||||
|
Box::pin(field_stream),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(result) => Ok(Ok(MultipartContent::File(FileMeta {
|
Ok(result) => Ok(Ok(MultipartContent::File(FileMeta {
|
||||||
|
@ -199,59 +160,54 @@ where
|
||||||
T: 'static,
|
T: 'static,
|
||||||
E: 'static,
|
E: 'static,
|
||||||
{
|
{
|
||||||
trace!("In handle_form_data, term: {:?}", term);
|
tracing::trace!("In handle_form_data, term: {:?}", term);
|
||||||
let mut bytes = BytesMut::new();
|
let mut buf = Vec::new();
|
||||||
|
|
||||||
let mut exceeded_size = false;
|
let mut stream = field.into_streamer();
|
||||||
let mut error = None;
|
|
||||||
|
|
||||||
let mut stream = Streamer(field, false);
|
while let Some(bytes) = stream.try_next().await? {
|
||||||
|
|
||||||
while let Some(res) = stream.next().await {
|
|
||||||
tracing::trace!("bytes from field");
|
tracing::trace!("bytes from field");
|
||||||
|
|
||||||
if exceeded_size {
|
if buf.len() + bytes.len() > form.max_field_size {
|
||||||
tracing::trace!("Dropping oversized bytes");
|
drop(buf);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if error.is_some() {
|
while stream.try_next().await?.is_some() {
|
||||||
tracing::trace!("Draining field while error exists");
|
tracing::trace!("Dropping oversized bytes");
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let b = match res {
|
|
||||||
Ok(bytes) => bytes,
|
|
||||||
Err(e) if error.is_none() => {
|
|
||||||
error = Some(e);
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
Err(_) => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
if bytes.len() + b.len() > form.max_field_size {
|
return Err(Error::FieldSize);
|
||||||
exceeded_size = true;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bytes.extend(b);
|
buf.push(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let bytes = match buf.len() {
|
||||||
|
0 => return Err(Error::FieldSize),
|
||||||
|
1 => buf.pop().expect("contains an element"),
|
||||||
|
_ => {
|
||||||
|
let total_length = buf.iter().map(|b| b.len()).sum();
|
||||||
|
|
||||||
|
let mut bytes = BytesMut::with_capacity(total_length);
|
||||||
|
|
||||||
|
for b in buf {
|
||||||
|
bytes.extend(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes.freeze()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
tracing::debug!("Finished consuming field");
|
tracing::debug!("Finished consuming field");
|
||||||
|
|
||||||
if let Some(error) = error {
|
|
||||||
return Err(error.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let FieldTerminator::Bytes = term {
|
if let FieldTerminator::Bytes = term {
|
||||||
return Ok(MultipartContent::Bytes(bytes.freeze()));
|
return Ok(MultipartContent::Bytes(bytes));
|
||||||
}
|
}
|
||||||
|
|
||||||
let s = String::from_utf8(bytes.to_vec()).map_err(Error::ParseField)?;
|
let s = std::str::from_utf8(&bytes).map_err(Error::ParseField)?;
|
||||||
|
|
||||||
match term {
|
match term {
|
||||||
FieldTerminator::Bytes | FieldTerminator::File(_) => Err(Error::FieldType),
|
FieldTerminator::Bytes | FieldTerminator::File(_) => Err(Error::FieldType),
|
||||||
FieldTerminator::Text => Ok(MultipartContent::Text(s)),
|
FieldTerminator::Text => Ok(MultipartContent::Text(String::from(s))),
|
||||||
FieldTerminator::Float => s
|
FieldTerminator::Float => s
|
||||||
.parse()
|
.parse()
|
||||||
.map_err(Error::ParseFloat)
|
.map_err(Error::ParseFloat)
|
||||||
|
@ -294,6 +250,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle multipart streams from Actix Web
|
/// Handle multipart streams from Actix Web
|
||||||
|
#[tracing::instrument(level = "TRACE", skip_all)]
|
||||||
pub async fn handle_multipart<T, E>(
|
pub async fn handle_multipart<T, E>(
|
||||||
m: actix_multipart::Multipart,
|
m: actix_multipart::Multipart,
|
||||||
form: Rc<Form<T, E>>,
|
form: Rc<Form<T, E>>,
|
||||||
|
@ -308,25 +265,33 @@ where
|
||||||
|
|
||||||
let mut set = JoinSet::new();
|
let mut set = JoinSet::new();
|
||||||
|
|
||||||
let mut m = Streamer(m, false);
|
let mut m = m.into_streamer();
|
||||||
|
|
||||||
let mut error: Option<Error> = None;
|
let mut error: Option<Error> = None;
|
||||||
let mut provided_error: Option<E> = None;
|
let mut provided_error: Option<E> = None;
|
||||||
|
let mut is_closed = false;
|
||||||
|
let mut stream_error = false;
|
||||||
|
|
||||||
'outer: loop {
|
'outer: loop {
|
||||||
tracing::trace!("multipart loop");
|
tracing::trace!("multipart loop");
|
||||||
|
|
||||||
if error.is_some() || provided_error.is_some() {
|
if error.is_some() || provided_error.is_some() {
|
||||||
while let Some(res) = m.next().await {
|
set.abort_all();
|
||||||
if let Ok(field) = res {
|
|
||||||
let mut stream = Streamer(field, false);
|
if !stream_error {
|
||||||
while let Some(_res) = stream.next().await {
|
while let Some(res) = m.next().await {
|
||||||
tracing::trace!("Throwing away uploaded bytes, we have an error");
|
if let Ok(field) = res {
|
||||||
|
let mut stream = field.into_streamer();
|
||||||
|
while stream.next().await.is_some() {
|
||||||
|
tracing::trace!("Throwing away uploaded bytes, we have an error");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while let Some(_) = set.join_next().await {
|
while set.join_next().await.is_some() {
|
||||||
tracing::trace!("Throwing away joined result");
|
tracing::trace!("Throwing away joined result");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,14 +299,18 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
opt = m.next(), if !m.is_closed() => {
|
opt = m.next(), if !is_closed => {
|
||||||
tracing::trace!("Selected stream");
|
tracing::trace!("Selected stream");
|
||||||
|
is_closed = opt.is_none();
|
||||||
|
|
||||||
if let Some(res) = opt {
|
if let Some(res) = opt {
|
||||||
match res {
|
match res {
|
||||||
Ok(field) => {
|
Ok(field) => {
|
||||||
set.spawn_local(handle_stream_field(field, Rc::clone(&form)));
|
set.spawn_local(handle_stream_field(field, Rc::clone(&form)));
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
is_closed = true;
|
||||||
|
stream_error = true;
|
||||||
error = Some(e.into());
|
error = Some(e.into());
|
||||||
continue 'outer;
|
continue 'outer;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue