pict-rs-uploader/src/main.rs
2023-10-03 15:14:31 -05:00

345 lines
8.8 KiB
Rust

mod pict_rs;
use std::{
fs::FileType,
path::{Path, PathBuf},
rc::Rc,
time::Duration,
};
use awc::{
http::{header::CONTENT_TYPE, StatusCode},
Client,
};
use clap::Parser;
use eyre::ErrReport;
use multipart_client_stream::{Body, Part};
use pict_rs::{Image, ImageResponse, Upload, UploadReponse};
use tokio::{
sync::{
mpsc::{Receiver, Sender},
Semaphore,
},
task::JoinSet,
};
use tracing::Instrument;
use tracing_subscriber::{filter::Targets, layer::SubscriberExt, Layer, Registry};
use url::Url;
#[actix_rt::main]
async fn main() -> color_eyre::Result<()> {
let Args {
endpoint,
ingest,
out,
} = Args::parse();
init_tracing()?;
let (state, mut receiver) = State::new(endpoint);
let file_type = tokio::fs::metadata(&ingest).await?.file_type();
let handle = actix_rt::spawn(async move { state.visit_path(&ingest, file_type).await });
let mut images = Vec::new();
while let Some(image) = receiver.recv().await {
images.push(image);
}
let json = serde_json::to_vec(&images)?;
tokio::fs::write(out, json).await?;
handle.await??;
Ok(())
}
#[derive(Debug, clap::Parser)]
struct Args {
#[clap(short, long)]
endpoint: Url,
#[clap(short, long)]
ingest: PathBuf,
#[clap(short, long)]
out: PathBuf,
}
#[derive(Clone)]
struct State {
inner: Rc<StateInner>,
}
struct StateInner {
endpoint: Url,
client: Client,
semaphore: Semaphore,
sender: Sender<Image>,
}
#[derive(Debug, Default)]
struct MultiError {
errors: Vec<ErrReport>,
}
impl State {
fn new(endpoint: Url) -> (Self, Receiver<Image>) {
let (sender, receiver) = tokio::sync::mpsc::channel(8);
let this = Self {
inner: Rc::new(StateInner {
endpoint,
client: Client::builder()
.add_default_header(("User-Agent", "pict-rs-uploader v0.1.0"))
.finish(),
semaphore: Semaphore::new(8),
sender,
}),
};
(this, receiver)
}
async fn visit_path(&self, path: &Path, file_type: FileType) -> color_eyre::Result<()> {
if file_type.is_file() {
self.visit_file(path).await?;
} else if file_type.is_dir() {
self.visit_dir(path).await?;
}
Ok(())
}
async fn visit_file(&self, path: &Path) -> color_eyre::Result<()> {
let Some(image_response) = self.inner.upload(path).await? else {
return Ok(());
};
match image_response {
ImageResponse::Ok { files, .. } => {
for image in files {
self.inner
.sender
.send(image)
.await
.expect("receiver shouldn't be dead");
}
}
ImageResponse::Error { msg, code } => {
tracing::warn!("Upload for {path:?} failed with code {code}, message\n{msg}");
}
}
Ok(())
}
#[tracing::instrument(skip_all)]
async fn visit_dir(&self, path: &Path) -> color_eyre::Result<()> {
let mut read_dir = tokio::fs::read_dir(path).await?;
let mut set = JoinSet::new();
let mut errors = MultiError::new();
while let Some(entry) = read_dir.next_entry().await? {
if set.len() > 4 {
if let Some(res) = set.join_next().await {
match res {
Ok(Err(e)) => {
errors.push(e);
break;
}
Err(e) => {
errors.push(e.into());
break;
}
_ => {}
}
}
}
let this = self.clone();
set.spawn_local(
async move {
let file_type = entry.file_type().await?;
let path = entry.path();
this.visit_path(&path, file_type).await
}
.instrument(tracing::info_span!("spawned")),
);
}
while let Some(res) = set.join_next().await {
match res {
Ok(Err(e)) => errors.push(e),
Err(e) => errors.push(e.into()),
_ => {}
}
}
if !errors.is_empty() {
return Err(errors.into());
}
Ok(())
}
}
impl StateInner {
#[tracing::instrument(skip(self))]
async fn upload(&self, path: &Path) -> color_eyre::Result<Option<ImageResponse>> {
let guard = self.semaphore.acquire().await?;
let filename = path
.file_name()
.ok_or(eyre::eyre!("Path does not have filename"))?
.to_str()
.ok_or(eyre::eyre!("Filename is not valid utf8"))?;
let mut retries = 0;
let mut response = loop {
let file = tokio::fs::File::open(path).await?;
let body = Body::builder()
.append(Part::new("images[]", file).filename(filename))
.build();
let res = self
.client
.post(self.upload_endpoint())
.insert_header((CONTENT_TYPE, body.content_type()))
.timeout(Duration::from_secs(20))
.send_stream(body)
.await;
match res {
Ok(response) => break response,
Err(e) if retries < 10 => {
retries += 1;
tracing::warn!("Failed upload with {e}, retrying +{retries}");
}
Err(e) => {
return Err(eyre::eyre!("Failed upload with {e}"));
}
}
};
let response: UploadReponse = response.json().await?;
let mut uploads = match response {
UploadReponse::Error { msg, code } => {
tracing::warn!("Upload failed with code {code}, message\n{msg}");
return Ok(None);
}
UploadReponse::Ok { uploads, .. } => uploads,
};
let upload = uploads
.pop()
.ok_or(eyre::eyre!("Expected one upload in response"))?;
let response = self.long_poll(&upload).await?;
drop(guard);
Ok(Some(response))
}
async fn long_poll(&self, upload: &Upload) -> color_eyre::Result<ImageResponse> {
let claim_endpoint = self.claim_endpoint();
let mut response = loop {
let mut retries = 0;
let response = loop {
let res = self
.client
.get(&claim_endpoint)
.query(upload)?
.timeout(Duration::from_secs(15))
.send()
.await;
match res {
Ok(response) => break response,
Err(e) if retries < 10 => {
retries += 1;
tracing::warn!("Failed claim with {e}, retrying +{retries}");
}
Err(e) => {
return Err(eyre::eyre!("Failed claim with {e}"));
}
}
};
if response.status() != StatusCode::NO_CONTENT {
break response;
}
};
let response: ImageResponse = response.json().await?;
Ok(response)
}
fn upload_endpoint(&self) -> String {
let mut url: Url = self.endpoint.clone();
url.set_path("/image/backgrounded");
url.to_string()
}
fn claim_endpoint(&self) -> String {
let mut url: Url = self.endpoint.clone();
url.set_path("/image/backgrounded/claim");
url.to_string()
}
}
fn init_tracing() -> color_eyre::Result<()> {
color_eyre::install()?;
let targets: Targets = std::env::var("RUST_LOG")
.unwrap_or_else(|_| "info".into())
.parse()?;
let format_layer = tracing_subscriber::fmt::layer().with_filter(targets);
let subscriber = Registry::default().with(format_layer);
tracing::subscriber::set_global_default(subscriber)?;
Ok(())
}
impl MultiError {
fn new() -> Self {
Self::default()
}
fn push(&mut self, report: ErrReport) {
self.errors.push(report);
}
fn is_empty(&self) -> bool {
self.errors.is_empty()
}
}
impl std::fmt::Display for MultiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut first = true;
for e in &self.errors {
if !first {
writeln!(f)?;
}
first = false;
e.fmt(f)?;
}
Ok(())
}
}
impl std::error::Error for MultiError {}