diff --git a/README.md b/README.md index 15ef4ad..4ffb07e 100644 --- a/README.md +++ b/README.md @@ -253,9 +253,27 @@ Example: ### API pict-rs offers the following endpoints: -- `POST /image` for uploading an image. Uploaded content must be valid multipart/form-data with an +- `POST /image?{args}` for uploading an image. Uploaded content must be valid multipart/form-data with an image array located within the `images[]` key + The {args} query serves multiple purpose for image uploads. The first is to provide + request-level validations for the uploaded media. Available keys are as follows: + - max_width: maximum width, in pixels, allowed for the uploaded media + - max_height: maximum height, in pixels, allowed for the uploaded media + - max_area: maximum area, in pixels, allowed for the uploaded media + - max_frame_count: maximum number of frames permitted for animations and videos + - max_file_size: maximum size, in megabytes, allowed + - allow_image: whether to permit still images in the upload + - allow_animation: whether to permit animations in the upload + - allow_video: whether to permit video in the upload + + These validations apply in addition to the validations specified in the pict-rs configuration, + so uploaded media will be rejected if any of the validations fail. + + The second purpose for the {args} query is to provide preprocess steps for the uploaded image. + The format is the same as in the process.{ext} endpoint. The images uploaded with these steps + provided will be processed before saving. + This endpoint returns the following JSON structure on success with a 201 Created status ```json { @@ -294,7 +312,9 @@ pict-rs offers the following endpoints: "msg": "ok" } ``` -- `POST /image/backgrounded` Upload an image, like the `/image` endpoint, but don't wait to validate and process it. +- `POST /image/backgrounded?{args}` Upload an image, like the `/image` endpoint, but don't wait to validate and process it. + The {args} query is the same format is the inline image upload endpoint. + This endpoint returns the following JSON structure on success with a 202 Accepted status ```json { diff --git a/src/concurrent_processor.rs b/src/concurrent_processor.rs deleted file mode 100644 index b15b521..0000000 --- a/src/concurrent_processor.rs +++ /dev/null @@ -1,172 +0,0 @@ -use crate::{ - details::Details, - error::{Error, UploadError}, - repo::Hash, -}; - -use dashmap::{mapref::entry::Entry, DashMap}; -use flume::{r#async::RecvFut, Receiver, Sender}; -use std::{ - future::Future, - path::PathBuf, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use tracing::Span; - -type OutcomeReceiver = Receiver<(Details, Arc)>; - -type ProcessMapKey = (Hash, PathBuf); - -type ProcessMapInner = DashMap; - -#[derive(Debug, Default, Clone)] -pub(crate) struct ProcessMap { - process_map: Arc, -} - -impl ProcessMap { - pub(super) fn new() -> Self { - Self::default() - } - - pub(super) async fn process( - &self, - hash: Hash, - path: PathBuf, - fut: Fut, - ) -> Result<(Details, Arc), Error> - where - Fut: Future), Error>>, - { - let key = (hash.clone(), path.clone()); - - let (sender, receiver) = flume::bounded(1); - - let entry = self.process_map.entry(key.clone()); - - let (state, span) = match entry { - Entry::Vacant(vacant) => { - vacant.insert(receiver); - - let span = tracing::info_span!( - "Processing image", - hash = ?hash, - path = ?path, - completed = &tracing::field::Empty, - ); - - metrics::counter!(crate::init_metrics::PROCESS_MAP_INSERTED).increment(1); - - (CancelState::Sender { sender }, span) - } - Entry::Occupied(receiver) => { - let span = tracing::info_span!( - "Waiting for processed image", - hash = ?hash, - path = ?path, - ); - - let receiver = receiver.get().clone().into_recv_async(); - - (CancelState::Receiver { receiver }, span) - } - }; - - CancelSafeProcessor { - cancel_token: CancelToken { - span, - key, - state, - process_map: self.clone(), - }, - fut, - } - .await - } - - fn remove(&self, key: &ProcessMapKey) -> Option { - self.process_map.remove(key).map(|(_, v)| v) - } -} - -struct CancelToken { - span: Span, - key: ProcessMapKey, - state: CancelState, - process_map: ProcessMap, -} - -enum CancelState { - Sender { - sender: Sender<(Details, Arc)>, - }, - Receiver { - receiver: RecvFut<'static, (Details, Arc)>, - }, -} - -impl CancelState { - const fn is_sender(&self) -> bool { - matches!(self, Self::Sender { .. }) - } -} - -pin_project_lite::pin_project! { - struct CancelSafeProcessor { - cancel_token: CancelToken, - - #[pin] - fut: F, - } -} - -impl Future for CancelSafeProcessor -where - F: Future), Error>>, -{ - type Output = Result<(Details, Arc), Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - let span = &this.cancel_token.span; - let process_map = &this.cancel_token.process_map; - let state = &mut this.cancel_token.state; - let key = &this.cancel_token.key; - let fut = this.fut; - - span.in_scope(|| match state { - CancelState::Sender { sender } => { - let res = std::task::ready!(fut.poll(cx)); - - if process_map.remove(key).is_some() { - metrics::counter!(crate::init_metrics::PROCESS_MAP_REMOVED).increment(1); - } - - if let Ok(tup) = &res { - let _ = sender.try_send(tup.clone()); - } - - Poll::Ready(res) - } - CancelState::Receiver { ref mut receiver } => Pin::new(receiver) - .poll(cx) - .map(|res| res.map_err(|_| UploadError::Canceled.into())), - }) - } -} - -impl Drop for CancelToken { - fn drop(&mut self) { - if self.state.is_sender() { - let completed = self.process_map.remove(&self.key).is_none(); - self.span.record("completed", completed); - - if !completed { - metrics::counter!(crate::init_metrics::PROCESS_MAP_REMOVED).increment(1); - } - } - } -} diff --git a/src/generate.rs b/src/generate.rs index 18d03bd..c270186 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -2,18 +2,17 @@ mod ffmpeg; mod magick; use crate::{ - concurrent_processor::ProcessMap, details::Details, error::{Error, UploadError}, formats::{ImageFormat, InputProcessableFormat, InternalVideoFormat, ProcessableFormat}, future::{WithMetrics, WithPollTimer, WithTimeout}, - repo::{Hash, VariantAlreadyExists}, + repo::{Hash, NotificationEntry, VariantAlreadyExists}, state::State, store::Store, }; use std::{ - path::PathBuf, + future::Future, sync::Arc, time::{Duration, Instant}, }; @@ -48,13 +47,12 @@ impl Drop for MetricsGuard { } } -#[tracing::instrument(skip(state, process_map, original_details, hash))] +#[tracing::instrument(skip(state, original_details, hash))] pub(crate) async fn generate( state: &State, - process_map: &ProcessMap, format: InputProcessableFormat, - thumbnail_path: PathBuf, - thumbnail_args: Vec, + variant: String, + variant_args: Vec, original_details: &Details, hash: Hash, ) -> Result<(Details, Arc), Error> { @@ -67,25 +65,122 @@ pub(crate) async fn generate( Ok((original_details.clone(), identifier)) } else { - let process_fut = process( - state, - format, - thumbnail_path.clone(), - thumbnail_args, - original_details, - hash.clone(), - ) - .with_poll_timer("process-future"); + let mut attempts = 0; + let tup = loop { + if attempts > 2 { + return Err(UploadError::ProcessTimeout.into()); + } - let (details, identifier) = process_map - .process(hash, thumbnail_path, process_fut) - .with_poll_timer("process-map-future") - .with_timeout(Duration::from_secs(state.config.media.process_timeout * 4)) - .with_metrics(crate::init_metrics::GENERATE_PROCESS) - .await - .map_err(|_| UploadError::ProcessTimeout)??; + match state + .repo + .claim_variant_processing_rights(hash.clone(), variant.clone()) + .await? + { + Ok(()) => { + // process + let process_future = process( + state, + format, + variant.clone(), + variant_args, + original_details, + hash.clone(), + ) + .with_poll_timer("process-future"); - Ok((details, identifier)) + let res = heartbeat(state, hash.clone(), variant.clone(), process_future) + .with_poll_timer("heartbeat-future") + .with_timeout(Duration::from_secs(state.config.media.process_timeout * 4)) + .with_metrics(crate::init_metrics::GENERATE_PROCESS) + .await + .map_err(|_| Error::from(UploadError::ProcessTimeout)); + + state + .repo + .notify_variant(hash.clone(), variant.clone()) + .await?; + + break res???; + } + Err(entry) => { + if let Some(tuple) = wait_timeout( + hash.clone(), + variant.clone(), + entry, + state, + Duration::from_secs(20), + ) + .await? + { + break tuple; + } + + attempts += 1; + } + } + }; + + Ok(tup) + } +} + +pub(crate) async fn wait_timeout( + hash: Hash, + variant: String, + mut entry: NotificationEntry, + state: &State, + timeout: Duration, +) -> Result)>, Error> { + let notified = entry.notified_timeout(timeout); + + if let Some(identifier) = state + .repo + .variant_identifier(hash.clone(), variant.clone()) + .await? + { + let details = crate::ensure_details_identifier(state, &identifier).await?; + + return Ok(Some((details, identifier))); + } + + match notified.await { + Ok(()) => tracing::debug!("notified"), + Err(_) => tracing::debug!("timeout"), + } + + Ok(None) +} + +async fn heartbeat( + state: &State, + hash: Hash, + variant: String, + future: impl Future, +) -> Result { + let repo = state.repo.clone(); + + let handle = crate::sync::abort_on_drop(crate::sync::spawn("heartbeat-task", async move { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + + loop { + interval.tick().await; + + if let Err(e) = repo.variant_heartbeat(hash.clone(), variant.clone()).await { + break Error::from(e); + } + } + })); + + let future = std::pin::pin!(future); + + tokio::select! { + biased; + output = future => { + Ok(output) + } + res = handle => { + Err(res.map_err(|_| UploadError::Canceled)?) + } } } @@ -93,8 +188,8 @@ pub(crate) async fn generate( async fn process( state: &State, output_format: InputProcessableFormat, - thumbnail_path: PathBuf, - thumbnail_args: Vec, + variant: String, + variant_args: Vec, original_details: &Details, hash: Hash, ) -> Result<(Details, Arc), Error> { @@ -120,7 +215,7 @@ async fn process( let stream = state.store.to_stream(&identifier, None, None).await?; let bytes = - crate::magick::process_image_command(state, thumbnail_args, input_format, format, quality) + crate::magick::process_image_command(state, variant_args, input_format, format, quality) .await? .drive_with_stream(stream) .into_bytes_stream() @@ -142,19 +237,21 @@ async fn process( ) .await?; - if let Err(VariantAlreadyExists) = state + let identifier = if let Err(VariantAlreadyExists) = state .repo - .relate_variant_identifier( - hash, - thumbnail_path.to_string_lossy().to_string(), - &identifier, - ) + .relate_variant_identifier(hash.clone(), variant.clone(), &identifier) .await? { state.store.remove(&identifier).await?; - } - - state.repo.relate_details(&identifier, &details).await?; + state + .repo + .variant_identifier(hash, variant) + .await? + .ok_or(UploadError::MissingIdentifier)? + } else { + state.repo.relate_details(&identifier, &details).await?; + identifier + }; guard.disarm(); diff --git a/src/lib.rs b/src/lib.rs index 89da947..9e29df8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,6 @@ mod backgrounded; mod blurhash; mod bytes_stream; -mod concurrent_processor; mod config; mod details; mod discover; @@ -57,7 +56,6 @@ use state::State; use std::{ marker::PhantomData, path::Path, - path::PathBuf, rc::Rc, sync::{Arc, OnceLock}, time::{Duration, SystemTime}, @@ -71,7 +69,6 @@ use tracing_actix_web::TracingLogger; use self::{ backgrounded::Backgrounded, - concurrent_processor::ProcessMap, config::{Configuration, Operation}, details::Details, either::Either, @@ -123,6 +120,7 @@ async fn ensure_details( ensure_details_identifier(state, &identifier).await } +#[tracing::instrument(skip(state))] async fn ensure_details_identifier( state: &State, identifier: &Arc, @@ -775,7 +773,7 @@ fn prepare_process( config: &Configuration, operations: Vec<(String, String)>, ext: &str, -) -> Result<(InputProcessableFormat, PathBuf, Vec), Error> { +) -> Result<(InputProcessableFormat, String, Vec), Error> { let operations = operations .into_iter() .filter(|(k, _)| config.media.filters.contains(&k.to_lowercase())) @@ -785,10 +783,9 @@ fn prepare_process( .parse::() .map_err(|_| UploadError::UnsupportedProcessExtension)?; - let (thumbnail_path, thumbnail_args) = - self::processor::build_chain(&operations, &format.to_string())?; + let (variant, variant_args) = self::processor::build_chain(&operations, &format.to_string())?; - Ok((format, thumbnail_path, thumbnail_args)) + Ok((format, variant, variant_args)) } #[tracing::instrument(name = "Fetching derived details", skip(state))] @@ -799,7 +796,7 @@ async fn process_details( ) -> Result { let alias = alias_from_query(source.into(), &state).await?; - let (_, thumbnail_path, _) = prepare_process(&state.config, operations, ext.as_str())?; + let (_, variant, _) = prepare_process(&state.config, operations, ext.as_str())?; let hash = state .repo @@ -807,18 +804,16 @@ async fn process_details( .await? .ok_or(UploadError::MissingAlias)?; - let thumbnail_string = thumbnail_path.to_string_lossy().to_string(); - if !state.config.server.read_only { state .repo - .accessed_variant(hash.clone(), thumbnail_string.clone()) + .accessed_variant(hash.clone(), variant.clone()) .await?; } let identifier = state .repo - .variant_identifier(hash, thumbnail_string) + .variant_identifier(hash, variant) .await? .ok_or(UploadError::MissingAlias)?; @@ -848,20 +843,16 @@ async fn not_found_hash(repo: &ArcRepo) -> Result, Error> } /// Process files -#[tracing::instrument(name = "Serving processed image", skip(state, process_map))] +#[tracing::instrument(name = "Serving processed image", skip(state))] async fn process( range: Option>, web::Query(ProcessQuery { source, operations }): web::Query, ext: web::Path, state: web::Data>, - process_map: web::Data, ) -> Result { let alias = proxy_alias_from_query(source.into(), &state).await?; - let (format, thumbnail_path, thumbnail_args) = - prepare_process(&state.config, operations, ext.as_str())?; - - let path_string = thumbnail_path.to_string_lossy().to_string(); + let (format, variant, variant_args) = prepare_process(&state.config, operations, ext.as_str())?; let (hash, alias, not_found) = if let Some(hash) = state.repo.hash(&alias).await? { (hash, alias, false) @@ -876,13 +867,13 @@ async fn process( if !state.config.server.read_only { state .repo - .accessed_variant(hash.clone(), path_string.clone()) + .accessed_variant(hash.clone(), variant.clone()) .await?; } let identifier_opt = state .repo - .variant_identifier(hash.clone(), path_string) + .variant_identifier(hash.clone(), variant.clone()) .await?; let (details, identifier) = if let Some(identifier) = identifier_opt { @@ -894,18 +885,34 @@ async fn process( return Err(UploadError::ReadOnly.into()); } - let original_details = ensure_details(&state, &alias).await?; + queue_generate(&state.repo, format, alias, variant.clone(), variant_args).await?; - generate::generate( - &state, - &process_map, - format, - thumbnail_path, - thumbnail_args, - &original_details, - hash, - ) - .await? + let mut attempts = 0; + loop { + if attempts > 6 { + return Err(UploadError::ProcessTimeout.into()); + } + + let entry = state + .repo + .variant_waiter(hash.clone(), variant.clone()) + .await?; + + let opt = generate::wait_timeout( + hash.clone(), + variant.clone(), + entry, + &state, + Duration::from_secs(5), + ) + .await?; + + if let Some(tuple) = opt { + break tuple; + } + + attempts += 1; + } }; if let Some(public_url) = state.store.public_url(&identifier) { @@ -936,9 +943,8 @@ async fn process_head( } }; - let (_, thumbnail_path, _) = prepare_process(&state.config, operations, ext.as_str())?; + let (_, variant, _) = prepare_process(&state.config, operations, ext.as_str())?; - let path_string = thumbnail_path.to_string_lossy().to_string(); let Some(hash) = state.repo.hash(&alias).await? else { // Invalid alias return Ok(HttpResponse::NotFound().finish()); @@ -947,14 +953,11 @@ async fn process_head( if !state.config.server.read_only { state .repo - .accessed_variant(hash.clone(), path_string.clone()) + .accessed_variant(hash.clone(), variant.clone()) .await?; } - let identifier_opt = state - .repo - .variant_identifier(hash.clone(), path_string) - .await?; + let identifier_opt = state.repo.variant_identifier(hash.clone(), variant).await?; if let Some(identifier) = identifier_opt { let details = ensure_details_identifier(&state, &identifier).await?; @@ -973,7 +976,7 @@ async fn process_head( /// Process files #[tracing::instrument(name = "Spawning image process", skip(state))] -async fn process_backgrounded( +async fn process_backgrounded( web::Query(ProcessQuery { source, operations }): web::Query, ext: web::Path, state: web::Data>, @@ -990,10 +993,9 @@ async fn process_backgrounded( } }; - let (target_format, process_path, process_args) = + let (target_format, variant, variant_args) = prepare_process(&state.config, operations, ext.as_str())?; - let path_string = process_path.to_string_lossy().to_string(); let Some(hash) = state.repo.hash(&source).await? else { // Invalid alias return Ok(HttpResponse::BadRequest().finish()); @@ -1001,7 +1003,7 @@ async fn process_backgrounded( let identifier_opt = state .repo - .variant_identifier(hash.clone(), path_string) + .variant_identifier(hash.clone(), variant.clone()) .await?; if identifier_opt.is_some() { @@ -1012,14 +1014,7 @@ async fn process_backgrounded( return Err(UploadError::ReadOnly.into()); } - queue_generate( - &state.repo, - target_format, - source, - process_path, - process_args, - ) - .await?; + queue_generate(&state.repo, target_format, source, variant, variant_args).await?; Ok(HttpResponse::Accepted().finish()) } @@ -1591,14 +1586,12 @@ fn json_config() -> web::JsonConfig { fn configure_endpoints( config: &mut web::ServiceConfig, state: State, - process_map: ProcessMap, extra_config: F, ) { config .app_data(query_config()) .app_data(json_config()) .app_data(web::Data::new(state.clone())) - .app_data(web::Data::new(process_map.clone())) .route("/healthz", web::get().to(healthz::)) .service( web::scope("/image") @@ -1706,12 +1699,12 @@ fn spawn_cleanup(state: State) { }); } -fn spawn_workers(state: State, process_map: ProcessMap) +fn spawn_workers(state: State) where S: Store + 'static, { crate::sync::spawn("cleanup-worker", queue::process_cleanup(state.clone())); - crate::sync::spawn("process-worker", queue::process_images(state, process_map)); + crate::sync::spawn("process-worker", queue::process_images(state)); } fn watch_keys(tls: Tls, sender: ChannelSender) -> DropHandle<()> { @@ -1737,8 +1730,6 @@ async fn launch< state: State, extra_config: F, ) -> color_eyre::Result<()> { - let process_map = ProcessMap::new(); - let address = state.config.server.address; let tls = Tls::from_config(&state.config); @@ -1748,18 +1739,15 @@ async fn launch< let server = HttpServer::new(move || { let extra_config = extra_config.clone(); let state = state.clone(); - let process_map = process_map.clone(); - spawn_workers(state.clone(), process_map.clone()); + spawn_workers(state.clone()); App::new() .wrap(TracingLogger::default()) .wrap(Deadline) .wrap(Metrics) .wrap(Payload::new()) - .configure(move |sc| { - configure_endpoints(sc, state.clone(), process_map.clone(), extra_config) - }) + .configure(move |sc| configure_endpoints(sc, state.clone(), extra_config)) }); if let Some(tls) = tls { diff --git a/src/processor.rs b/src/processor.rs index dec45e8..d7e3121 100644 --- a/src/processor.rs +++ b/src/processor.rs @@ -91,7 +91,7 @@ impl ResizeKind { pub(crate) fn build_chain( args: &[(String, String)], ext: &str, -) -> Result<(PathBuf, Vec), Error> { +) -> Result<(String, Vec), Error> { fn parse(key: &str, value: &str) -> Result, Error> { if key == P::NAME { return Ok(Some(P::parse(key, value).ok_or(UploadError::ParsePath)?)); @@ -122,7 +122,7 @@ pub(crate) fn build_chain( path.push(ext); - Ok((path, args)) + Ok((path.to_string_lossy().to_string(), args)) } impl Processor for Identity { diff --git a/src/queue.rs b/src/queue.rs index 96cebf8..f20b457 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,5 +1,4 @@ use crate::{ - concurrent_processor::ProcessMap, error::{Error, UploadError}, formats::InputProcessableFormat, future::{LocalBoxFuture, WithPollTimer}, @@ -12,7 +11,6 @@ use crate::{ use std::{ ops::Deref, - path::PathBuf, sync::Arc, time::{Duration, Instant}, }; @@ -63,7 +61,7 @@ enum Process { Generate { target_format: InputProcessableFormat, source: Serde, - process_path: PathBuf, + process_path: String, process_args: Vec, }, } @@ -178,13 +176,13 @@ pub(crate) async fn queue_generate( repo: &ArcRepo, target_format: InputProcessableFormat, source: Alias, - process_path: PathBuf, + variant: String, process_args: Vec, ) -> Result<(), Error> { let job = serde_json::to_value(Process::Generate { target_format, source: Serde::new(source), - process_path, + process_path: variant, process_args, }) .map_err(UploadError::PushJob)?; @@ -196,8 +194,8 @@ pub(crate) async fn process_cleanup(state: State) { process_jobs(state, CLEANUP_QUEUE, cleanup::perform).await } -pub(crate) async fn process_images(state: State, process_map: ProcessMap) { - process_image_jobs(state, process_map, PROCESS_QUEUE, process::perform).await +pub(crate) async fn process_images(state: State) { + process_jobs(state, PROCESS_QUEUE, process::perform).await } struct MetricsGuard { @@ -357,7 +355,7 @@ where let (job_id, job) = state .repo .pop(queue, worker_id) - .with_poll_timer("pop-cleanup") + .with_poll_timer("pop-job") .await?; let guard = MetricsGuard::guard(worker_id, queue); @@ -369,99 +367,13 @@ where job_id, (callback)(state, job), ) - .with_poll_timer("cleanup-job-and-heartbeat") - .await; - - state - .repo - .complete_job(queue, worker_id, job_id, job_result(&res)) - .with_poll_timer("cleanup-job-complete") - .await?; - - res?; - - guard.disarm(); - - Ok(()) as Result<(), Error> - } - .instrument(tracing::info_span!("tick", %queue, %worker_id)) - .await?; - } -} - -async fn process_image_jobs( - state: State, - process_map: ProcessMap, - queue: &'static str, - callback: F, -) where - S: Store, - for<'a> F: Fn(&'a State, &'a ProcessMap, serde_json::Value) -> JobFuture<'a> + Copy, -{ - let worker_id = uuid::Uuid::new_v4(); - - loop { - tracing::trace!("process_image_jobs: looping"); - - crate::sync::cooperate().await; - - let res = image_job_loop(&state, &process_map, worker_id, queue, callback) - .with_poll_timer("image-job-loop") - .await; - - if let Err(e) = res { - tracing::warn!("Error processing jobs: {}", format!("{e}")); - tracing::warn!("{}", format!("{e:?}")); - - if e.is_disconnected() { - tokio::time::sleep(Duration::from_secs(10)).await; - } - - continue; - } - - break; - } -} - -async fn image_job_loop( - state: &State, - process_map: &ProcessMap, - worker_id: uuid::Uuid, - queue: &'static str, - callback: F, -) -> Result<(), Error> -where - S: Store, - for<'a> F: Fn(&'a State, &'a ProcessMap, serde_json::Value) -> JobFuture<'a> + Copy, -{ - loop { - tracing::trace!("image_job_loop: looping"); - - crate::sync::cooperate().await; - - async { - let (job_id, job) = state - .repo - .pop(queue, worker_id) - .with_poll_timer("pop-process") - .await?; - - let guard = MetricsGuard::guard(worker_id, queue); - - let res = heartbeat( - &state.repo, - queue, - worker_id, - job_id, - (callback)(state, process_map, job), - ) - .with_poll_timer("process-job-and-heartbeat") + .with_poll_timer("job-and-heartbeat") .await; state .repo .complete_job(queue, worker_id, job_id, job_result(&res)) + .with_poll_timer("job-complete") .await?; res?; diff --git a/src/queue/process.rs b/src/queue/process.rs index 76ff626..b2343d6 100644 --- a/src/queue/process.rs +++ b/src/queue/process.rs @@ -2,7 +2,6 @@ use time::Instant; use tracing::{Instrument, Span}; use crate::{ - concurrent_processor::ProcessMap, error::{Error, UploadError}, formats::InputProcessableFormat, future::WithPollTimer, @@ -14,15 +13,11 @@ use crate::{ store::Store, UploadQuery, }; -use std::{path::PathBuf, sync::Arc}; +use std::sync::Arc; use super::{JobContext, JobFuture, JobResult}; -pub(super) fn perform<'a, S>( - state: &'a State, - process_map: &'a ProcessMap, - job: serde_json::Value, -) -> JobFuture<'a> +pub(super) fn perform(state: &State, job: serde_json::Value) -> JobFuture<'_> where S: Store + 'static, { @@ -58,7 +53,6 @@ where } => { generate( state, - process_map, target_format, Serde::into_inner(source), process_path, @@ -178,13 +172,12 @@ where Ok(()) } -#[tracing::instrument(skip(state, process_map, process_path, process_args))] +#[tracing::instrument(skip(state, variant, process_args))] async fn generate( state: &State, - process_map: &ProcessMap, target_format: InputProcessableFormat, source: Alias, - process_path: PathBuf, + variant: String, process_args: Vec, ) -> JobResult { let hash = state @@ -195,10 +188,9 @@ async fn generate( .ok_or(UploadError::MissingAlias) .abort()?; - let path_string = process_path.to_string_lossy().to_string(); let identifier_opt = state .repo - .variant_identifier(hash.clone(), path_string) + .variant_identifier(hash.clone(), variant.clone()) .await .retry()?; @@ -211,9 +203,8 @@ async fn generate( crate::generate::generate( state, - process_map, target_format, - process_path, + variant, process_args, &original_details, hash, diff --git a/src/repo.rs b/src/repo.rs index ed87b7a..9cbccc5 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -3,6 +3,7 @@ mod delete_token; mod hash; mod metrics; mod migrate; +mod notification_map; use crate::{ config, @@ -23,6 +24,7 @@ pub(crate) use alias::Alias; pub(crate) use delete_token::DeleteToken; pub(crate) use hash::Hash; pub(crate) use migrate::{migrate_04, migrate_repo}; +pub(crate) use notification_map::NotificationEntry; pub(crate) type ArcRepo = Arc; @@ -103,6 +105,7 @@ pub(crate) trait FullRepo: + AliasRepo + QueueRepo + HashRepo + + VariantRepo + StoreMigrationRepo + AliasAccessRepo + VariantAccessRepo @@ -653,20 +656,6 @@ pub(crate) trait HashRepo: BaseRepo { async fn identifier(&self, hash: Hash) -> Result>, RepoError>; - async fn relate_variant_identifier( - &self, - hash: Hash, - variant: String, - identifier: &Arc, - ) -> Result, RepoError>; - async fn variant_identifier( - &self, - hash: Hash, - variant: String, - ) -> Result>, RepoError>; - async fn variants(&self, hash: Hash) -> Result)>, RepoError>; - async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>; - async fn relate_blurhash(&self, hash: Hash, blurhash: Arc) -> Result<(), RepoError>; async fn blurhash(&self, hash: Hash) -> Result>, RepoError>; @@ -726,6 +715,96 @@ where T::identifier(self, hash).await } + async fn relate_blurhash(&self, hash: Hash, blurhash: Arc) -> Result<(), RepoError> { + T::relate_blurhash(self, hash, blurhash).await + } + + async fn blurhash(&self, hash: Hash) -> Result>, RepoError> { + T::blurhash(self, hash).await + } + + async fn relate_motion_identifier( + &self, + hash: Hash, + identifier: &Arc, + ) -> Result<(), RepoError> { + T::relate_motion_identifier(self, hash, identifier).await + } + + async fn motion_identifier(&self, hash: Hash) -> Result>, RepoError> { + T::motion_identifier(self, hash).await + } + + async fn cleanup_hash(&self, hash: Hash) -> Result<(), RepoError> { + T::cleanup_hash(self, hash).await + } +} + +#[async_trait::async_trait(?Send)] +pub(crate) trait VariantRepo: BaseRepo { + async fn claim_variant_processing_rights( + &self, + hash: Hash, + variant: String, + ) -> Result, RepoError>; + + async fn variant_waiter( + &self, + hash: Hash, + variant: String, + ) -> Result; + + async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError>; + + async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>; + + async fn relate_variant_identifier( + &self, + hash: Hash, + variant: String, + identifier: &Arc, + ) -> Result, RepoError>; + + async fn variant_identifier( + &self, + hash: Hash, + variant: String, + ) -> Result>, RepoError>; + + async fn variants(&self, hash: Hash) -> Result)>, RepoError>; + + async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>; +} + +#[async_trait::async_trait(?Send)] +impl VariantRepo for Arc +where + T: VariantRepo, +{ + async fn claim_variant_processing_rights( + &self, + hash: Hash, + variant: String, + ) -> Result, RepoError> { + T::claim_variant_processing_rights(self, hash, variant).await + } + + async fn variant_waiter( + &self, + hash: Hash, + variant: String, + ) -> Result { + T::variant_waiter(self, hash, variant).await + } + + async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> { + T::variant_heartbeat(self, hash, variant).await + } + + async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> { + T::notify_variant(self, hash, variant).await + } + async fn relate_variant_identifier( &self, hash: Hash, @@ -750,30 +829,6 @@ where async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> { T::remove_variant(self, hash, variant).await } - - async fn relate_blurhash(&self, hash: Hash, blurhash: Arc) -> Result<(), RepoError> { - T::relate_blurhash(self, hash, blurhash).await - } - - async fn blurhash(&self, hash: Hash) -> Result>, RepoError> { - T::blurhash(self, hash).await - } - - async fn relate_motion_identifier( - &self, - hash: Hash, - identifier: &Arc, - ) -> Result<(), RepoError> { - T::relate_motion_identifier(self, hash, identifier).await - } - - async fn motion_identifier(&self, hash: Hash) -> Result>, RepoError> { - T::motion_identifier(self, hash).await - } - - async fn cleanup_hash(&self, hash: Hash) -> Result<(), RepoError> { - T::cleanup_hash(self, hash).await - } } #[async_trait::async_trait(?Send)] diff --git a/src/repo/notification_map.rs b/src/repo/notification_map.rs new file mode 100644 index 0000000..5e2b0d0 --- /dev/null +++ b/src/repo/notification_map.rs @@ -0,0 +1,94 @@ +use dashmap::DashMap; +use std::{ + future::Future, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Weak, + }, + time::Duration, +}; +use tokio::sync::Notify; + +use crate::future::WithTimeout; + +type Map = Arc, Weak>>; + +#[derive(Clone)] +pub(super) struct NotificationMap { + map: Map, +} + +pub(crate) struct NotificationEntry { + inner: Arc, +} + +struct NotificationEntryInner { + key: Arc, + map: Map, + notify: Notify, + armed: AtomicBool, +} + +impl NotificationMap { + pub(super) fn new() -> Self { + Self { + map: Arc::new(DashMap::new()), + } + } + + pub(super) fn register_interest(&self, key: Arc) -> NotificationEntry { + let new_entry = Arc::new(NotificationEntryInner { + key: key.clone(), + map: self.map.clone(), + notify: crate::sync::bare_notify(), + armed: AtomicBool::new(false), + }); + + let mut key_entry = self + .map + .entry(key) + .or_insert_with(|| Arc::downgrade(&new_entry)); + + let upgraded_entry = key_entry.value().upgrade(); + + let inner = if let Some(entry) = upgraded_entry { + entry + } else { + *key_entry.value_mut() = Arc::downgrade(&new_entry); + new_entry + }; + + inner.armed.store(true, Ordering::Release); + + NotificationEntry { inner } + } + + pub(super) fn notify(&self, key: &str) { + if let Some(notifier) = self.map.get(key).and_then(|v| v.upgrade()) { + notifier.notify.notify_waiters(); + } + } +} + +impl NotificationEntry { + pub(crate) fn notified_timeout( + &mut self, + duration: Duration, + ) -> impl Future> + '_ { + self.inner.notify.notified().with_timeout(duration) + } +} + +impl Default for NotificationMap { + fn default() -> Self { + Self::new() + } +} + +impl Drop for NotificationEntryInner { + fn drop(&mut self) { + if self.armed.load(Ordering::Acquire) { + self.map.remove(&self.key); + } + } +} diff --git a/src/repo/postgres.rs b/src/repo/postgres.rs index faa5792..ebb4914 100644 --- a/src/repo/postgres.rs +++ b/src/repo/postgres.rs @@ -4,6 +4,7 @@ mod schema; use std::{ collections::{BTreeSet, VecDeque}, + future::Future, path::PathBuf, sync::{ atomic::{AtomicU64, Ordering}, @@ -43,10 +44,11 @@ use self::job_status::JobStatus; use super::{ metrics::{PopMetricsGuard, PushMetricsGuard, WaitMetricsGuard}, + notification_map::{NotificationEntry, NotificationMap}, Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, DetailsRepo, FullRepo, Hash, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash, ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo, - UploadResult, VariantAccessRepo, VariantAlreadyExists, + UploadResult, VariantAccessRepo, VariantAlreadyExists, VariantRepo, }; #[derive(Clone)] @@ -62,6 +64,7 @@ struct Inner { notifier_pool: Pool, queue_notifications: DashMap>, upload_notifications: DashMap>, + keyed_notifications: NotificationMap, } struct UploadInterest { @@ -81,6 +84,10 @@ struct UploadNotifierState<'a> { inner: &'a Inner, } +struct KeyedNotifierState<'a> { + inner: &'a Inner, +} + #[derive(Debug, thiserror::Error)] pub(crate) enum ConnectPostgresError { #[error("Failed to connect to postgres for migrations")] @@ -102,7 +109,7 @@ pub(crate) enum PostgresError { Pool(#[source] RunError), #[error("Error in database")] - Diesel(#[source] diesel::result::Error), + Diesel(#[from] diesel::result::Error), #[error("Error deserializing hex value")] Hex(#[source] hex::FromHexError), @@ -331,6 +338,7 @@ impl PostgresRepo { notifier_pool, queue_notifications: DashMap::new(), upload_notifications: DashMap::new(), + keyed_notifications: NotificationMap::new(), }); let handle = crate::sync::abort_on_drop(crate::sync::spawn_sendable( @@ -363,8 +371,97 @@ impl PostgresRepo { .with_poll_timer("postgres-get-notifier-connection") .await } + + async fn insert_keyed_notifier( + &self, + input_key: &str, + ) -> Result, PostgresError> { + use schema::keyed_notifications::dsl::*; + + let mut conn = self.get_connection().await?; + + let timestamp = to_primitive(time::OffsetDateTime::now_utc()); + + diesel::delete(keyed_notifications) + .filter(heartbeat.le(timestamp.saturating_sub(time::Duration::minutes(2)))) + .execute(&mut conn) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; + + let res = diesel::insert_into(keyed_notifications) + .values(key.eq(input_key)) + .execute(&mut conn) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)?; + + match res { + Ok(_) => Ok(Ok(())), + Err(diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UniqueViolation, + _, + )) => Ok(Err(AlreadyInserted)), + Err(e) => Err(PostgresError::Diesel(e)), + } + } + + async fn keyed_notifier_heartbeat(&self, input_key: &str) -> Result<(), PostgresError> { + use schema::keyed_notifications::dsl::*; + + let mut conn = self.get_connection().await?; + + let timestamp = to_primitive(time::OffsetDateTime::now_utc()); + + diesel::update(keyed_notifications) + .filter(key.eq(input_key)) + .set(heartbeat.eq(timestamp)) + .execute(&mut conn) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; + + Ok(()) + } + + fn listen_on_key(&self, key: Arc) -> NotificationEntry { + self.inner.keyed_notifications.register_interest(key) + } + + async fn register_interest(&self) -> Result<(), PostgresError> { + let mut notifier_conn = self.get_notifier_connection().await?; + + diesel::sql_query("LISTEN keyed_notification_channel;") + .execute(&mut notifier_conn) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; + + Ok(()) + } + + async fn clear_keyed_notifier(&self, input_key: String) -> Result<(), PostgresError> { + use schema::keyed_notifications::dsl::*; + + let mut conn = self.get_connection().await?; + + diesel::delete(keyed_notifications) + .filter(key.eq(input_key)) + .execute(&mut conn) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; + + Ok(()) + } } +struct AlreadyInserted; + struct GetConnectionMetricsGuard { start: Instant, armed: bool, @@ -437,13 +534,15 @@ impl Inner { } impl UploadInterest { - async fn notified_timeout(&self, timeout: Duration) -> Result<(), tokio::time::error::Elapsed> { + fn notified_timeout( + &self, + timeout: Duration, + ) -> impl Future> + '_ { self.interest .as_ref() .expect("interest exists") .notified() .with_timeout(timeout) - .await } } @@ -511,6 +610,12 @@ impl<'a> UploadNotifierState<'a> { } } +impl<'a> KeyedNotifierState<'a> { + fn handle(&self, key: &str) { + self.inner.keyed_notifications.notify(key); + } +} + type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; type ConfigFn = Box BoxFuture<'_, ConnectionResult> + Send + Sync + 'static>; @@ -529,6 +634,8 @@ async fn delegate_notifications( let upload_notifier_state = UploadNotifierState { inner: &inner }; + let keyed_notifier_state = KeyedNotifierState { inner: &inner }; + while let Ok(notification) = receiver.recv_async().await { tracing::trace!("delegate_notifications: looping"); metrics::counter!(crate::init_metrics::POSTGRES_NOTIFICATION).increment(1); @@ -542,6 +649,10 @@ async fn delegate_notifications( // new upload finished upload_notifier_state.handle(notification.payload()); } + "keyed_notification_channel" => { + // new keyed notification + keyed_notifier_state.handle(notification.payload()); + } channel => { tracing::info!( "Unhandled postgres notification: {channel}: {}", @@ -863,110 +974,6 @@ impl HashRepo for PostgresRepo { Ok(opt.map(Arc::from)) } - #[tracing::instrument(level = "debug", skip(self))] - async fn relate_variant_identifier( - &self, - input_hash: Hash, - input_variant: String, - input_identifier: &Arc, - ) -> Result, RepoError> { - use schema::variants::dsl::*; - - let mut conn = self.get_connection().await?; - - let res = diesel::insert_into(variants) - .values(( - hash.eq(&input_hash), - variant.eq(&input_variant), - identifier.eq(input_identifier.as_ref()), - )) - .execute(&mut conn) - .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_RELATE_VARIANT_IDENTIFIER) - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)?; - - match res { - Ok(_) => Ok(Ok(())), - Err(diesel::result::Error::DatabaseError( - diesel::result::DatabaseErrorKind::UniqueViolation, - _, - )) => Ok(Err(VariantAlreadyExists)), - Err(e) => Err(PostgresError::Diesel(e).into()), - } - } - - #[tracing::instrument(level = "debug", skip(self))] - async fn variant_identifier( - &self, - input_hash: Hash, - input_variant: String, - ) -> Result>, RepoError> { - use schema::variants::dsl::*; - - let mut conn = self.get_connection().await?; - - let opt = variants - .select(identifier) - .filter(hash.eq(&input_hash)) - .filter(variant.eq(&input_variant)) - .get_result::(&mut conn) - .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_IDENTIFIER) - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)? - .optional() - .map_err(PostgresError::Diesel)? - .map(Arc::from); - - Ok(opt) - } - - #[tracing::instrument(level = "debug", skip(self))] - async fn variants(&self, input_hash: Hash) -> Result)>, RepoError> { - use schema::variants::dsl::*; - - let mut conn = self.get_connection().await?; - - let vec = variants - .select((variant, identifier)) - .filter(hash.eq(&input_hash)) - .get_results::<(String, String)>(&mut conn) - .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_FOR_HASH) - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)? - .map_err(PostgresError::Diesel)? - .into_iter() - .map(|(s, i)| (s, Arc::from(i))) - .collect(); - - Ok(vec) - } - - #[tracing::instrument(level = "debug", skip(self))] - async fn remove_variant( - &self, - input_hash: Hash, - input_variant: String, - ) -> Result<(), RepoError> { - use schema::variants::dsl::*; - - let mut conn = self.get_connection().await?; - - diesel::delete(variants) - .filter(hash.eq(&input_hash)) - .filter(variant.eq(&input_variant)) - .execute(&mut conn) - .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_REMOVE) - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)? - .map_err(PostgresError::Diesel)?; - - Ok(()) - } - #[tracing::instrument(level = "debug", skip(self))] async fn relate_blurhash( &self, @@ -1083,6 +1090,167 @@ impl HashRepo for PostgresRepo { } } +#[async_trait::async_trait(?Send)] +impl VariantRepo for PostgresRepo { + #[tracing::instrument(level = "debug", skip(self))] + async fn claim_variant_processing_rights( + &self, + hash: Hash, + variant: String, + ) -> Result, RepoError> { + let key = Arc::from(format!("{}{variant}", hash.to_base64())); + let entry = self.listen_on_key(Arc::clone(&key)); + + self.register_interest().await?; + + if self + .variant_identifier(hash.clone(), variant.clone()) + .await? + .is_some() + { + return Ok(Err(entry)); + } + + match self.insert_keyed_notifier(&key).await? { + Ok(()) => Ok(Ok(())), + Err(AlreadyInserted) => Ok(Err(entry)), + } + } + + async fn variant_waiter( + &self, + hash: Hash, + variant: String, + ) -> Result { + let key = Arc::from(format!("{}{variant}", hash.to_base64())); + let entry = self.listen_on_key(key); + + self.register_interest().await?; + + Ok(entry) + } + + #[tracing::instrument(level = "debug", skip(self))] + async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> { + let key = format!("{}{variant}", hash.to_base64()); + + self.keyed_notifier_heartbeat(&key) + .await + .map_err(Into::into) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> { + let key = format!("{}{variant}", hash.to_base64()); + + self.clear_keyed_notifier(key).await.map_err(Into::into) + } + + #[tracing::instrument(level = "debug", skip(self))] + async fn relate_variant_identifier( + &self, + input_hash: Hash, + input_variant: String, + input_identifier: &Arc, + ) -> Result, RepoError> { + use schema::variants::dsl::*; + + let mut conn = self.get_connection().await?; + + let res = diesel::insert_into(variants) + .values(( + hash.eq(&input_hash), + variant.eq(&input_variant), + identifier.eq(input_identifier.to_string()), + )) + .execute(&mut conn) + .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_RELATE_VARIANT_IDENTIFIER) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)?; + + match res { + Ok(_) => Ok(Ok(())), + Err(diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UniqueViolation, + _, + )) => Ok(Err(VariantAlreadyExists)), + Err(e) => Err(PostgresError::Diesel(e).into()), + } + } + + #[tracing::instrument(level = "debug", skip(self))] + async fn variant_identifier( + &self, + input_hash: Hash, + input_variant: String, + ) -> Result>, RepoError> { + use schema::variants::dsl::*; + + let mut conn = self.get_connection().await?; + + let opt = variants + .select(identifier) + .filter(hash.eq(&input_hash)) + .filter(variant.eq(&input_variant)) + .get_result::(&mut conn) + .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_IDENTIFIER) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .optional() + .map_err(PostgresError::Diesel)? + .map(Arc::from); + + Ok(opt) + } + + #[tracing::instrument(level = "debug", skip(self))] + async fn variants(&self, input_hash: Hash) -> Result)>, RepoError> { + use schema::variants::dsl::*; + + let mut conn = self.get_connection().await?; + + let vec = variants + .select((variant, identifier)) + .filter(hash.eq(&input_hash)) + .get_results::<(String, String)>(&mut conn) + .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_FOR_HASH) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)? + .into_iter() + .map(|(s, i)| (s, Arc::from(i))) + .collect(); + + Ok(vec) + } + + #[tracing::instrument(level = "debug", skip(self))] + async fn remove_variant( + &self, + input_hash: Hash, + input_variant: String, + ) -> Result<(), RepoError> { + use schema::variants::dsl::*; + + let mut conn = self.get_connection().await?; + + diesel::delete(variants) + .filter(hash.eq(&input_hash)) + .filter(variant.eq(&input_variant)) + .execute(&mut conn) + .with_metrics(crate::init_metrics::POSTGRES_VARIANTS_REMOVE) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; + + Ok(()) + } +} + #[async_trait::async_trait(?Send)] impl AliasRepo for PostgresRepo { #[tracing::instrument(level = "debug", skip(self))] @@ -1279,16 +1447,22 @@ impl DetailsRepo for PostgresRepo { let value = serde_json::to_value(&input_details.inner).map_err(PostgresError::SerializeDetails)?; - diesel::insert_into(details) + let res = diesel::insert_into(details) .values((identifier.eq(input_identifier.as_ref()), json.eq(&value))) .execute(&mut conn) .with_metrics(crate::init_metrics::POSTGRES_DETAILS_RELATE) .with_timeout(Duration::from_secs(5)) .await - .map_err(|_| PostgresError::DbTimeout)? - .map_err(PostgresError::Diesel)?; + .map_err(|_| PostgresError::DbTimeout)?; - Ok(()) + match res { + Ok(_) + | Err(diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UniqueViolation, + _, + )) => Ok(()), + Err(e) => Err(PostgresError::Diesel(e).into()), + } } #[tracing::instrument(level = "debug", skip(self))] diff --git a/src/repo/postgres/migrations/V0014__add_keyed_notifications.rs b/src/repo/postgres/migrations/V0014__add_keyed_notifications.rs new file mode 100644 index 0000000..9b46661 --- /dev/null +++ b/src/repo/postgres/migrations/V0014__add_keyed_notifications.rs @@ -0,0 +1,50 @@ +use barrel::backend::Pg; +use barrel::functions::AutogenFunction; +use barrel::{types, Migration}; + +pub(crate) fn migration() -> String { + let mut m = Migration::new(); + + m.create_table("keyed_notifications", |t| { + t.add_column( + "key", + types::text().primary(true).unique(true).nullable(false), + ); + t.add_column( + "heartbeat", + types::datetime() + .nullable(false) + .default(AutogenFunction::CurrentTimestamp), + ); + + t.add_index( + "keyed_notifications_heartbeat_index", + types::index(["heartbeat"]), + ); + }); + + m.inject_custom( + r#" +CREATE OR REPLACE FUNCTION keyed_notify() + RETURNS trigger AS +$$ +BEGIN + PERFORM pg_notify('keyed_notification_channel', OLD.key); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + "# + .trim(), + ); + + m.inject_custom( + r#" +CREATE TRIGGER keyed_notification_removed + AFTER DELETE + ON keyed_notifications + FOR EACH ROW +EXECUTE PROCEDURE keyed_notify(); + "#, + ); + m.make::().to_string() +} diff --git a/src/repo/postgres/schema.rs b/src/repo/postgres/schema.rs index fa772b7..9a634b7 100644 --- a/src/repo/postgres/schema.rs +++ b/src/repo/postgres/schema.rs @@ -48,6 +48,13 @@ diesel::table! { } } +diesel::table! { + keyed_notifications (key) { + key -> Text, + heartbeat -> Timestamp, + } +} + diesel::table! { proxies (url) { url -> Text, @@ -109,6 +116,7 @@ diesel::allow_tables_to_appear_in_same_query!( details, hashes, job_queue, + keyed_notifications, proxies, refinery_schema_history, settings, diff --git a/src/repo/sled.rs b/src/repo/sled.rs index 6042116..ff45e04 100644 --- a/src/repo/sled.rs +++ b/src/repo/sled.rs @@ -5,6 +5,7 @@ use crate::{ serde_str::Serde, stream::{from_iterator, LocalBoxStream}, }; +use dashmap::DashMap; use sled::{transaction::TransactionError, Db, IVec, Transactional, Tree}; use std::{ collections::HashMap, @@ -22,10 +23,11 @@ use uuid::Uuid; use super::{ hash::Hash, metrics::{PopMetricsGuard, PushMetricsGuard, WaitMetricsGuard}, + notification_map::{NotificationEntry, NotificationMap}, Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, Details, DetailsRepo, FullRepo, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash, ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo, - UploadResult, VariantAccessRepo, VariantAlreadyExists, + UploadResult, VariantAccessRepo, VariantAlreadyExists, VariantRepo, }; macro_rules! b { @@ -113,6 +115,8 @@ pub(crate) struct SledRepo { migration_identifiers: Tree, cache_capacity: u64, export_path: PathBuf, + variant_process_map: DashMap<(Hash, String), time::OffsetDateTime>, + notifications: NotificationMap, db: Db, } @@ -156,6 +160,8 @@ impl SledRepo { migration_identifiers: db.open_tree("pict-rs-migration-identifiers-tree")?, cache_capacity, export_path, + variant_process_map: DashMap::new(), + notifications: NotificationMap::new(), db, }) } @@ -1331,88 +1337,6 @@ impl HashRepo for SledRepo { Ok(opt.map(try_into_arc_str).transpose()?) } - #[tracing::instrument(level = "trace", skip(self))] - async fn relate_variant_identifier( - &self, - hash: Hash, - variant: String, - identifier: &Arc, - ) -> Result, RepoError> { - let hash = hash.to_bytes(); - - let key = variant_key(&hash, &variant); - let value = identifier.clone(); - - let hash_variant_identifiers = self.hash_variant_identifiers.clone(); - - crate::sync::spawn_blocking("sled-io", move || { - hash_variant_identifiers - .compare_and_swap(key, Option::<&[u8]>::None, Some(value.as_bytes())) - .map(|res| res.map_err(|_| VariantAlreadyExists)) - }) - .await - .map_err(|_| RepoError::Canceled)? - .map_err(SledError::from) - .map_err(RepoError::from) - } - - #[tracing::instrument(level = "trace", skip(self))] - async fn variant_identifier( - &self, - hash: Hash, - variant: String, - ) -> Result>, RepoError> { - let hash = hash.to_bytes(); - - let key = variant_key(&hash, &variant); - - let opt = b!( - self.hash_variant_identifiers, - hash_variant_identifiers.get(key) - ); - - Ok(opt.map(try_into_arc_str).transpose()?) - } - - #[tracing::instrument(level = "debug", skip(self))] - async fn variants(&self, hash: Hash) -> Result)>, RepoError> { - let hash = hash.to_ivec(); - - let vec = b!( - self.hash_variant_identifiers, - Ok(hash_variant_identifiers - .scan_prefix(hash.clone()) - .filter_map(|res| res.ok()) - .filter_map(|(key, ivec)| { - let identifier = try_into_arc_str(ivec).ok(); - - let variant = variant_from_key(&hash, &key); - if variant.is_none() { - tracing::warn!("Skipping a variant: {}", String::from_utf8_lossy(&key)); - } - - Some((variant?, identifier?)) - }) - .collect::>()) as Result, SledError> - ); - - Ok(vec) - } - - #[tracing::instrument(level = "trace", skip(self))] - async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> { - let hash = hash.to_bytes(); - - let key = variant_key(&hash, &variant); - - b!( - self.hash_variant_identifiers, - hash_variant_identifiers.remove(key) - ); - - Ok(()) - } - #[tracing::instrument(level = "trace", skip(self))] async fn relate_blurhash(&self, hash: Hash, blurhash: Arc) -> Result<(), RepoError> { b!( @@ -1528,6 +1452,167 @@ impl HashRepo for SledRepo { } } +#[async_trait::async_trait(?Send)] +impl VariantRepo for SledRepo { + #[tracing::instrument(level = "trace", skip(self))] + async fn claim_variant_processing_rights( + &self, + hash: Hash, + variant: String, + ) -> Result, RepoError> { + let key = (hash.clone(), variant.clone()); + let now = time::OffsetDateTime::now_utc(); + let entry = self + .notifications + .register_interest(Arc::from(format!("{}{variant}", hash.to_base64()))); + + match self.variant_process_map.entry(key.clone()) { + dashmap::mapref::entry::Entry::Occupied(mut occupied_entry) => { + if occupied_entry + .get() + .saturating_add(time::Duration::minutes(2)) + > now + { + return Ok(Err(entry)); + } + + occupied_entry.insert(now); + } + dashmap::mapref::entry::Entry::Vacant(vacant_entry) => { + vacant_entry.insert(now); + } + } + + if self.variant_identifier(hash, variant).await?.is_some() { + self.variant_process_map.remove(&key); + return Ok(Err(entry)); + } + + Ok(Ok(())) + } + + async fn variant_waiter( + &self, + hash: Hash, + variant: String, + ) -> Result { + let entry = self + .notifications + .register_interest(Arc::from(format!("{}{variant}", hash.to_base64()))); + + Ok(entry) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> { + let key = (hash, variant); + let now = time::OffsetDateTime::now_utc(); + + if let dashmap::mapref::entry::Entry::Occupied(mut occupied_entry) = + self.variant_process_map.entry(key) + { + occupied_entry.insert(now); + } + + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> { + let key = (hash.clone(), variant.clone()); + self.variant_process_map.remove(&key); + + let key = format!("{}{variant}", hash.to_base64()); + self.notifications.notify(&key); + + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn relate_variant_identifier( + &self, + hash: Hash, + variant: String, + identifier: &Arc, + ) -> Result, RepoError> { + let hash = hash.to_bytes(); + + let key = variant_key(&hash, &variant); + let value = identifier.clone(); + + let hash_variant_identifiers = self.hash_variant_identifiers.clone(); + + let out = crate::sync::spawn_blocking("sled-io", move || { + hash_variant_identifiers + .compare_and_swap(key, Option::<&[u8]>::None, Some(value.as_bytes())) + .map(|res| res.map_err(|_| VariantAlreadyExists)) + }) + .await + .map_err(|_| RepoError::Canceled)? + .map_err(SledError::from) + .map_err(RepoError::from)?; + + Ok(out) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn variant_identifier( + &self, + hash: Hash, + variant: String, + ) -> Result>, RepoError> { + let hash = hash.to_bytes(); + + let key = variant_key(&hash, &variant); + + let opt = b!( + self.hash_variant_identifiers, + hash_variant_identifiers.get(key) + ); + + Ok(opt.map(try_into_arc_str).transpose()?) + } + + #[tracing::instrument(level = "debug", skip(self))] + async fn variants(&self, hash: Hash) -> Result)>, RepoError> { + let hash = hash.to_ivec(); + + let vec = b!( + self.hash_variant_identifiers, + Ok(hash_variant_identifiers + .scan_prefix(hash.clone()) + .filter_map(|res| res.ok()) + .filter_map(|(key, ivec)| { + let identifier = try_into_arc_str(ivec).ok(); + + let variant = variant_from_key(&hash, &key); + if variant.is_none() { + tracing::warn!("Skipping a variant: {}", String::from_utf8_lossy(&key)); + } + + Some((variant?, identifier?)) + }) + .collect::>()) as Result, SledError> + ); + + Ok(vec) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> { + let hash = hash.to_bytes(); + + let key = variant_key(&hash, &variant); + + b!( + self.hash_variant_identifiers, + hash_variant_identifiers.remove(key) + ); + + Ok(()) + } +} + fn hash_alias_key(hash: &IVec, alias: &IVec) -> Vec { let mut v = hash.to_vec(); v.extend_from_slice(alias);