Compare commits

...

2 commits

Author SHA1 Message Date
asonix b43a435e64 Broken!!!!! 2024-03-30 09:36:31 -05:00
asonix 6e9239fa36 Move variant methods into variant repo trait 2024-03-28 12:04:40 -05:00
10 changed files with 561 additions and 545 deletions

View file

@ -1,172 +0,0 @@
use crate::{
details::Details,
error::{Error, UploadError},
repo::Hash,
};
use dashmap::{mapref::entry::Entry, DashMap};
use flume::{r#async::RecvFut, Receiver, Sender};
use std::{
future::Future,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tracing::Span;
type OutcomeReceiver = Receiver<(Details, Arc<str>)>;
type ProcessMapKey = (Hash, PathBuf);
type ProcessMapInner = DashMap<ProcessMapKey, OutcomeReceiver>;
#[derive(Debug, Default, Clone)]
pub(crate) struct ProcessMap {
process_map: Arc<ProcessMapInner>,
}
impl ProcessMap {
pub(super) fn new() -> Self {
Self::default()
}
pub(super) async fn process<Fut>(
&self,
hash: Hash,
path: PathBuf,
fut: Fut,
) -> Result<(Details, Arc<str>), Error>
where
Fut: Future<Output = Result<(Details, Arc<str>), Error>>,
{
let key = (hash.clone(), path.clone());
let (sender, receiver) = flume::bounded(1);
let entry = self.process_map.entry(key.clone());
let (state, span) = match entry {
Entry::Vacant(vacant) => {
vacant.insert(receiver);
let span = tracing::info_span!(
"Processing image",
hash = ?hash,
path = ?path,
completed = &tracing::field::Empty,
);
metrics::counter!(crate::init_metrics::PROCESS_MAP_INSERTED).increment(1);
(CancelState::Sender { sender }, span)
}
Entry::Occupied(receiver) => {
let span = tracing::info_span!(
"Waiting for processed image",
hash = ?hash,
path = ?path,
);
let receiver = receiver.get().clone().into_recv_async();
(CancelState::Receiver { receiver }, span)
}
};
CancelSafeProcessor {
cancel_token: CancelToken {
span,
key,
state,
process_map: self.clone(),
},
fut,
}
.await
}
fn remove(&self, key: &ProcessMapKey) -> Option<OutcomeReceiver> {
self.process_map.remove(key).map(|(_, v)| v)
}
}
struct CancelToken {
span: Span,
key: ProcessMapKey,
state: CancelState,
process_map: ProcessMap,
}
enum CancelState {
Sender {
sender: Sender<(Details, Arc<str>)>,
},
Receiver {
receiver: RecvFut<'static, (Details, Arc<str>)>,
},
}
impl CancelState {
const fn is_sender(&self) -> bool {
matches!(self, Self::Sender { .. })
}
}
pin_project_lite::pin_project! {
struct CancelSafeProcessor<F> {
cancel_token: CancelToken,
#[pin]
fut: F,
}
}
impl<F> Future for CancelSafeProcessor<F>
where
F: Future<Output = Result<(Details, Arc<str>), Error>>,
{
type Output = Result<(Details, Arc<str>), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();
let span = &this.cancel_token.span;
let process_map = &this.cancel_token.process_map;
let state = &mut this.cancel_token.state;
let key = &this.cancel_token.key;
let fut = this.fut;
span.in_scope(|| match state {
CancelState::Sender { sender } => {
let res = std::task::ready!(fut.poll(cx));
if process_map.remove(key).is_some() {
metrics::counter!(crate::init_metrics::PROCESS_MAP_REMOVED).increment(1);
}
if let Ok(tup) = &res {
let _ = sender.try_send(tup.clone());
}
Poll::Ready(res)
}
CancelState::Receiver { ref mut receiver } => Pin::new(receiver)
.poll(cx)
.map(|res| res.map_err(|_| UploadError::Canceled.into())),
})
}
}
impl Drop for CancelToken {
fn drop(&mut self) {
if self.state.is_sender() {
let completed = self.process_map.remove(&self.key).is_none();
self.span.record("completed", completed);
if !completed {
metrics::counter!(crate::init_metrics::PROCESS_MAP_REMOVED).increment(1);
}
}
}
}

View file

@ -2,7 +2,6 @@ mod ffmpeg;
mod magick;
use crate::{
concurrent_processor::ProcessMap,
details::Details,
error::{Error, UploadError},
formats::{ImageFormat, InputProcessableFormat, InternalVideoFormat, ProcessableFormat},
@ -13,6 +12,7 @@ use crate::{
};
use std::{
future::Future,
path::PathBuf,
sync::Arc,
time::{Duration, Instant},
@ -48,10 +48,9 @@ impl Drop for MetricsGuard {
}
}
#[tracing::instrument(skip(state, process_map, original_details, hash))]
#[tracing::instrument(skip(state, original_details, hash))]
pub(crate) async fn generate<S: Store + 'static>(
state: &State<S>,
process_map: &ProcessMap,
format: InputProcessableFormat,
thumbnail_path: PathBuf,
thumbnail_args: Vec<String>,
@ -67,33 +66,97 @@ pub(crate) async fn generate<S: Store + 'static>(
Ok((original_details.clone(), identifier))
} else {
let process_fut = process(
state,
format,
thumbnail_path.clone(),
thumbnail_args,
original_details,
hash.clone(),
)
.with_poll_timer("process-future");
let variant = thumbnail_path.to_string_lossy().to_string();
let (details, identifier) = process_map
.process(hash, thumbnail_path, process_fut)
.with_poll_timer("process-map-future")
.with_timeout(Duration::from_secs(state.config.media.process_timeout * 4))
.with_metrics(crate::init_metrics::GENERATE_PROCESS)
.await
.map_err(|_| UploadError::ProcessTimeout)??;
let mut attempts = 0;
let (details, identifier) = loop {
if attempts > 4 {
todo!("return error");
}
match state
.repo
.claim_variant_processing_rights(hash.clone(), variant.clone())
.await?
{
Ok(()) => {
// process
let process_future = process(
state,
format,
variant.clone(),
thumbnail_args,
original_details,
hash.clone(),
)
.with_poll_timer("process-future");
let tuple = heartbeat(state, hash.clone(), variant.clone(), process_future)
.with_poll_timer("heartbeat-future")
.await??;
break tuple;
}
Err(_) => match state
.repo
.await_variant(hash.clone(), variant.clone())
.await?
{
Some(identifier) => {
let details = crate::ensure_details_identifier(state, &identifier).await?;
break (details, identifier);
}
None => {
attempts += 1;
continue;
}
},
}
};
Ok((details, identifier))
}
}
async fn heartbeat<S, O>(
state: &State<S>,
hash: Hash,
variant: String,
future: impl Future<Output = O>,
) -> Result<O, Error> {
let repo = state.repo.clone();
let handle = crate::sync::abort_on_drop(crate::sync::spawn("heartbeat-task", async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
if let Err(e) = repo.variant_heartbeat(hash.clone(), variant.clone()).await {
break Error::from(e);
}
}
}));
let future = std::pin::pin!(future);
tokio::select! {
biased;
output = future => {
Ok(output)
}
res = handle => {
Err(res.map_err(|_| UploadError::Canceled)?)
}
}
}
#[tracing::instrument(skip(state, hash))]
async fn process<S: Store + 'static>(
state: &State<S>,
output_format: InputProcessableFormat,
thumbnail_path: PathBuf,
variant: String,
thumbnail_args: Vec<String>,
original_details: &Details,
hash: Hash,
@ -142,19 +205,21 @@ async fn process<S: Store + 'static>(
)
.await?;
if let Err(VariantAlreadyExists) = state
let identifier = if let Err(VariantAlreadyExists) = state
.repo
.relate_variant_identifier(
hash,
thumbnail_path.to_string_lossy().to_string(),
&identifier,
)
.relate_variant_identifier(hash.clone(), variant.clone(), &identifier)
.await?
{
state.store.remove(&identifier).await?;
}
state.repo.relate_details(&identifier, &details).await?;
state
.repo
.variant_identifier(hash, variant)
.await?
.ok_or(UploadError::MissingIdentifier)?
} else {
state.repo.relate_details(&identifier, &details).await?;
identifier
};
guard.disarm();

