From b2674f06d0d2a69860aec478d3f282b41a11f769 Mon Sep 17 00:00:00 2001 From: asonix Date: Sun, 10 Sep 2023 20:08:01 -0400 Subject: [PATCH] More streme --- src/migrate_store.rs | 3 +- src/queue/cleanup.rs | 3 +- src/repo.rs | 98 +++++++------------------- src/repo/migrate.rs | 3 +- src/repo/postgres.rs | 159 +++++++++++++++++-------------------------- 5 files changed, 92 insertions(+), 174 deletions(-) diff --git a/src/migrate_store.rs b/src/migrate_store.rs index ac3b638..3b47812 100644 --- a/src/migrate_store.rs +++ b/src/migrate_store.rs @@ -107,7 +107,8 @@ where } // Hashes are read in a consistent order - let mut stream = repo.hashes().into_streamer(); + let stream = std::pin::pin!(repo.hashes()); + let mut stream = stream.into_streamer(); let state = Rc::new(MigrateState { repo: repo.clone(), diff --git a/src/queue/cleanup.rs b/src/queue/cleanup.rs index 8217e3b..2b91edb 100644 --- a/src/queue/cleanup.rs +++ b/src/queue/cleanup.rs @@ -138,7 +138,8 @@ async fn alias(repo: &ArcRepo, alias: Alias, token: DeleteToken) -> Result<(), E #[tracing::instrument(skip_all)] async fn all_variants(repo: &ArcRepo) -> Result<(), Error> { - let mut hash_stream = repo.hashes().into_streamer(); + let hash_stream = std::pin::pin!(repo.hashes()); + let mut hash_stream = hash_stream.into_streamer(); while let Some(res) = hash_stream.next().await { let hash = res?; diff --git a/src/repo.rs b/src/repo.rs index fddda9a..01f5c6e 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -8,10 +8,10 @@ use crate::{ config, details::Details, error_code::{ErrorCode, OwnedErrorCode}, - future::LocalBoxFuture, stream::LocalBoxStream, }; use base64::Engine; +use futures_core::Stream; use std::{fmt::Debug, sync::Arc}; use url::Url; use uuid::Uuid; @@ -572,81 +572,29 @@ impl HashPage { } } -type PageFuture = LocalBoxFuture<'static, Result>; - -pub(crate) struct HashStream { - repo: Option, - page_future: Option, - page: Option, -} - -impl futures_core::Stream for HashStream { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - - loop { - let Some(repo) = &this.repo else { - return std::task::Poll::Ready(None); - }; - - let slug = if let Some(page) = &mut this.page { - // popping last in page is fine - we reversed them - if let Some(hash) = page.hashes.pop() { - return std::task::Poll::Ready(Some(Ok(hash))); - } - - let slug = page.next(); - this.page.take(); - - if let Some(slug) = slug { - Some(slug) - } else { - this.repo.take(); - return std::task::Poll::Ready(None); - } - } else { - None - }; - - if let Some(page_future) = &mut this.page_future { - let res = std::task::ready!(page_future.as_mut().poll(cx)); - - this.page_future.take(); - - match res { - Ok(mut page) => { - // reverse because we use `pop` to fetch next - page.hashes.reverse(); - - this.page = Some(page); - } - Err(e) => { - this.repo.take(); - - return std::task::Poll::Ready(Some(Err(e))); - } - } - } else { - let repo = repo.clone(); - - this.page_future = Some(Box::pin(async move { repo.hash_page(slug, 100).await })); - } - } - } -} - impl dyn FullRepo { - pub(crate) fn hashes(self: &Arc) -> HashStream { - HashStream { - repo: Some(self.clone()), - page_future: None, - page: None, - } + pub(crate) fn hashes(self: &Arc) -> impl Stream> { + let repo = self.clone(); + + streem::try_from_fn(|yielder| async move { + let mut slug = None; + + loop { + let page = repo.hash_page(slug, 100).await?; + + slug = page.next(); + + for hash in page.hashes { + yielder.yield_ok(hash).await; + } + + if slug.is_none() { + break; + } + } + + Ok(()) + }) } } diff --git a/src/repo/migrate.rs b/src/repo/migrate.rs index 7809fed..8083696 100644 --- a/src/repo/migrate.rs +++ b/src/repo/migrate.rs @@ -35,7 +35,8 @@ pub(crate) async fn migrate_repo(old_repo: ArcRepo, new_repo: ArcRepo) -> Result tracing::warn!("Checks complete, migrating repo"); tracing::warn!("{total_size} hashes will be migrated"); - let mut hash_stream = old_repo.hashes().into_streamer(); + let hash_stream = std::pin::pin!(old_repo.hashes()); + let mut hash_stream = hash_stream.into_streamer(); let mut index = 0; while let Some(res) = hash_stream.next().await { diff --git a/src/repo/postgres.rs b/src/repo/postgres.rs index 459d56b..99f0b3e 100644 --- a/src/repo/postgres.rs +++ b/src/repo/postgres.rs @@ -21,6 +21,7 @@ use diesel_async::{ }, AsyncConnection, AsyncPgConnection, RunQueryDsl, }; +use futures_core::Stream; use tokio::sync::Notify; use tokio_postgres::{tls::NoTlsStream, AsyncMessage, Connection, NoTls, Notification, Socket}; use tracing::Instrument; @@ -30,7 +31,7 @@ use uuid::Uuid; use crate::{ details::Details, error_code::{ErrorCode, OwnedErrorCode}, - future::{LocalBoxFuture, WithMetrics, WithTimeout}, + future::{WithMetrics, WithTimeout}, serde_str::Serde, stream::LocalBoxStream, }; @@ -1477,33 +1478,29 @@ impl AliasAccessRepo for PostgresRepo { &self, timestamp: time::OffsetDateTime, ) -> Result>, RepoError> { - Ok(Box::pin(PageStream { - inner: self.inner.clone(), - future: None, - current: Vec::new(), - older_than: to_primitive(timestamp), - next: Box::new(|inner, older_than| { - Box::pin(async move { - use schema::proxies::dsl::*; + Ok(Box::pin(page_stream( + self.inner.clone(), + to_primitive(timestamp), + |inner, older_than| async move { + use schema::proxies::dsl::*; - let mut conn = inner.get_connection().await?; + let mut conn = inner.get_connection().await?; - let vec = proxies - .select((accessed, alias)) - .filter(accessed.lt(older_than)) - .order(accessed.desc()) - .limit(100) - .get_results(&mut conn) - .with_metrics("pict-rs.postgres.alias-access.older-aliases") - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)? - .map_err(PostgresError::Diesel)?; + let vec = proxies + .select((accessed, alias)) + .filter(accessed.lt(older_than)) + .order(accessed.desc()) + .limit(100) + .get_results(&mut conn) + .with_metrics("pict-rs.postgres.alias-access.older-aliases") + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; - Ok(vec) - }) - }), - })) + Ok(vec) + }, + ))) } async fn remove_alias_access(&self, _: Alias) -> Result<(), RepoError> { @@ -1570,33 +1567,29 @@ impl VariantAccessRepo for PostgresRepo { &self, timestamp: time::OffsetDateTime, ) -> Result>, RepoError> { - Ok(Box::pin(PageStream { - inner: self.inner.clone(), - future: None, - current: Vec::new(), - older_than: to_primitive(timestamp), - next: Box::new(|inner, older_than| { - Box::pin(async move { - use schema::variants::dsl::*; + Ok(Box::pin(page_stream( + self.inner.clone(), + to_primitive(timestamp), + |inner, older_than| async move { + use schema::variants::dsl::*; - let mut conn = inner.get_connection().await?; + let mut conn = inner.get_connection().await?; - let vec = variants - .select((accessed, (hash, variant))) - .filter(accessed.lt(older_than)) - .order(accessed.desc()) - .limit(100) - .get_results(&mut conn) - .with_metrics("pict-rs.postgres.variant-access.older-variants") - .with_timeout(Duration::from_secs(5)) - .await - .map_err(|_| PostgresError::DbTimeout)? - .map_err(PostgresError::Diesel)?; + let vec = variants + .select((accessed, (hash, variant))) + .filter(accessed.lt(older_than)) + .order(accessed.desc()) + .limit(100) + .get_results(&mut conn) + .with_metrics("pict-rs.postgres.variant-access.older-variants") + .with_timeout(Duration::from_secs(5)) + .await + .map_err(|_| PostgresError::DbTimeout)? + .map_err(PostgresError::Diesel)?; - Ok(vec) - }) - }), - })) + Ok(vec) + }, + ))) } async fn remove_variant_access(&self, _: Hash, _: String) -> Result<(), RepoError> { @@ -1778,62 +1771,36 @@ impl FullRepo for PostgresRepo { } } -type NextFuture = LocalBoxFuture<'static, Result, RepoError>>; - -struct PageStream { +fn page_stream( inner: Arc, - future: Option>, - current: Vec, - older_than: time::PrimitiveDateTime, - next: Box, time::PrimitiveDateTime) -> NextFuture>, -} - -impl futures_core::Stream for PageStream + mut older_than: time::PrimitiveDateTime, + next: F, +) -> impl Stream> where - I: Unpin, + F: Fn(Arc, time::PrimitiveDateTime) -> Fut, + Fut: std::future::Future, RepoError>> + + 'static, + I: 'static, { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - + streem::try_from_fn(|yielder| async move { loop { - // Pop because we reversed the list - if let Some(alias) = this.current.pop() { - return std::task::Poll::Ready(Some(Ok(alias))); - } + let mut page = (next)(inner.clone(), older_than).await?; - if let Some(future) = this.future.as_mut() { - let res = std::task::ready!(future.as_mut().poll(cx)); - - this.future.take(); - - match res { - Ok(page) if page.is_empty() => { - return std::task::Poll::Ready(None); - } - Ok(page) => { - let (mut timestamps, mut aliases): (Vec<_>, Vec<_>) = - page.into_iter().unzip(); - // reverse because we use .pop() to get next - aliases.reverse(); - - this.current = aliases; - this.older_than = timestamps.pop().expect("Verified nonempty"); - } - Err(e) => return std::task::Poll::Ready(Some(Err(e))), + if let Some((last_time, last_item)) = page.pop() { + for (_, item) in page { + yielder.yield_ok(item).await; } - } else { - let inner = this.inner.clone(); - let older_than = this.older_than; - this.future = Some((this.next)(inner, older_than)); + yielder.yield_ok(last_item).await; + + older_than = last_time; + } else { + break; } } - } + + Ok(()) + }) } impl std::fmt::Debug for PostgresRepo {