Simplify stream implementations with streem

This commit is contained in:
asonix 2023-12-10 18:16:25 -06:00
parent 8b422644fb
commit 0f7614ec3b
2 changed files with 83 additions and 114 deletions

View file

@ -19,7 +19,7 @@
use std::{ use std::{
num::{ParseFloatError, ParseIntError}, num::{ParseFloatError, ParseIntError},
string::FromUtf8Error, str::Utf8Error,
}; };
use actix_web::{ use actix_web::{
@ -35,7 +35,7 @@ pub enum Error {
#[error("Error in multipart creation")] #[error("Error in multipart creation")]
Multipart(#[from] MultipartError), Multipart(#[from] MultipartError),
#[error("Failed to parse field")] #[error("Failed to parse field")]
ParseField(#[from] FromUtf8Error), ParseField(#[from] Utf8Error),
#[error("Failed to parse int")] #[error("Failed to parse int")]
ParseInt(#[from] ParseIntError), ParseInt(#[from] ParseIntError),
#[error("Failed to parse float")] #[error("Failed to parse float")]

View file

@ -25,43 +25,9 @@ use crate::{
}, },
}; };
use actix_web::web::BytesMut; use actix_web::web::BytesMut;
use std::{collections::HashMap, future::poll_fn, path::Path, pin::Pin, rc::Rc}; use std::{collections::HashMap, path::Path, rc::Rc};
use tokio::{sync::mpsc::Receiver, task::JoinSet}; use streem::IntoStreamer;
use tracing::trace; use tokio::task::JoinSet;
struct Streamer<S>(S, bool);
impl<S> Streamer<S>
where
S: futures_core::Stream + Unpin,
{
async fn next(&mut self) -> Option<S::Item> {
if self.1 {
return None;
}
let opt = poll_fn(|cx| Pin::new(&mut self.0).poll_next(cx)).await;
self.1 = opt.is_none();
opt
}
fn is_closed(&self) -> bool {
self.1
}
}
struct ReceiverStream<T>(Receiver<T>);
impl<T> futures_core::Stream for ReceiverStream<T> {
type Item = T;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
Pin::new(&mut self.0).poll_recv(cx)
}
}
fn consolidate<T>(mf: MultipartForm<T>) -> Value<T> { fn consolidate<T>(mf: MultipartForm<T>) -> Value<T> {
mf.into_iter().fold( mf.into_iter().fold(
@ -100,16 +66,13 @@ fn parse_multipart_name(name: String) -> Result<Vec<NamePart>, Error> {
NamePart::Map(part.to_owned()) NamePart::Map(part.to_owned())
} }
}) })
.fold(Ok(vec![]), |acc, part| match acc { .try_fold(vec![], |mut v, part| {
Ok(mut v) => { if v.is_empty() && !part.is_map() {
if v.is_empty() && !part.is_map() { return Err(Error::ContentDisposition);
return Err(Error::ContentDisposition);
}
v.push(part);
Ok(v)
} }
Err(e) => Err(e),
v.push(part);
Ok(v)
}) })
} }
@ -143,42 +106,40 @@ where
let max_file_size = form.max_file_size; let max_file_size = form.max_file_size;
let (tx, rx) = tokio::sync::mpsc::channel(8); let field_stream = streem::try_from_fn(move |yielder| async move {
let consume_fut = async move {
let mut file_size = 0; let mut file_size = 0;
let mut exceeded_size = false;
let mut stream = Streamer(field, false); let mut stream = field.into_streamer();
while let Some(res) = stream.next().await { while let Some(bytes) = stream.try_next().await? {
tracing::trace!("Bytes from field"); tracing::trace!("Bytes from field");
if exceeded_size { file_size += bytes.len();
tracing::trace!("Dropping oversized bytes");
continue; if file_size > max_file_size {
drop(bytes);
while stream.try_next().await?.is_some() {
tracing::trace!("Dropping oversized bytes");
}
return Err(Error::FileSize);
} }
if let Ok(bytes) = &res { yielder.yield_ok(bytes).await;
file_size += bytes.len();
if file_size > max_file_size {
exceeded_size = true;
let _ = tx.send(Err(Error::FileSize)).await;
}
};
let _ = tx.send(res.map_err(Error::from)).await;
} }
drop(tx);
tracing::debug!("Finished consuming field"); tracing::debug!("Finished consuming field");
};
let stream = ReceiverStream(rx); Ok(())
});
let result_fut = file_fn(filename.clone(), content_type.clone(), Box::pin(stream)); let result = file_fn(
filename.clone(),
let (_, result) = tokio::join!(consume_fut, result_fut); content_type.clone(),
Box::pin(field_stream),
)
.await;
match result { match result {
Ok(result) => Ok(Ok(MultipartContent::File(FileMeta { Ok(result) => Ok(Ok(MultipartContent::File(FileMeta {
@ -199,59 +160,54 @@ where
T: 'static, T: 'static,
E: 'static, E: 'static,
{ {
trace!("In handle_form_data, term: {:?}", term); tracing::trace!("In handle_form_data, term: {:?}", term);
let mut bytes = BytesMut::new(); let mut buf = Vec::new();
let mut exceeded_size = false; let mut stream = field.into_streamer();
let mut error = None;
let mut stream = Streamer(field, false); while let Some(bytes) = stream.try_next().await? {
while let Some(res) = stream.next().await {
tracing::trace!("bytes from field"); tracing::trace!("bytes from field");
if exceeded_size { if buf.len() + bytes.len() > form.max_field_size {
tracing::trace!("Dropping oversized bytes"); drop(buf);
continue;
}
if error.is_some() { while stream.try_next().await?.is_some() {
tracing::trace!("Draining field while error exists"); tracing::trace!("Dropping oversized bytes");
continue;
}
let b = match res {
Ok(bytes) => bytes,
Err(e) if error.is_none() => {
error = Some(e);
continue;
} }
Err(_) => continue,
};
if bytes.len() + b.len() > form.max_field_size { return Err(Error::FieldSize);
exceeded_size = true;
continue;
} }
bytes.extend(b); buf.push(bytes);
} }
let bytes = match buf.len() {
0 => return Err(Error::FieldSize),
1 => buf.pop().expect("contains an element"),
_ => {
let total_length = buf.iter().map(|b| b.len()).sum();
let mut bytes = BytesMut::with_capacity(total_length);
for b in buf {
bytes.extend(b);
}
bytes.freeze()
}
};
tracing::debug!("Finished consuming field"); tracing::debug!("Finished consuming field");
if let Some(error) = error {
return Err(error.into());
}
if let FieldTerminator::Bytes = term { if let FieldTerminator::Bytes = term {
return Ok(MultipartContent::Bytes(bytes.freeze())); return Ok(MultipartContent::Bytes(bytes));
} }
let s = String::from_utf8(bytes.to_vec()).map_err(Error::ParseField)?; let s = std::str::from_utf8(&bytes).map_err(Error::ParseField)?;
match term { match term {
FieldTerminator::Bytes | FieldTerminator::File(_) => Err(Error::FieldType), FieldTerminator::Bytes | FieldTerminator::File(_) => Err(Error::FieldType),
FieldTerminator::Text => Ok(MultipartContent::Text(s)), FieldTerminator::Text => Ok(MultipartContent::Text(String::from(s))),
FieldTerminator::Float => s FieldTerminator::Float => s
.parse() .parse()
.map_err(Error::ParseFloat) .map_err(Error::ParseFloat)
@ -294,6 +250,7 @@ where
} }
/// Handle multipart streams from Actix Web /// Handle multipart streams from Actix Web
#[tracing::instrument(level = "TRACE", skip_all)]
pub async fn handle_multipart<T, E>( pub async fn handle_multipart<T, E>(
m: actix_multipart::Multipart, m: actix_multipart::Multipart,
form: Rc<Form<T, E>>, form: Rc<Form<T, E>>,
@ -308,25 +265,33 @@ where
let mut set = JoinSet::new(); let mut set = JoinSet::new();
let mut m = Streamer(m, false); let mut m = m.into_streamer();
let mut error: Option<Error> = None; let mut error: Option<Error> = None;
let mut provided_error: Option<E> = None; let mut provided_error: Option<E> = None;
let mut is_closed = false;
let mut stream_error = false;
'outer: loop { 'outer: loop {
tracing::trace!("multipart loop"); tracing::trace!("multipart loop");
if error.is_some() || provided_error.is_some() { if error.is_some() || provided_error.is_some() {
while let Some(res) = m.next().await { set.abort_all();
if let Ok(field) = res {
let mut stream = Streamer(field, false); if !stream_error {
while let Some(_res) = stream.next().await { while let Some(res) = m.next().await {
tracing::trace!("Throwing away uploaded bytes, we have an error"); if let Ok(field) = res {
let mut stream = field.into_streamer();
while stream.next().await.is_some() {
tracing::trace!("Throwing away uploaded bytes, we have an error");
}
} else {
break;
} }
} }
} }
while let Some(_) = set.join_next().await { while set.join_next().await.is_some() {
tracing::trace!("Throwing away joined result"); tracing::trace!("Throwing away joined result");
} }
@ -334,14 +299,18 @@ where
} }
tokio::select! { tokio::select! {
opt = m.next(), if !m.is_closed() => { opt = m.next(), if !is_closed => {
tracing::trace!("Selected stream"); tracing::trace!("Selected stream");
is_closed = opt.is_none();
if let Some(res) = opt { if let Some(res) = opt {
match res { match res {
Ok(field) => { Ok(field) => {
set.spawn_local(handle_stream_field(field, Rc::clone(&form))); set.spawn_local(handle_stream_field(field, Rc::clone(&form)));
}, },
Err(e) => { Err(e) => {
is_closed = true;
stream_error = true;
error = Some(e.into()); error = Some(e.into());
continue 'outer; continue 'outer;
} }