View file

@ -1,7 +1,6 @@
mod backgrounded;
mod blurhash;
mod bytes_stream;
mod concurrent_processor;
mod config;
mod details;
mod discover;
@ -71,7 +70,6 @@ use tracing_actix_web::TracingLogger;
use self::{
backgrounded::Backgrounded,
concurrent_processor::ProcessMap,
config::{Configuration, Operation},
details::Details,
either::Either,
@ -848,13 +846,12 @@ async fn not_found_hash(repo: &ArcRepo) -> Result<Option<(Alias, Hash)>, Error>
}
/// Process files
#[tracing::instrument(name = "Serving processed image", skip(state, process_map))]
#[tracing::instrument(name = "Serving processed image", skip(state))]
async fn process<S: Store + 'static>(
range: Option<web::Header<Range>>,
web::Query(ProcessQuery { source, operations }): web::Query<ProcessQuery>,
ext: web::Path<String>,
state: web::Data<State<S>>,
process_map: web::Data<ProcessMap>,
) -> Result<HttpResponse, Error> {
let alias = proxy_alias_from_query(source.into(), &state).await?;
@ -898,7 +895,6 @@ async fn process<S: Store + 'static>(
generate::generate(
&state,
&process_map,
format,
thumbnail_path,
thumbnail_args,
@ -1591,14 +1587,12 @@ fn json_config() -> web::JsonConfig {
fn configure_endpoints<S: Store + 'static, F: Fn(&mut web::ServiceConfig)>(
config: &mut web::ServiceConfig,
state: State<S>,
process_map: ProcessMap,
extra_config: F,
) {
config
.app_data(query_config())
.app_data(json_config())
.app_data(web::Data::new(state.clone()))
.app_data(web::Data::new(process_map.clone()))
.route("/healthz", web::get().to(healthz::<S>))
.service(
web::scope("/image")
@ -1706,12 +1700,12 @@ fn spawn_cleanup<S>(state: State<S>) {
});
}
fn spawn_workers<S>(state: State<S>, process_map: ProcessMap)
fn spawn_workers<S>(state: State<S>)
where
S: Store + 'static,
{
crate::sync::spawn("cleanup-worker", queue::process_cleanup(state.clone()));
crate::sync::spawn("process-worker", queue::process_images(state, process_map));
crate::sync::spawn("process-worker", queue::process_images(state));
}
fn watch_keys(tls: Tls, sender: ChannelSender) -> DropHandle<()> {
@ -1737,8 +1731,6 @@ async fn launch<
state: State<S>,
extra_config: F,
) -> color_eyre::Result<()> {
let process_map = ProcessMap::new();
let address = state.config.server.address;
let tls = Tls::from_config(&state.config);
@ -1748,18 +1740,15 @@ async fn launch<
let server = HttpServer::new(move || {
let extra_config = extra_config.clone();
let state = state.clone();
let process_map = process_map.clone();
spawn_workers(state.clone(), process_map.clone());
spawn_workers(state.clone());
App::new()
.wrap(TracingLogger::default())
.wrap(Deadline)
.wrap(Metrics)
.wrap(Payload::new())
.configure(move |sc| {
configure_endpoints(sc, state.clone(), process_map.clone(), extra_config)
})
.configure(move |sc| configure_endpoints(sc, state.clone(), extra_config))
});
if let Some(tls) = tls {

View file

@ -1,5 +1,4 @@
use crate::{
concurrent_processor::ProcessMap,
error::{Error, UploadError},
formats::InputProcessableFormat,
future::{LocalBoxFuture, WithPollTimer},
@ -196,8 +195,8 @@ pub(crate) async fn process_cleanup<S: Store + 'static>(state: State<S>) {
process_jobs(state, CLEANUP_QUEUE, cleanup::perform).await
}
pub(crate) async fn process_images<S: Store + 'static>(state: State<S>, process_map: ProcessMap) {
process_image_jobs(state, process_map, PROCESS_QUEUE, process::perform).await
pub(crate) async fn process_images<S: Store + 'static>(state: State<S>) {
process_jobs(state, PROCESS_QUEUE, process::perform).await
}
struct MetricsGuard {
@ -357,7 +356,7 @@ where
let (job_id, job) = state
.repo
.pop(queue, worker_id)
.with_poll_timer("pop-cleanup")
.with_poll_timer("pop-job")
.await?;
let guard = MetricsGuard::guard(worker_id, queue);
@ -369,99 +368,13 @@ where
job_id,
(callback)(state, job),
)
.with_poll_timer("cleanup-job-and-heartbeat")
.await;
state
.repo
.complete_job(queue, worker_id, job_id, job_result(&res))
.with_poll_timer("cleanup-job-complete")
.await?;
res?;
guard.disarm();
Ok(()) as Result<(), Error>
}
.instrument(tracing::info_span!("tick", %queue, %worker_id))
.await?;
}
}
async fn process_image_jobs<S, F>(
state: State<S>,
process_map: ProcessMap,
queue: &'static str,
callback: F,
) where
S: Store,
for<'a> F: Fn(&'a State<S>, &'a ProcessMap, serde_json::Value) -> JobFuture<'a> + Copy,
{
let worker_id = uuid::Uuid::new_v4();
loop {
tracing::trace!("process_image_jobs: looping");
crate::sync::cooperate().await;
let res = image_job_loop(&state, &process_map, worker_id, queue, callback)
.with_poll_timer("image-job-loop")
.await;
if let Err(e) = res {
tracing::warn!("Error processing jobs: {}", format!("{e}"));
tracing::warn!("{}", format!("{e:?}"));
if e.is_disconnected() {
tokio::time::sleep(Duration::from_secs(10)).await;
}
continue;
}
break;
}
}
async fn image_job_loop<S, F>(
state: &State<S>,
process_map: &ProcessMap,
worker_id: uuid::Uuid,
queue: &'static str,
callback: F,
) -> Result<(), Error>
where
S: Store,
for<'a> F: Fn(&'a State<S>, &'a ProcessMap, serde_json::Value) -> JobFuture<'a> + Copy,
{
loop {
tracing::trace!("image_job_loop: looping");
crate::sync::cooperate().await;
async {
let (job_id, job) = state
.repo
.pop(queue, worker_id)
.with_poll_timer("pop-process")
.await?;
let guard = MetricsGuard::guard(worker_id, queue);
let res = heartbeat(
&state.repo,
queue,
worker_id,
job_id,
(callback)(state, process_map, job),
)
.with_poll_timer("process-job-and-heartbeat")
.with_poll_timer("job-and-heartbeat")
.await;
state
.repo
.complete_job(queue, worker_id, job_id, job_result(&res))
.with_poll_timer("job-complete")
.await?;
res?;

View file

@ -2,7 +2,6 @@ use time::Instant;
use tracing::{Instrument, Span};
use crate::{
concurrent_processor::ProcessMap,
error::{Error, UploadError},
formats::InputProcessableFormat,
future::WithPollTimer,
@ -18,11 +17,7 @@ use std::{path::PathBuf, sync::Arc};
use super::{JobContext, JobFuture, JobResult};
pub(super) fn perform<'a, S>(
state: &'a State<S>,
process_map: &'a ProcessMap,
job: serde_json::Value,
) -> JobFuture<'a>
pub(super) fn perform<'a, S>(state: &'a State<S>, job: serde_json::Value) -> JobFuture<'a>
where
S: Store + 'static,
{
@ -58,7 +53,6 @@ where
} => {
generate(
state,
process_map,
target_format,
Serde::into_inner(source),
process_path,
@ -178,10 +172,9 @@ where
Ok(())
}
#[tracing::instrument(skip(state, process_map, process_path, process_args))]
#[tracing::instrument(skip(state, process_path, process_args))]
async fn generate<S: Store + 'static>(
state: &State<S>,
process_map: &ProcessMap,
target_format: InputProcessableFormat,
source: Alias,
process_path: PathBuf,
@ -211,7 +204,6 @@ async fn generate<S: Store + 'static>(
crate::generate::generate(
state,
process_map,
target_format,
process_path,
process_args,

View file

@ -103,6 +103,7 @@ pub(crate) trait FullRepo:
+ AliasRepo
+ QueueRepo
+ HashRepo
+ VariantRepo
+ StoreMigrationRepo
+ AliasAccessRepo
+ VariantAccessRepo
@ -653,20 +654,6 @@ pub(crate) trait HashRepo: BaseRepo {
async fn identifier(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError>;
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError>;
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError>;
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError>;
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>;
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError>;
async fn blurhash(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError>;
@ -726,6 +713,90 @@ where
T::identifier(self, hash).await
}
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError> {
T::relate_blurhash(self, hash, blurhash).await
}
async fn blurhash(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::blurhash(self, hash).await
}
async fn relate_motion_identifier(
&self,
hash: Hash,
identifier: &Arc<str>,
) -> Result<(), RepoError> {
T::relate_motion_identifier(self, hash, identifier).await
}
async fn motion_identifier(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::motion_identifier(self, hash).await
}
async fn cleanup_hash(&self, hash: Hash) -> Result<(), RepoError> {
T::cleanup_hash(self, hash).await
}
}
#[async_trait::async_trait(?Send)]
pub(crate) trait VariantRepo: BaseRepo {
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), VariantAlreadyExists>, RepoError>;
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError>;
async fn await_variant(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError>;
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError>;
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError>;
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError>;
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>;
}
#[async_trait::async_trait(?Send)]
impl<T> VariantRepo for Arc<T>
where
T: VariantRepo,
{
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
T::claim_variant_processing_rights(self, hash, variant).await
}
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
T::variant_heartbeat(self, hash, variant).await
}
async fn await_variant(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
T::await_variant(self, hash, variant).await
}
async fn relate_variant_identifier(
&self,
hash: Hash,
@ -750,30 +821,6 @@ where
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
T::remove_variant(self, hash, variant).await
}
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError> {
T::relate_blurhash(self, hash, blurhash).await
}
async fn blurhash(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::blurhash(self, hash).await
}
async fn relate_motion_identifier(
&self,
hash: Hash,
identifier: &Arc<str>,
) -> Result<(), RepoError> {
T::relate_motion_identifier(self, hash, identifier).await
}
async fn motion_identifier(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::motion_identifier(self, hash).await
}
async fn cleanup_hash(&self, hash: Hash) -> Result<(), RepoError> {
T::cleanup_hash(self, hash).await
}
}
#[async_trait::async_trait(?Send)]

View file

@ -46,7 +46,7 @@ use super::{
Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, DetailsRepo,
FullRepo, Hash, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash,
ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo,
UploadResult, VariantAccessRepo, VariantAlreadyExists,
UploadResult, VariantAccessRepo, VariantAlreadyExists, VariantRepo,
};
#[derive(Clone)]
@ -62,6 +62,7 @@ struct Inner {
notifier_pool: Pool<AsyncPgConnection>,
queue_notifications: DashMap<String, Arc<Notify>>,
upload_notifications: DashMap<UploadId, Weak<Notify>>,
keyed_notifications: DashMap<String, Arc<Notify>>,
}
struct UploadInterest {
@ -81,6 +82,10 @@ struct UploadNotifierState<'a> {
inner: &'a Inner,
}
struct KeyedNotifierState<'a> {
inner: &'a Inner,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnectPostgresError {
#[error("Failed to connect to postgres for migrations")]
@ -331,6 +336,7 @@ impl PostgresRepo {
notifier_pool,
queue_notifications: DashMap::new(),
upload_notifications: DashMap::new(),
keyed_notifications: DashMap::new(),
});
let handle = crate::sync::abort_on_drop(crate::sync::spawn_sendable(
@ -363,8 +369,55 @@ impl PostgresRepo {
.with_poll_timer("postgres-get-notifier-connection")
.await
}
async fn insert_keyed_notifier(
&self,
input_key: &str,
) -> Result<Result<(), AlreadyInserted>, PostgresError> {
use schema::keyed_notifications::dsl::*;
let mut conn = self.get_connection().await?;
let res = diesel::insert_into(keyed_notifications)
.values((key.eq(input_key)))
.execute(&mut conn)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?;
match res {
Ok(_) => Ok(Ok(())),
Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)) => Ok(Err(AlreadyInserted)),
Err(e) => Err(PostgresError::Diesel(e).into()),
}
}
async fn listen_on_key(&self, input_key: &str) -> Result<Result<(), TimedOut>, PostgresError> {
todo!()
}
async fn clear_keyed_notifier(&self, input_key: &str) -> Result<(), PostgresError> {
use schema::keyed_notifications::dsl::*;
let mut conn = self.get_connection().await?;
diesel::delete(keyed_notifications)
.filter(key.eq(input_key))
.execute(&mut conn)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)??;
Ok(())
}
}
struct TimedOut;
struct AlreadyInserted;
struct GetConnectionMetricsGuard {
start: Instant,
armed: bool,
@ -511,6 +564,19 @@ impl<'a> UploadNotifierState<'a> {
}
}
impl<'a> KeyedNotifierState<'a> {
fn handle(&self, key: &str) {
if let Some(notifier) = self
.inner
.keyed_notifications
.remove(key)
.and_then(|(_, weak)| weak.upgrade())
{
notifier.notify_waiters();
}
}
}
type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
type ConfigFn =
Box<dyn Fn(&str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> + Send + Sync + 'static>;
@ -529,6 +595,8 @@ async fn delegate_notifications(
let upload_notifier_state = UploadNotifierState { inner: &inner };
let keyed_notifier_state = KeyedNotifierState { inner: &inner };
while let Ok(notification) = receiver.recv_async().await {
tracing::trace!("delegate_notifications: looping");
metrics::counter!(crate::init_metrics::POSTGRES_NOTIFICATION).increment(1);
@ -542,6 +610,10 @@ async fn delegate_notifications(
// new upload finished
upload_notifier_state.handle(notification.payload());
}
"keyed_notification_channel" => {
// new keyed notification
keyed_notifier_state.handle(notification.payload());
}
channel => {
tracing::info!(
"Unhandled postgres notification: {channel}: {}",
@ -863,110 +935,6 @@ impl HashRepo for PostgresRepo {
Ok(opt.map(Arc::from))
}
#[tracing::instrument(level = "debug", skip(self))]
async fn relate_variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
input_identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let res = diesel::insert_into(variants)
.values((
hash.eq(&input_hash),
variant.eq(&input_variant),
identifier.eq(input_identifier.as_ref()),
))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_RELATE_VARIANT_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?;
match res {
Ok(_) => Ok(Ok(())),
Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)) => Ok(Err(VariantAlreadyExists)),
Err(e) => Err(PostgresError::Diesel(e).into()),
}
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let opt = variants
.select(identifier)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.get_result::<String>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.optional()
.map_err(PostgresError::Diesel)?
.map(Arc::from);
Ok(opt)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, input_hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let vec = variants
.select((variant, identifier))
.filter(hash.eq(&input_hash))
.get_results::<(String, String)>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_FOR_HASH)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?
.into_iter()
.map(|(s, i)| (s, Arc::from(i)))
.collect();
Ok(vec)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn remove_variant(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<(), RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
diesel::delete(variants)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_REMOVE)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn relate_blurhash(
&self,
@ -1083,6 +1051,136 @@ impl HashRepo for PostgresRepo {
}
}
#[async_trait::async_trait(?Send)]
impl VariantRepo for PostgresRepo {
#[tracing::instrument(level = "debug", skip(self))]
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
todo!()
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
todo!()
}
#[tracing::instrument(level = "debug", skip(self))]
async fn await_variant(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
todo!()
}
#[tracing::instrument(level = "debug", skip(self))]
async fn relate_variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
input_identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let res = diesel::insert_into(variants)
.values((
hash.eq(&input_hash),
variant.eq(&input_variant),
identifier.eq(input_identifier.as_ref()),
))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_RELATE_VARIANT_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?;
match res {
Ok(_) => Ok(Ok(())),
Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)) => Ok(Err(VariantAlreadyExists)),
Err(e) => Err(PostgresError::Diesel(e).into()),
}
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let opt = variants
.select(identifier)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.get_result::<String>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.optional()
.map_err(PostgresError::Diesel)?
.map(Arc::from);
Ok(opt)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, input_hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let vec = variants
.select((variant, identifier))
.filter(hash.eq(&input_hash))
.get_results::<(String, String)>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_FOR_HASH)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?
.into_iter()
.map(|(s, i)| (s, Arc::from(i)))
.collect();
Ok(vec)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn remove_variant(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<(), RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
diesel::delete(variants)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_REMOVE)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
Ok(())
}
}
#[async_trait::async_trait(?Send)]
impl AliasRepo for PostgresRepo {
#[tracing::instrument(level = "debug", skip(self))]

View file

@ -0,0 +1,50 @@
use barrel::backend::Pg;
use barrel::functions::AutogenFunction;
use barrel::{types, Migration};
pub(crate) fn migration() -> String {
let mut m = Migration::new();
m.create_table("keyed_notifications", |t| {
t.add_column(
"key",
types::text().primary(true).unique(true).nullable(false),
);
t.add_column(
"heartbeat",
types::datetime()
.nullable(false)
.default(AutogenFunction::CurrentTimestamp),
);
t.add_index(
"keyed_notifications_heartbeat_index",
types::index(["heartbeat"]),
);
});
m.inject_custom(
r#"
CREATE OR REPLACE FUNCTION keyed_notify()
RETURNS trigger AS
$$
BEGIN
PERFORM pg_notify('keyed_notification_channel', OLD.key);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"#
.trim(),
);
m.inject_custom(
r#"
CREATE TRIGGER keyed_notification_removed
AFTER DELETE
ON keyed_notifications
FOR EACH ROW
EXECUTE PROCEDURE keyed_notify();
"#,
);
m.make::<Pg>().to_string()
}

View file

@ -48,6 +48,13 @@ diesel::table! {
}
}
diesel::table! {
keyed_notifications (key) {
key -> Text,
heartbeat -> Timestamp,
}
}
diesel::table! {
proxies (url) {
url -> Text,
@ -109,6 +116,7 @@ diesel::allow_tables_to_appear_in_same_query!(
details,
hashes,
job_queue,
keyed_notifications,
proxies,
refinery_schema_history,
settings,

View file

@ -25,7 +25,7 @@ use super::{
Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, Details,
DetailsRepo, FullRepo, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash,
ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo,
UploadResult, VariantAccessRepo, VariantAlreadyExists,
UploadResult, VariantAccessRepo, VariantAlreadyExists, VariantRepo,
};
macro_rules! b {
@ -1331,88 +1331,6 @@ impl HashRepo for SledRepo {
Ok(opt.map(try_into_arc_str).transpose()?)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let value = identifier.clone();
let hash_variant_identifiers = self.hash_variant_identifiers.clone();
crate::sync::spawn_blocking("sled-io", move || {
hash_variant_identifiers
.compare_and_swap(key, Option::<&[u8]>::None, Some(value.as_bytes()))
.map(|res| res.map_err(|_| VariantAlreadyExists))
})
.await
.map_err(|_| RepoError::Canceled)?
.map_err(SledError::from)
.map_err(RepoError::from)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let opt = b!(
self.hash_variant_identifiers,
hash_variant_identifiers.get(key)
);
Ok(opt.map(try_into_arc_str).transpose()?)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
let hash = hash.to_ivec();
let vec = b!(
self.hash_variant_identifiers,
Ok(hash_variant_identifiers
.scan_prefix(hash.clone())
.filter_map(|res| res.ok())
.filter_map(|(key, ivec)| {
let identifier = try_into_arc_str(ivec).ok();
let variant = variant_from_key(&hash, &key);
if variant.is_none() {
tracing::warn!("Skipping a variant: {}", String::from_utf8_lossy(&key));
}
Some((variant?, identifier?))
})
.collect::<Vec<_>>()) as Result<Vec<_>, SledError>
);
Ok(vec)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
b!(
self.hash_variant_identifiers,
hash_variant_identifiers.remove(key)
);
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError> {
b!(
@ -1528,6 +1446,114 @@ impl HashRepo for SledRepo {
}
}
#[async_trait::async_trait(?Send)]
impl VariantRepo for SledRepo {
#[tracing::instrument(level = "trace", skip(self))]
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
todo!()
}
#[tracing::instrument(level = "trace", skip(self))]
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
todo!()
}
#[tracing::instrument(level = "trace", skip(self))]
async fn await_variant(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
todo!()
}
#[tracing::instrument(level = "trace", skip(self))]
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let value = identifier.clone();
let hash_variant_identifiers = self.hash_variant_identifiers.clone();
crate::sync::spawn_blocking("sled-io", move || {
hash_variant_identifiers
.compare_and_swap(key, Option::<&[u8]>::None, Some(value.as_bytes()))
.map(|res| res.map_err(|_| VariantAlreadyExists))
})
.await
.map_err(|_| RepoError::Canceled)?
.map_err(SledError::from)
.map_err(RepoError::from)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let opt = b!(
self.hash_variant_identifiers,
hash_variant_identifiers.get(key)
);
Ok(opt.map(try_into_arc_str).transpose()?)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
let hash = hash.to_ivec();
let vec = b!(
self.hash_variant_identifiers,
Ok(hash_variant_identifiers
.scan_prefix(hash.clone())
.filter_map(|res| res.ok())
.filter_map(|(key, ivec)| {
let identifier = try_into_arc_str(ivec).ok();
let variant = variant_from_key(&hash, &key);
if variant.is_none() {
tracing::warn!("Skipping a variant: {}", String::from_utf8_lossy(&key));
}
Some((variant?, identifier?))
})
.collect::<Vec<_>>()) as Result<Vec<_>, SledError>
);
Ok(vec)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
b!(
self.hash_variant_identifiers,
hash_variant_identifiers.remove(key)
);
Ok(())
}
}
fn hash_alias_key(hash: &IVec, alias: &IVec) -> Vec<u8> {
let mut v = hash.to_vec();
v.extend_from_slice(alias);