use actix_web::{ dev::{Payload, Service, ServiceRequest, Transform}, http::StatusCode, web, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError, }; use futures::{ channel::oneshot, future::{ok, LocalBoxFuture, Ready}, }; use std::task::{Context, Poll}; use uuid::Uuid; pub(crate) struct Verify; pub(crate) struct VerifyMiddleware(S); pub struct ValidToken { pub(crate) token: Uuid, } impl FromRequest for ValidToken { type Error = TokenError; type Future = LocalBoxFuture<'static, Result>; type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let res = req .extensions_mut() .remove::>() .ok_or(TokenError); Box::pin(async move { res?.await.map_err(|_| TokenError) }) } } #[derive(Clone, Debug, thiserror::Error)] #[error("Invalid token")] pub struct TokenError; #[derive(Clone, Debug, thiserror::Error)] #[error("Invalid token")] struct VerifyError; impl ResponseError for TokenError { fn status_code(&self) -> StatusCode { StatusCode::UNAUTHORIZED } fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type(mime::TEXT_PLAIN.essence_str()) .body(self.to_string()) } } impl ResponseError for VerifyError { fn status_code(&self) -> StatusCode { StatusCode::UNAUTHORIZED } fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type(mime::TEXT_PLAIN.essence_str()) .body(self.to_string()) } } impl Transform for Verify where S: Service, S::Future: 'static, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type InitError = (); type Transform = VerifyMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(VerifyMiddleware(service)) } } impl Service for VerifyMiddleware where S: Service, S::Future: 'static, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type Future = LocalBoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: S::Request) -> Self::Future { let (req, pl) = req.into_parts(); let state_fut = web::Data::::extract(&req); let token_fut = Option::>::extract(&req); let path_fut = web::Path::::extract(&req); let req = ServiceRequest::from_parts(req, pl) .map_err(|_| VerifyError) .unwrap(); let (tx, rx) = oneshot::channel(); req.extensions_mut().insert(rx); let service_fut = self.0.call(req); Box::pin(async move { let state = state_fut.await?; let path = path_fut.await?; let token = token_fut.await?; if let Some(token) = token { let token = token.into_inner(); if verify(&path, token.clone(), &state).await.is_ok() { tx.send(ValidToken { token: token.token }) .map_err(|_| VerifyError)?; } else { drop(tx); } } else { drop(tx); } service_fut.await }) } } async fn verify( path: &crate::AggregationPath, token: crate::Token, state: &crate::State, ) -> Result<(), VerifyError> { let token_storage = state.store.token(path).await.map_err(|_| VerifyError)?; let verified = web::block(move || token_storage.verify(&token)) .await .map_err(|_| VerifyError)?; if !verified { return Err(VerifyError); } Ok(()) }