From 4ec8205934197fe962420c72647c0868af86de08 Mon Sep 17 00:00:00 2001 From: asonix Date: Sun, 16 Jul 2023 14:08:18 -0500 Subject: [PATCH] Try to avoid spinning forever --- Cargo.toml | 1 - src/upload.rs | 100 ++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 76 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bfa2a7e..957afca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ futures-core = "0.3.28" mime = "0.3.16" thiserror = "1.0" tokio = { version = "1", default-features = false, features = ["macros", "sync"] } -tokio-stream = "0.1.14" tracing = "0.1.15" [dev-dependencies] diff --git a/src/upload.rs b/src/upload.rs index 56fae67..cf73cdf 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -26,16 +26,16 @@ use crate::{ }; use actix_web::web::BytesMut; use std::{collections::HashMap, future::poll_fn, path::Path, pin::Pin, rc::Rc}; -use tokio::task::JoinSet; +use tokio::{sync::mpsc::Receiver, task::JoinSet}; use tracing::trace; struct Streamer(S, bool); -impl Streamer { - async fn next(&mut self) -> Option - where - S: futures_core::Stream + Unpin, - { +impl Streamer +where + S: futures_core::Stream + Unpin, +{ + async fn next(&mut self) -> Option { if self.1 { return None; } @@ -46,6 +46,19 @@ impl Streamer { } } +struct ReceiverStream(Receiver); + +impl futures_core::Stream for ReceiverStream { + type Item = T; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.0).poll_recv(cx) + } +} + fn consolidate(mf: MultipartForm) -> Value { mf.into_iter().fold( Value::Map(HashMap::new()), @@ -152,7 +165,7 @@ where } }; - let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let stream = ReceiverStream(rx); let result_fut = file_fn(filename.clone(), content_type.clone(), Box::pin(stream)); @@ -290,46 +303,85 @@ where let mut provided_error: Option = None; 'outer: loop { + if error.is_some() || provided_error.is_some() { + while let Some(res) = m.next().await { + if let Ok(field) = res { + let mut stream = Streamer(field, false); + while let Some(_res) = stream.next().await { + tracing::trace!("Throwing away uploaded bytes, we have an error"); + } + } + } + + break 'outer; + } + + if stream_finished { + while let Some(res) = set.join_next().await { + let (name_parts, content) = match res { + Ok(Ok(Ok(tup))) => tup, + Ok(Ok(Err(e))) => { + provided_error = Some(e); + continue 'outer; + } + Ok(Err(e)) => { + error = Some(e); + continue 'outer; + } + Err(e) => { + error = Some(e.into()); + continue 'outer; + } + }; + + let (l, r) = match count(&content, file_count, field_count, &form) { + Ok(tup) => tup, + Err(e) => { + error = Some(e); + continue 'outer; + } + }; + + file_count = l; + field_count = r; + + multipart_form.push((name_parts, content)); + } + + break 'outer; + } + tokio::select! { opt = m.next() => { if let Some(res) = opt { - if error.is_some() || provided_error.is_some() { - continue; - } - match res { Ok(field) => { set.spawn_local(handle_stream_field(field, Rc::clone(&form))); }, - Err(e) => error = Some(e.into()), + Err(e) => { + error = Some(e.into()); + continue 'outer; + } } } else { stream_finished = true; } } opt = set.join_next() => { - if error.is_some() || provided_error.is_some() { - if stream_finished { - break 'outer; - } - - continue; - } - if let Some(res) = opt { let (name_parts, content) = match res { Ok(Ok(Ok(tup))) => tup, Ok(Ok(Err(e))) => { provided_error = Some(e); - continue; + continue 'outer; } Ok(Err(e)) => { error = Some(e); - continue; + continue 'outer; }, Err(e) => { error = Some(e.into()); - continue; + continue 'outer; }, }; @@ -337,7 +389,7 @@ where Ok(tup) => tup, Err(e) => { error = Some(e); - continue; + continue 'outer; } };