Ensure veirfy before returning DigestVerify

This commit is contained in:
asonix 2023-09-09 18:05:18 -04:00
parent 92a73f0313
commit f0dc14d5f1
2 changed files with 25 additions and 9 deletions

View file

@ -1,7 +1,7 @@
[package] [package]
name = "http-signature-normalization-actix" name = "http-signature-normalization-actix"
description = "An HTTP Signatures library that leaves the signing to you" description = "An HTTP Signatures library that leaves the signing to you"
version = "0.10.2" version = "0.10.3"
authors = ["asonix <asonix@asonix.dog>"] authors = ["asonix <asonix@asonix.dog>"]
license = "AGPL-3.0" license = "AGPL-3.0"
readme = "README.md" readme = "README.md"

View file

@ -17,7 +17,7 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use streem::IntoStreamer; use streem::IntoStreamer;
use tokio::sync::mpsc; use tokio::sync::{mpsc, oneshot};
use tracing::{debug, Span}; use tracing::{debug, Span};
use tracing_error::SpanTrace; use tracing_error::SpanTrace;
@ -125,22 +125,30 @@ where
} }
} }
struct VerifiedReceiver {
rx: Option<oneshot::Receiver<()>>,
}
impl FromRequest for DigestVerified { impl FromRequest for DigestVerified {
type Error = VerifyError; type Error = VerifyError;
type Future = Ready<Result<Self, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let res = req let res = req
.extensions() .extensions_mut()
.get::<Self>() .get_mut::<VerifiedReceiver>()
.copied() .and_then(|r| r.rx.take())
.ok_or_else(|| VerifyError::new(&Span::current(), VerifyErrorKind::Extension)); .ok_or_else(|| VerifyError::new(&Span::current(), VerifyErrorKind::Extension));
if res.is_err() { if res.is_err() {
debug!("Failed to fetch DigestVerified from request"); debug!("Failed to fetch DigestVerified from request");
} }
ready(res) Box::pin(async move {
res?.await
.map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))
.map(|()| DigestVerified)
})
} }
} }
@ -203,12 +211,16 @@ where
let spawner = self.1.clone(); let spawner = self.1.clone();
let (tx, rx) = mpsc::channel(1); let (tx, rx) = mpsc::channel(1);
let f1 = span.in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx)); let (verify_tx, verify_rx) = oneshot::channel();
let f1 = span
.in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx, verify_tx));
let payload: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>> + 'static>> = let payload: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>> + 'static>> =
Box::pin(RxStream(rx)); Box::pin(RxStream(rx));
req.set_payload(payload.into()); req.set_payload(payload.into());
req.extensions_mut().insert(DigestVerified); req.extensions_mut().insert(VerifiedReceiver {
rx: Some(verify_rx),
});
let f2 = self.0.call(req); let f2 = self.0.call(req);
@ -238,6 +250,7 @@ async fn verify_payload<T, Spawner>(
mut verify_digest: T, mut verify_digest: T,
payload: Payload, payload: Payload,
tx: mpsc::Sender<web::Bytes>, tx: mpsc::Sender<web::Bytes>,
verify_tx: oneshot::Sender<()>,
) -> Result<(), actix_web::Error> ) -> Result<(), actix_web::Error>
where where
T: DigestVerify + Clone + Send + 'static, T: DigestVerify + Clone + Send + 'static,
@ -264,6 +277,9 @@ where
.await??; .await??;
if verified { if verified {
if verify_tx.send(()).is_err() {
debug!("handler dropped");
}
Ok(()) Ok(())
} else { } else {
Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into()) Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into())