use crate::{profiles::Profile, State}; use actix_session::{Session, UserSession}; use actix_web::{ dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform}, http::StatusCode, web::Data, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError, }; use event_listener::Event; use futures_core::future::LocalBoxFuture; use futures_util::future::{ok, Ready}; use hyaenidae_accounts::Authenticated; use std::{ cell::{Cell, RefCell}, rc::Rc, }; use uuid::Uuid; impl FromRequest for Profile { type Config = (); type Error = actix_web::Error; type Future = LocalBoxFuture<'static, Result>; fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let opt = req .extensions() .get::() .map(|extractor| (extractor.clone(), extractor.2.listen())); let user_fut = Authenticated::extract(&req); let state_fut = Data::::extract(&req); Box::pin(async move { let (extractor, listen_fut) = match opt { Some(tuple) => tuple, None => return Err(ServerError("/404".to_owned()).into()), }; if let Some(profile) = extractor.0.borrow().as_ref() { return Ok(profile.clone()); } if !extractor.1.get() { listen_fut.await; if let Some(profile) = extractor.0.borrow().as_ref() { return Ok(profile.clone()); } } let state = state_fut.await?; let user_id = user_fut.await?.user().id(); let has_profiles = state .profiles .store .profiles .for_local(user_id) .next() .is_some(); let redirect = if has_profiles { "/profiles/change" } else { "/profiles/create/handle" }; Err(ServerError(redirect.to_owned()).into()) }) } } #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] pub(super) struct ProfileData { id: Uuid, } impl ProfileData { pub(super) fn set_data(id: Uuid, session: &Session) -> Result<(), SessionError> { session .set("profile-data", ProfileData { id }) .map_err(|_| SessionError::Set) } fn data(session: &Session) -> Result, actix_web::Error> { Ok(session.get("profile-data")?) } } #[derive(Debug, thiserror::Error)] pub(super) enum SessionError { #[error("Error setting key")] Set, } #[derive(Debug, thiserror::Error)] #[error("Redirecting to {0}")] struct ServerError(String); impl ResponseError for ServerError { fn status_code(&self) -> StatusCode { StatusCode::SEE_OTHER } fn error_response(&self) -> HttpResponse { HttpResponse::SeeOther() .header("Location", self.0.clone()) .finish() } } pub(crate) struct CurrentProfile(pub State); pub(crate) struct ProfileMiddleware(State, S); #[derive(Clone)] struct ProfileExtractor(Rc>>, Rc>, Rc); struct ProfileDropGuard(Rc); fn profile_extractor() -> (ProfileExtractor, ProfileDropGuard) { let event = Rc::new(Event::new()); let state = Rc::new(RefCell::new(None)); let flag = Rc::new(Cell::new(false)); ( ProfileExtractor(state, flag, Rc::clone(&event)), ProfileDropGuard(event), ) } impl Drop for ProfileDropGuard { fn drop(&mut self) { self.0.notify(usize::MAX); } } impl Transform for CurrentProfile where S: Service, Error = actix_web::Error>, S::Future: 'static, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type InitError = (); type Transform = ProfileMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(ProfileMiddleware(self.0.clone(), service)) } } impl Service for ProfileMiddleware where S: Service, Error = actix_web::Error>, S::Future: 'static, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type Future = LocalBoxFuture<'static, Result>; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.1.poll_ready(cx) } fn call(&mut self, req: S::Request) -> Self::Future { let session = req.get_session(); let (extractor, drop_guard) = profile_extractor(); req.extensions_mut().insert(extractor.clone()); let (req, pl) = req.into_parts(); let user_fut = Authenticated::extract(&req); let req = ServiceRequest::from_parts(req, pl) .map_err(|_| ()) .expect("Request has been cloned"); let fut = self.1.call(req); let state = self.0.clone(); Box::pin(async move { let user_id = if let Ok(auth) = user_fut.await { auth.user().id() } else { return fut.await; }; if let Some(ProfileData { id: profile_id }) = ProfileData::data(&session)? { match Profile::from_id(profile_id, &state) { Ok(profile) => { if profile .inner .local_owner() .map(|id| id == user_id) .unwrap_or(false) { *extractor.0.borrow_mut() = Some(profile); } } Err(e) => { log::error!("Error fetching profile {}: {}", profile_id, e); return Err(ServerError("/500".to_owned()).into()); } } } extractor.1.set(true); drop(drop_guard); fut.await }) } }