diff --git a/src/repo/postgres.rs b/src/repo/postgres.rs index e2db982..fc489ec 100644 --- a/src/repo/postgres.rs +++ b/src/repo/postgres.rs @@ -4,6 +4,7 @@ mod schema; use std::{ collections::{BTreeSet, VecDeque}, + future::Future, path::PathBuf, sync::{ atomic::{AtomicU64, Ordering}, @@ -62,7 +63,32 @@ struct Inner { notifier_pool: Pool, queue_notifications: DashMap>, upload_notifications: DashMap>, - keyed_notifications: DashMap>, + keyed_notifications: DashMap, Weak>, +} + +struct NotificationEntry { + key: Arc, + inner: Arc, + notify: Notify, +} + +impl Drop for NotificationEntry { + fn drop(&mut self) { + self.inner.keyed_notifications.remove(self.key.as_ref()); + } +} + +struct KeyListener { + entry: Arc, +} + +impl KeyListener { + fn notified_timeout( + &self, + timeout: Duration, + ) -> impl Future> + '_ { + self.entry.notify.notified().with_timeout(timeout) + } } struct UploadInterest { @@ -379,7 +405,7 @@ impl PostgresRepo { let mut conn = self.get_connection().await?; let res = diesel::insert_into(keyed_notifications) - .values((key.eq(input_key))) + .values(key.eq(input_key)) .execute(&mut conn) .with_timeout(Duration::from_secs(5)) .await @@ -395,8 +421,61 @@ impl PostgresRepo { } } - async fn listen_on_key(&self, input_key: &str) -> Result, PostgresError> { - todo!() + async fn keyed_notifier_heartbeat(&self, input_key: &str) -> Result<(), PostgresError> { + use schema::keyed_notifications::dsl::*; + + let mut conn = self.get_connection().await?; + + let timestamp = to_primitive(time::OffsetDateTime::now_utc()); + + diesel::update(keyed_notifications) + .filter(key.eq(input_key)) + .set(heartbeat.eq(timestamp)) + .execute(&mut conn) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; + + Ok(()) + } + + fn listen_on_key(&self, key: Arc) -> KeyListener { + let new_entry = Arc::new(NotificationEntry { + key: key.clone(), + inner: Arc::clone(&self.inner), + notify: crate::sync::bare_notify(), + }); + + let mut entry = self + .inner + .keyed_notifications + .entry(key) + .or_insert_with(|| Arc::downgrade(&new_entry)); + + let upgraded = entry.value().upgrade(); + + let entry = if let Some(existing_entry) = upgraded { + existing_entry + } else { + *entry.value_mut() = Arc::downgrade(&new_entry); + new_entry + }; + + KeyListener { entry } + } + + async fn register_interest(&self) -> Result<(), PostgresError> { + let mut notifier_conn = self.get_notifier_connection().await?; + + diesel::sql_query("LISTEN upload_completion_channel;") + .execute(&mut notifier_conn) + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; + + Ok(()) } async fn clear_keyed_notifier(&self, input_key: &str) -> Result<(), PostgresError> { @@ -409,13 +488,13 @@ impl PostgresRepo { .execute(&mut conn) .with_timeout(Duration::from_secs(5)) .await - .map_err(|_| PostgresError::DbTimeout)??; + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; Ok(()) } } -struct TimedOut; struct AlreadyInserted; struct GetConnectionMetricsGuard { @@ -490,13 +569,15 @@ impl Inner { } impl UploadInterest { - async fn notified_timeout(&self, timeout: Duration) -> Result<(), tokio::time::error::Elapsed> { + fn notified_timeout( + &self, + timeout: Duration, + ) -> impl Future> + '_ { self.interest .as_ref() .expect("interest exists") .notified() .with_timeout(timeout) - .await } } @@ -566,13 +647,13 @@ impl<'a> UploadNotifierState<'a> { impl<'a> KeyedNotifierState<'a> { fn handle(&self, key: &str) { - if let Some(notifier) = self + if let Some(notification_entry) = self .inner .keyed_notifications .remove(key) .and_then(|(_, weak)| weak.upgrade()) { - notifier.notify_waiters(); + notification_entry.notify.notify_waiters(); } } } @@ -1059,12 +1140,29 @@ impl VariantRepo for PostgresRepo { hash: Hash, variant: String, ) -> Result, RepoError> { - todo!() + if self + .variant_identifier(hash.clone(), variant.clone()) + .await? + .is_some() + { + return Ok(Err(VariantAlreadyExists)); + } + + let key = format!("{}{variant}", hash.to_base64()); + + match self.insert_keyed_notifier(&key).await? { + Ok(()) => Ok(Ok(())), + Err(AlreadyInserted) => Ok(Err(VariantAlreadyExists)), + } } #[tracing::instrument(level = "debug", skip(self))] async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> { - todo!() + let key = format!("{}{variant}", hash.to_base64()); + + self.keyed_notifier_heartbeat(&key) + .await + .map_err(Into::into) } #[tracing::instrument(level = "debug", skip(self))] @@ -1073,7 +1171,26 @@ impl VariantRepo for PostgresRepo { hash: Hash, variant: String, ) -> Result>, RepoError> { - todo!() + let key = Arc::from(format!("{}{}", hash.to_base64(), variant.clone())); + + let listener = self.listen_on_key(key); + let notified = listener.notified_timeout(Duration::from_secs(10)); + + self.register_interest().await?; + + if let Some(identifier) = self + .variant_identifier(hash.clone(), variant.clone()) + .await? + { + return Ok(Some(identifier)); + } + + match notified.await { + Ok(()) => tracing::debug!("notified"), + Err(_) => tracing::trace!("timeout"), + } + + self.variant_identifier(hash, variant).await } #[tracing::instrument(level = "debug", skip(self))] @@ -1099,6 +1216,12 @@ impl VariantRepo for PostgresRepo { .await .map_err(|_| PostgresError::DbTimeout)?; + let key = format!("{}{}", input_hash.to_base64(), input_variant.clone()); + match self.clear_keyed_notifier(&key).await { + Ok(()) => {} + Err(e) => tracing::warn!("Failed to clear notifier: {e}"), + } + match res { Ok(_) => Ok(Ok(())), Err(diesel::result::Error::DatabaseError(