//! Types for setting up Digest middleware verification use actix_web::{ dev::{Body, Payload, Service, ServiceRequest, ServiceResponse, Transform}, error::PayloadError, http::{header::HeaderValue, StatusCode}, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError, }; use bytes::{Bytes, BytesMut}; use futures::{ future::{err, ok, ready, Ready}, stream::once, Stream, StreamExt, }; use std::{ cell::RefCell, future::Future, pin::Pin, rc::Rc, task::{Context, Poll}, }; use super::{DigestPart, DigestVerify}; #[derive(Copy, Clone, Debug)] /// A type implementing FromRequest that can be used in route handler to guard for verified /// digests /// /// This is only required when the [`VerifyDigest`] middleware is set to optional pub struct DigestVerified; /// The VerifyDigest middleware /// /// ```rust,ignore /// let middleware = VerifyDigest::new(MyVerify::new()) /// .optional(); /// /// HttpServer::new(move || { /// App::new() /// .wrap(middleware.clone()) /// .route("/protected", web::post().to(|_: DigestVerified| "Verified Digest Header")) /// .route("/unprotected", web::post().to(|| "No verification required")) /// }) /// ``` pub struct VerifyDigest(bool, T); #[doc(hidden)] pub struct VerifyMiddleware(Rc>, bool, T); #[derive(Debug, thiserror::Error)] #[error("Error verifying digest")] #[doc(hidden)] pub struct VerifyError; impl VerifyDigest where T: DigestVerify + Clone, { /// Produce a new VerifyDigest with a user-provided [`Digestverify`] type pub fn new(verify_digest: T) -> Self { VerifyDigest(true, verify_digest) } /// Mark verifying the Digest as optional /// /// If a digest is present in the request, it will be verified, but it is not required to be /// present pub fn optional(self) -> Self { VerifyDigest(false, self.1) } } impl FromRequest for DigestVerified { type Error = VerifyError; type Future = Ready>; type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { ready( req.extensions() .get::() .map(|s| *s) .ok_or(VerifyError), ) } } impl Transform for VerifyDigest where T: DigestVerify + Clone + 'static, S: Service< Request = ServiceRequest, Response = ServiceResponse, Error = actix_web::Error, > + 'static, S::Error: 'static, { type Request = ServiceRequest; type Response = ServiceResponse; type Error = actix_web::Error; type Transform = VerifyMiddleware; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(VerifyMiddleware( Rc::new(RefCell::new(service)), self.0, self.1.clone(), )) } } impl Service for VerifyMiddleware where T: DigestVerify + Clone + 'static, S: Service< Request = ServiceRequest, Response = ServiceResponse, Error = actix_web::Error, > + 'static, S::Error: 'static, { type Request = ServiceRequest; type Response = ServiceResponse; type Error = actix_web::Error; type Future = Pin>>>; fn poll_ready(&mut self, cx: &mut Context) -> Poll> { self.0.borrow_mut().poll_ready(cx) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { if let Some(digest) = req.headers().get("Digest") { let vec = match parse_digest(digest) { Some(vec) => vec, None => return Box::pin(err(VerifyError.into())), }; let mut payload = req.take_payload(); let service = self.0.clone(); let mut verify_digest = self.2.clone(); Box::pin(async move { let mut output_bytes = BytesMut::new(); while let Some(res) = payload.next().await { let bytes = res?; output_bytes.extend(bytes); } let bytes = output_bytes.freeze(); if verify_digest.verify(&vec, &bytes.as_ref()) { req.set_payload( (Box::pin(once(ok(bytes))) as Pin> + 'static>>) .into(), ); req.extensions_mut().insert(DigestVerified); service.borrow_mut().call(req).await } else { Err(VerifyError.into()) } }) } else if self.1 { Box::pin(err(VerifyError.into())) } else { Box::pin(self.0.borrow_mut().call(req)) } } } fn parse_digest(h: &HeaderValue) -> Option> { let h = h.to_str().ok()?.split(";").next()?; let v: Vec<_> = h .split(",") .filter_map(|p| { let mut iter = p.splitn(2, "="); iter.next() .and_then(|alg| iter.next().map(|value| (alg, value))) }) .map(|(alg, value)| DigestPart { algorithm: alg.to_owned(), digest: value.to_owned(), }) .collect(); if v.is_empty() { None } else { Some(v) } } impl ResponseError for VerifyError { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST } fn error_response(&self) -> HttpResponse { HttpResponse::BadRequest().finish() } }