use actix_web::{ dev::{Payload, Service, ServiceRequest, Transform}, http::StatusCode, web, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError, }; use std::{ future::{ready, Future, Ready}, pin::Pin, task::{Context, Poll}, }; use tokio::sync::oneshot; use uuid::Uuid; type LocalBoxFuture<'a, T> = Pin + 'a>>; pub(crate) struct Verify; pub(crate) struct VerifyMiddleware(S); #[derive(Debug)] pub struct ValidToken { pub(crate) token: Uuid, } impl FromRequest for ValidToken { type Error = TokenError; type Future = LocalBoxFuture<'static, Result>; fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let res = req .extensions_mut() .remove::>() .ok_or(TokenError); Box::pin(async move { res?.await.map_err(|_| TokenError) }) } } #[derive(Clone, Debug, thiserror::Error)] #[error("Invalid token")] pub struct TokenError; #[derive(Clone, Debug, thiserror::Error)] #[error("Invalid token")] struct VerifyError; impl ResponseError for TokenError { fn status_code(&self) -> StatusCode { StatusCode::UNAUTHORIZED } fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type(mime::TEXT_PLAIN.essence_str()) .body(self.to_string()) } } impl ResponseError for VerifyError { fn status_code(&self) -> StatusCode { StatusCode::UNAUTHORIZED } fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type(mime::TEXT_PLAIN.essence_str()) .body(self.to_string()) } } impl Transform for Verify where S: Service, S::Future: 'static, { type Response = S::Response; type Error = S::Error; type InitError = (); type Transform = VerifyMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(VerifyMiddleware(service))) } } impl Service for VerifyMiddleware where S: Service, S::Future: 'static, { type Response = S::Response; type Error = S::Error; type Future = LocalBoxFuture<'static, Result>; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&self, req: ServiceRequest) -> Self::Future { let (req, pl) = req.into_parts(); let state_fut = web::Data::::extract(&req); let token_fut = Option::>::extract(&req); let path_fut = web::Path::::extract(&req); let req = ServiceRequest::from_parts(req, pl); let (tx, rx) = oneshot::channel(); req.extensions_mut().insert(rx); let service_fut = self.0.call(req); Box::pin(async move { let state = state_fut.await?; let path = path_fut.await?; let token = token_fut.await?; if let Some(token) = token { let token = token.into_inner(); if verify(&path, token.clone(), &state).await.is_ok() { tx.send(ValidToken { token: token.token }) .map_err(|_| VerifyError)?; } else { drop(tx); } } else { drop(tx); } service_fut.await }) } } async fn verify( path: &crate::CollectionPath, token: crate::Token, state: &crate::State, ) -> Result<(), VerifyError> { let token_storage = state .store .token(path) .await .map_err(|_| VerifyError)? .ok_or(VerifyError)?; let verified = web::block(move || token_storage.verify(&token)) .await .map_err(|_| VerifyError)? .map_err(|_| VerifyError)?; if !verified { return Err(VerifyError); } Ok(()) }