diff --git a/actix/Cargo.toml b/actix/Cargo.toml index 9c32f15..dc1ec18 100644 --- a/actix/Cargo.toml +++ b/actix/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "http-signature-normalization-actix" description = "An HTTP Signatures library that leaves the signing to you" -version = "0.9.1" +version = "0.10.0" authors = ["asonix "] license = "AGPL-3.0" readme = "README.md" diff --git a/actix/examples/client.rs b/actix/examples/client.rs index 46cc204..02099ce 100644 --- a/actix/examples/client.rs +++ b/actix/examples/client.rs @@ -45,7 +45,7 @@ async fn main() -> Result<(), Box> { tracing::subscriber::set_global_default(subscriber)?; - let config = Config::default().require_header("accept").require_digest(); + let config = Config::new().require_header("accept").require_digest(); request(config.clone()).await?; request(config.mastodon_compat()).await?; diff --git a/actix/examples/server.rs b/actix/examples/server.rs index cf42c3b..c296551 100644 --- a/actix/examples/server.rs +++ b/actix/examples/server.rs @@ -63,7 +63,7 @@ async fn main() -> Result<(), Box> { tracing::subscriber::set_global_default(subscriber)?; - let config = Config::default().require_header("accept").require_digest(); + let config = Config::new().require_header("accept").require_digest(); HttpServer::new(move || { App::new() diff --git a/actix/src/digest/middleware.rs b/actix/src/digest/middleware.rs index fe016ef..18c8669 100644 --- a/actix/src/digest/middleware.rs +++ b/actix/src/digest/middleware.rs @@ -1,5 +1,7 @@ //! Types for setting up Digest middleware verification +use crate::{DefaultSpawner, Spawn}; + use super::{DigestPart, DigestVerify}; use actix_web::{ body::MessageBody, @@ -42,10 +44,10 @@ pub struct DigestVerified; /// .route("/unprotected", web::post().to(|| "No verification required")) /// }) /// ``` -pub struct VerifyDigest(bool, T); +pub struct VerifyDigest(Spawner, bool, T); #[doc(hidden)] -pub struct VerifyMiddleware(S, bool, T); +pub struct VerifyMiddleware(S, Spawner, bool, T); #[derive(Debug, thiserror::Error)] #[error("Error verifying digest")] @@ -98,7 +100,22 @@ where { /// Produce a new VerifyDigest with a user-provided [`Digestverify`] type pub fn new(verify_digest: T) -> Self { - VerifyDigest(true, verify_digest) + VerifyDigest(DefaultSpawner, true, verify_digest) + } +} + +impl VerifyDigest +where + T: DigestVerify + Clone, +{ + /// Set the spawner used for verifying bytes in the request + /// + /// By default this value uses `actix_web::web::block` + pub fn spawner(self, spawner: NewSpawner) -> VerifyDigest + where + NewSpawner: Spawn, + { + VerifyDigest(spawner, self.1, self.2) } /// Mark verifying the Digest as optional @@ -106,7 +123,7 @@ where /// If a digest is present in the request, it will be verified, but it is not required to be /// present pub fn optional(self) -> Self { - VerifyDigest(false, self.1) + VerifyDigest(self.0, false, self.2) } } @@ -129,30 +146,37 @@ impl FromRequest for DigestVerified { } } -impl Transform for VerifyDigest +impl Transform for VerifyDigest where T: DigestVerify + Clone + Send + 'static, S: Service, Error = actix_web::Error> + 'static, S::Error: 'static, B: MessageBody + 'static, + Spawner: Spawn + Clone + 'static, { type Response = ServiceResponse; type Error = actix_web::Error; - type Transform = VerifyMiddleware; + type Transform = VerifyMiddleware; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ready(Ok(VerifyMiddleware(service, self.0, self.1.clone()))) + ready(Ok(VerifyMiddleware( + service, + self.0.clone(), + self.1, + self.2.clone(), + ))) } } -impl Service for VerifyMiddleware +impl Service for VerifyMiddleware where T: DigestVerify + Clone + Send + 'static, S: Service, Error = actix_web::Error> + 'static, S::Error: 'static, B: MessageBody + 'static, + Spawner: Spawn + Clone + 'static, { type Response = ServiceResponse; type Error = actix_web::Error; @@ -165,7 +189,7 @@ where fn call(&self, mut req: ServiceRequest) -> Self::Future { let span = tracing::info_span!( "Verify digest", - digest.required = tracing::field::display(&self.1), + digest.required = tracing::field::display(&self.2), ); if let Some(digest) = req.headers().get("Digest") { @@ -178,9 +202,10 @@ where } }; let payload = req.take_payload(); + let spawner = self.1.clone(); let (tx, rx) = mpsc::channel(1); - let f1 = span.in_scope(|| verify_payload(vec, self.2.clone(), payload, tx)); + let f1 = span.in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx)); let payload: Pin> + 'static>> = Box::pin(RxStream(rx).map(Ok)); @@ -193,7 +218,7 @@ where let (_, res) = futures_util::future::join(f1, f2).await; res }) - } else if self.1 { + } else if self.2 { Box::pin(ready(Err(VerifyError::new( &span, VerifyErrorKind::MissingDigest, @@ -205,8 +230,9 @@ where } } -#[tracing::instrument(name = "Verify Payload", skip(verify_digest, payload, tx))] -async fn verify_payload( +#[tracing::instrument(name = "Verify Payload", skip(spawner, verify_digest, payload, tx))] +async fn verify_payload( + spawner: Spawner, vec: Vec, mut verify_digest: T, mut payload: Payload, @@ -214,23 +240,26 @@ async fn verify_payload( ) -> Result<(), actix_web::Error> where T: DigestVerify + Clone + Send + 'static, + Spawner: Spawn, { while let Some(res) = payload.next().await { let bytes = res?; let bytes2 = bytes.clone(); - verify_digest = web::block(move || { - verify_digest.update(bytes2.as_ref()); - Ok(verify_digest) as Result - }) - .await??; + verify_digest = spawner + .spawn_blocking(move || { + verify_digest.update(bytes2.as_ref()); + Ok(verify_digest) as Result + }) + .await??; tx.send(bytes) .await .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))?; } - let verified = - web::block(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>).await??; + let verified = spawner + .spawn_blocking(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>) + .await??; if verified { Ok(()) diff --git a/actix/src/lib.rs b/actix/src/lib.rs index 63d57a8..ef51b8f 100644 --- a/actix/src/lib.rs +++ b/actix/src/lib.rs @@ -257,7 +257,7 @@ pub mod verify { } #[cfg(feature = "client")] -pub use self::client::{Canceled, PrepareSignError, Sign, Spawn}; +pub use self::client::{PrepareSignError, Sign}; #[cfg(feature = "server")] pub use self::server::{PrepareVerifyError, SignatureVerify}; @@ -285,9 +285,68 @@ pub struct Config { #[derive(Clone, Copy, Debug, Default)] pub struct DefaultSpawner; +/// An error that indicates a blocking operation panicked and cannot return a response +#[derive(Debug)] +pub struct Canceled; + +impl std::fmt::Display for Canceled { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Operation was canceled") + } +} + +impl std::error::Error for Canceled {} + +/// A trait dictating how to spawn a future onto a blocking threadpool. By default, +/// http-signature-normalization-actix will use actix_rt's built-in blocking threadpool, but this +/// can be customized +pub trait Spawn { + /// The future type returned by spawn_blocking + type Future: std::future::Future>; + + /// Spawn the blocking function onto the threadpool + fn spawn_blocking(&self, func: Func) -> Self::Future + where + Func: FnOnce() -> Out + Send + 'static, + Out: Send + 'static; +} + +/// The future returned by DefaultSpawner when spawning blocking operations on the actix_rt +/// blocking threadpool +pub struct DefaultSpawnerFuture { + inner: actix_rt::task::JoinHandle, +} + +impl Spawn for DefaultSpawner { + type Future = DefaultSpawnerFuture; + + fn spawn_blocking(&self, func: Func) -> Self::Future + where + Func: FnOnce() -> Out + Send + 'static, + Out: Send + 'static, + { + DefaultSpawnerFuture { + inner: actix_rt::task::spawn_blocking(func), + } + } +} + +impl std::future::Future for DefaultSpawnerFuture { + type Output = Result; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let res = std::task::ready!(std::pin::Pin::new(&mut self.inner).poll(cx)); + + std::task::Poll::Ready(res.map_err(|_| Canceled)) + } +} + #[cfg(feature = "client")] mod client { - use super::{Config, DefaultSpawner, RequiredError}; + use super::{Config, RequiredError, Spawn}; use actix_http::header::{InvalidHeaderValue, ToStrError}; use actix_rt::task::JoinError; use std::{fmt::Display, future::Future, pin::Pin}; @@ -354,65 +413,6 @@ mod client { /// Invalid Date header InvalidHeader(#[from] actix_http::header::InvalidHeaderValue), } - - /// An error that indicates a blocking operation panicked and cannot return a response - #[derive(Debug)] - pub struct Canceled; - - impl std::fmt::Display for Canceled { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Operation was canceled") - } - } - - impl std::error::Error for Canceled {} - - /// A trait dictating how to spawn a future onto a blocking threadpool. By default, - /// http-signature-normalization-actix will use actix_rt's built-in blocking threadpool, but this - /// can be customized - pub trait Spawn { - /// The future type returned by spawn_blocking - type Future: std::future::Future>; - - /// Spawn the blocking function onto the threadpool - fn spawn_blocking(&self, func: Func) -> Self::Future - where - Func: FnOnce() -> Out + Send + 'static, - Out: Send + 'static; - } - - /// The future returned by DefaultSpawner when spawning blocking operations on the actix_rt - /// blocking threadpool - pub struct DefaultSpawnerFuture { - inner: actix_rt::task::JoinHandle, - } - - impl Spawn for DefaultSpawner { - type Future = DefaultSpawnerFuture; - - fn spawn_blocking(&self, func: Func) -> Self::Future - where - Func: FnOnce() -> Out + Send + 'static, - Out: Send + 'static, - { - DefaultSpawnerFuture { - inner: actix_rt::task::spawn_blocking(func), - } - } - } - - impl std::future::Future for DefaultSpawnerFuture { - type Output = Result; - - fn poll( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - let res = std::task::ready!(std::pin::Pin::new(&mut self.inner).poll(cx)); - - std::task::Poll::Ready(res.map_err(|_| Canceled)) - } - } } #[cfg(feature = "server")] @@ -485,6 +485,16 @@ mod server { } } } + + impl actix_web::ResponseError for super::Canceled { + fn status_code(&self) -> actix_http::StatusCode { + actix_http::StatusCode::INTERNAL_SERVER_ERROR + } + + fn error_response(&self) -> actix_web::HttpResponse { + actix_web::HttpResponse::new(self.status_code()) + } + } } impl Config { diff --git a/actix/src/middleware.rs b/actix/src/middleware.rs index 571f021..693ead4 100644 --- a/actix/src/middleware.rs +++ b/actix/src/middleware.rs @@ -1,6 +1,6 @@ //! Types for verifying requests with Actix Web -use crate::{Config, PrepareVerifyError, SignatureVerify}; +use crate::{Config, PrepareVerifyError, SignatureVerify, Spawn}; use actix_web::{ body::MessageBody, dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform}, @@ -45,11 +45,11 @@ impl SignatureVerified { /// .route("/unprotected", web::post().to(|| "No verification required")) /// }) /// ``` -pub struct VerifySignature(T, Config, HeaderKind); +pub struct VerifySignature(T, Config, HeaderKind); #[derive(Debug)] #[doc(hidden)] -pub struct VerifyMiddleware(Rc, Config, HeaderKind, T); +pub struct VerifyMiddleware(Rc, Config, HeaderKind, T); #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] enum HeaderKind { @@ -126,7 +126,7 @@ impl VerifyError { } } -impl VerifySignature +impl VerifySignature where T: SignatureVerify, { @@ -135,7 +135,10 @@ where /// /// By default, this middleware expects to verify Signature headers, and requires the presence /// of the header - pub fn new(verify_signature: T, config: Config) -> Self { + pub fn new(verify_signature: T, config: Config) -> Self + where + Spawner: Spawn, + { VerifySignature(verify_signature, config, HeaderKind::Signature) } @@ -145,7 +148,7 @@ where } } -impl VerifyMiddleware +impl VerifyMiddleware where T: SignatureVerify + Clone + 'static, T::Future: 'static, @@ -257,16 +260,17 @@ impl FromRequest for SignatureVerified { } } -impl Transform for VerifySignature +impl Transform for VerifySignature where T: SignatureVerify + Clone + 'static, S: Service, Error = actix_web::Error> + 'static, S::Error: 'static, B: MessageBody + 'static, + Spawner: Clone, { type Response = ServiceResponse; type Error = actix_web::Error; - type Transform = VerifyMiddleware; + type Transform = VerifyMiddleware; type InitError = (); type Future = Ready>; @@ -280,7 +284,7 @@ where } } -impl Service for VerifyMiddleware +impl Service for VerifyMiddleware where T: SignatureVerify + Clone + 'static, S: Service, Error = actix_web::Error> + 'static,