From 369a1e8a96030d645c95a7ccaea3457812bd0ee5 Mon Sep 17 00:00:00 2001 From: asonix Date: Sun, 10 Sep 2023 13:20:35 -0400 Subject: [PATCH] Replace spawned tasks with inline payload stream processing --- actix/src/digest/middleware.rs | 92 ++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 33 deletions(-) diff --git a/actix/src/digest/middleware.rs b/actix/src/digest/middleware.rs index 071f628..9db03b1 100644 --- a/actix/src/digest/middleware.rs +++ b/actix/src/digest/middleware.rs @@ -1,6 +1,6 @@ //! Types for setting up Digest middleware verification -use crate::{DefaultSpawner, Spawn}; +use crate::{Canceled, DefaultSpawner, Spawn}; use super::{DigestPart, DigestVerify}; use actix_web::{ @@ -16,7 +16,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use streem::IntoStreamer; +use streem::{from_fn::Yielder, IntoStreamer}; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, Span}; use tracing_error::SpanTrace; @@ -207,30 +207,23 @@ where ))); } }; - let payload = req.take_payload(); + let spawner = self.1.clone(); - - let (tx, rx) = mpsc::channel(1); + let digest = self.3.clone(); let (verify_tx, verify_rx) = oneshot::channel(); - let f1 = span - .in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx, verify_tx)); + let payload = req.take_payload(); let payload: Pin> + 'static>> = - Box::pin(RxStream(rx)); + Box::pin(streem::try_from_fn(|yielder| async move { + verify_payload(yielder, spawner, vec, digest, payload, verify_tx).await + })); req.set_payload(payload.into()); + req.extensions_mut().insert(VerifiedReceiver { rx: Some(verify_rx), }); - let f2 = self.0.call(req); - - Box::pin(async move { - let handle1 = actix_web::rt::spawn(f1); - let handle2 = actix_web::rt::spawn(f2); - - handle1.await.expect("verify panic")?; - handle2.await.expect("inner panic") - }) + Box::pin(self.0.call(req)) } else if self.2 { Box::pin(ready(Err(VerifyError::new( &span, @@ -243,46 +236,79 @@ where } } -#[tracing::instrument(name = "Verify Payload", skip(spawner, verify_digest, payload, tx))] +fn canceled_error(error: Canceled) -> PayloadError { + PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error)) +} + +fn verified_error(error: VerifyError) -> PayloadError { + PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error)) +} + async fn verify_payload( + yielder: Yielder>, spawner: Spawner, vec: Vec, mut verify_digest: T, payload: Payload, - tx: mpsc::Sender, verify_tx: oneshot::Sender<()>, -) -> Result<(), actix_web::Error> +) -> Result<(), PayloadError> where T: DigestVerify + Clone + Send + 'static, Spawner: Spawn, { let mut payload = payload.into_streamer(); - while let Some(bytes) = payload.try_next().await? { - let bytes2 = bytes.clone(); - verify_digest = spawner - .spawn_blocking(move || { - verify_digest.update(bytes2.as_ref()); - Ok(verify_digest) as Result - }) - .await??; + let mut error = None; - tx.send(bytes) - .await - .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))?; + while let Some(bytes) = payload.try_next().await? { + if error.is_none() { + let bytes2 = bytes.clone(); + let mut verify_digest2 = verify_digest.clone(); + + let task = spawner.spawn_blocking(move || { + verify_digest2.update(bytes2.as_ref()); + Ok(verify_digest2) as Result + }); + + yielder.yield_ok(bytes).await; + + match task.await { + Ok(Ok(digest)) => { + verify_digest = digest; + } + Ok(Err(e)) => { + error = Some(verified_error(e)); + } + Err(e) => { + error = Some(canceled_error(e)); + } + } + } else { + yielder.yield_ok(bytes).await; + } + } + + if let Some(error) = error { + return Err(error); } let verified = spawner .spawn_blocking(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>) - .await??; + .await + .map_err(canceled_error)? + .map_err(verified_error)?; if verified { if verify_tx.send(()).is_err() { debug!("handler dropped"); } + Ok(()) } else { - Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into()) + Err(verified_error(VerifyError::new( + &Span::current(), + VerifyErrorKind::Verify, + ))) } }