pict-rs-aggregator/src/middleware.rs
2020-12-08 16:03:18 -06:00

155 lines
4.1 KiB
Rust

use actix_web::{
dev::{Payload, Service, ServiceRequest, Transform},
http::StatusCode,
web, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
};
use futures::{
channel::oneshot,
future::{ok, LocalBoxFuture, Ready},
};
use std::task::{Context, Poll};
use uuid::Uuid;
pub(crate) struct Verify;
pub(crate) struct VerifyMiddleware<S>(S);
pub struct ValidToken {
pub(crate) token: Uuid,
}
impl FromRequest for ValidToken {
type Error = TokenError;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
type Config = ();
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let res = req
.extensions_mut()
.remove::<oneshot::Receiver<Self>>()
.ok_or(TokenError);
Box::pin(async move { res?.await.map_err(|_| TokenError) })
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("Invalid token")]
pub struct TokenError;
#[derive(Clone, Debug, thiserror::Error)]
#[error("Invalid token")]
struct VerifyError;
impl ResponseError for TokenError {
fn status_code(&self) -> StatusCode {
StatusCode::UNAUTHORIZED
}
fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code())
.content_type(mime::TEXT_PLAIN.essence_str())
.body(self.to_string())
}
}
impl ResponseError for VerifyError {
fn status_code(&self) -> StatusCode {
StatusCode::UNAUTHORIZED
}
fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code())
.content_type(mime::TEXT_PLAIN.essence_str())
.body(self.to_string())
}
}
impl<S> Transform<S> for Verify
where
S: Service<Request = ServiceRequest, Error = actix_web::Error>,
S::Future: 'static,
{
type Request = S::Request;
type Response = S::Response;
type Error = S::Error;
type InitError = ();
type Transform = VerifyMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(VerifyMiddleware(service))
}
}
impl<S> Service for VerifyMiddleware<S>
where
S: Service<Request = ServiceRequest, Error = actix_web::Error>,
S::Future: 'static,
{
type Request = S::Request;
type Response = S::Response;
type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<S::Response, S::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: S::Request) -> Self::Future {
let (req, pl) = req.into_parts();
let state_fut = web::Data::<crate::State>::extract(&req);
let token_fut = Option::<web::Query<crate::Token>>::extract(&req);
let path_fut = web::Path::<crate::AggregationPath>::extract(&req);
let req = ServiceRequest::from_parts(req, pl)
.map_err(|_| VerifyError)
.unwrap();
let (tx, rx) = oneshot::channel();
req.extensions_mut().insert(rx);
let service_fut = self.0.call(req);
Box::pin(async move {
let state = state_fut.await?;
let path = path_fut.await?;
let token = token_fut.await?;
if let Some(token) = token {
let token = token.into_inner();
if verify(&path, token.clone(), &state).await.is_ok() {
tx.send(ValidToken { token: token.token })
.map_err(|_| VerifyError)?;
} else {
drop(tx);
}
} else {
drop(tx);
}
service_fut.await
})
}
}
async fn verify(
path: &crate::AggregationPath,
token: crate::Token,
state: &crate::State,
) -> Result<(), VerifyError> {
let token_storage = state.store.token(path).await.map_err(|_| VerifyError)?;
let verified = web::block(move || token_storage.verify(&token))
.await
.map_err(|_| VerifyError)?;
if !verified {
return Err(VerifyError);
}
Ok(())
}