pict-rs-admin/src/main.rs
asonix 122c17424c
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/tag Build is passing
Support TLS upstream & downstream
2024-02-01 22:12:33 -06:00

489 lines
14 KiB
Rust

use std::{net::SocketAddr, path::PathBuf, time::Duration};
use actix_web::{
body::BodyStream, error::ErrorInternalServerError, web, App, HttpResponse, HttpServer,
};
use clap::Parser;
use console_subscriber::ConsoleLayer;
use opentelemetry::KeyValue;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{propagation::TraceContextPropagator, Resource};
use reqwest::{redirect::Policy, Client};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::TracingMiddleware;
use rustls_021::ServerConfig;
use tracing_actix_web::TracingLogger;
use tracing_error::ErrorLayer;
use tracing_subscriber::{
filter::Targets, layer::SubscriberExt, registry::LookupSpan, Layer, Registry,
};
use url::Url;
#[derive(Debug, Parser)]
struct Args {
#[clap(
short,
long,
env = "PICTRS_ADMIN__BIND_ADDRESS",
default_value = "[::]:8084"
)]
bind_address: SocketAddr,
#[clap(
long,
env = "PICTRS_ADMIN__PICTRS_ENDPOINT",
default_value = "http://localhost:8080"
)]
pict_rs_endpoint: Url,
#[clap(long, env = "PICTRS_ADMIN__PICTRS_API_KEY")]
pict_rs_api_key: String,
#[clap(long, env = "PICTRS_ADMIN__OPENTELEMETRY_URL")]
opentelemetry_url: Option<Url>,
#[clap(long, env = "PICTRS_ADMIN__CONSOLE_ADDRESS")]
console_address: Option<SocketAddr>,
#[clap(long, env = "PICTRS_ADMIN__CONSOLE_EVENT_BUFFER_SIZE")]
console_event_buffer_size: Option<usize>,
#[clap(long, env = "PICTRS_ADMIN__CERTIFICATE")]
pict_rs_certificate: Option<PathBuf>,
#[clap(long, env = "PICTRS_ADMIN__SERVER_CERTIFICATE")]
server_certificate: Option<PathBuf>,
#[clap(long, env = "PICTRS_ADMIN__SERVER_PRIVATE_KEY")]
server_private_key: Option<PathBuf>,
}
#[derive(Clone)]
struct PictrsClient {
client: ClientWithMiddleware,
pict_rs_endpoint: Url,
pict_rs_api_key: String,
}
#[derive(Clone, Copy, Debug, serde::Deserialize)]
enum OkMessage {
#[serde(rename = "ok")]
Ok,
}
#[derive(Clone, Debug, serde::Deserialize)]
struct PictrsDetails {
width: u16,
height: u16,
frames: Option<u32>,
content_type: String,
created_at: String,
}
#[derive(Clone, Debug, serde::Deserialize)]
struct PictrsHash {
hex: String,
aliases: Vec<String>,
details: Option<PictrsDetails>,
}
#[derive(Clone, Debug, serde::Deserialize)]
pub struct PictrsPage {
#[allow(dead_code)]
limit: usize,
current: Option<String>,
prev: Option<String>,
next: Option<String>,
hashes: Vec<PictrsHash>,
}
#[derive(Clone, Debug, serde::Deserialize)]
#[serde(untagged)]
enum PageResponse {
Ok {
#[allow(dead_code)]
msg: OkMessage,
page: PictrsPage,
},
Err {
msg: String,
},
}
#[derive(Debug, serde::Serialize)]
struct AliasQuery<'a> {
alias: &'a str,
}
impl PictrsHash {
fn media_link(&self) -> Option<String> {
self.aliases.first().map(|alias| format!("/image/{alias}"))
}
fn video_type(&self) -> Option<String> {
self.details.as_ref().and_then(|d| {
if d.content_type.starts_with("video") {
Some(d.content_type.clone())
} else {
None
}
})
}
}
impl PictrsPage {
fn prev_link(&self) -> Option<String> {
self.prev.as_ref().map(|slug| format!("/?slug={slug}"))
}
fn next_link(&self) -> Option<String> {
self.next.as_ref().map(|slug| format!("/?slug={slug}"))
}
fn purge_link(&self, hash: &PictrsHash) -> Option<String> {
hash.aliases.first().map(|alias| {
if let Some(slug) = &self.current {
format!("/purge/{alias}?slug={slug}")
} else {
format!("/purge/{alias}")
}
})
}
}
impl PictrsClient {
async fn page(&self, query: &PageQuery) -> Result<PageResponse, reqwest_middleware::Error> {
let mut url = self.pict_rs_endpoint.clone();
url.set_path("/internal/hashes");
let response = self
.client
.get(url.as_str())
.header("x-api-token", &self.pict_rs_api_key)
.query(query)
.send()
.await?;
response.json().await.map_err(From::from)
}
async fn purge(&self, alias: &str) -> Result<(), reqwest_middleware::Error> {
let mut url = self.pict_rs_endpoint.clone();
url.set_path("/internal/purge");
let _ = self
.client
.post(url.as_str())
.header("x-api-token", &self.pict_rs_api_key)
.query(&AliasQuery { alias })
.send()
.await?;
Ok(())
}
async fn proxy_image(&self, alias: &str) -> Result<HttpResponse, reqwest_middleware::Error> {
let mut url = self.pict_rs_endpoint.clone();
url.set_path(&format!("/image/original/{alias}"));
let response = self.client.get(url.as_str()).send().await?;
let mut client_res = HttpResponse::build(response.status());
for (name, value) in response
.headers()
.iter()
.filter(|(h, _)| *h != "connection")
{
client_res.insert_header((name.clone(), value.clone()));
}
Ok(client_res.body(BodyStream::new(response.bytes_stream())))
}
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct PageQuery {
slug: Option<String>,
timestamp: Option<String>,
}
async fn index(
web::Query(query): web::Query<PageQuery>,
client: web::Data<PictrsClient>,
) -> Result<HttpResponse, actix_web::Error> {
let page = client
.page(&query)
.await
.map_err(ErrorInternalServerError)?;
let page = match page {
PageResponse::Ok { page, .. } => page,
PageResponse::Err { msg } => return Err(ErrorInternalServerError(msg)),
};
let mut buf = Vec::new();
templates::index_html(&mut buf, &page).map_err(ErrorInternalServerError)?;
let body = minify_html::minify(&buf, &minify_html::Cfg::spec_compliant());
Ok(HttpResponse::Ok().content_type("text/html").body(body))
}
async fn image(
alias: web::Path<String>,
client: web::Data<PictrsClient>,
) -> Result<HttpResponse, actix_web::Error> {
client
.proxy_image(&alias)
.await
.map_err(ErrorInternalServerError)
}
#[derive(Debug, serde::Deserialize)]
struct ConfirmQuery {
confirm: Option<u8>,
slug: Option<String>,
}
async fn purge(
alias: web::Path<String>,
client: web::Data<PictrsClient>,
web::Query(ConfirmQuery { confirm, slug }): web::Query<ConfirmQuery>,
) -> Result<HttpResponse, actix_web::Error> {
let return_link = slug
.as_ref()
.map(|s| format!("/?slug={s}"))
.unwrap_or_else(|| String::from("/"));
if confirm.is_some() {
client
.purge(&alias)
.await
.map_err(ErrorInternalServerError)?;
return Ok(HttpResponse::SeeOther()
.insert_header(("location", return_link))
.finish());
}
let purge_link = slug
.map(|s| format!("/purge/{alias}?&slug={s}&confirm=1"))
.unwrap_or_else(|| format!("/purge/{alias}?confirm=1"));
let image_link = format!("/image/{alias}");
let mut buf = Vec::new();
templates::purge_html(&mut buf, &image_link, &purge_link, &return_link)
.map_err(ErrorInternalServerError)?;
let body = minify_html::minify(&buf, &minify_html::Cfg::spec_compliant());
Ok(HttpResponse::Ok().content_type("text/html").body(body))
}
async fn serve_static(name: web::Path<String>) -> HttpResponse {
if let Some(data) = templates::statics::StaticFile::get(&name) {
HttpResponse::Ok().body(data.content)
} else {
HttpResponse::NotFound().finish()
}
}
fn init_tracing(
service_name: &'static str,
opentelemetry_url: Option<&Url>,
console_address: Option<SocketAddr>,
console_event_buffer_size: Option<usize>,
) -> color_eyre::Result<()> {
opentelemetry::global::set_text_map_propagator(TraceContextPropagator::new());
tracing_log::LogTracer::init()?;
let targets: Targets = std::env::var("RUST_LOG")
.unwrap_or_else(|_| "info".into())
.parse()?;
let format_layer = tracing_subscriber::fmt::layer().with_filter(targets.clone());
let subscriber = Registry::default()
.with(format_layer)
.with(ErrorLayer::default());
if let Some(addr) = console_address {
let builder = ConsoleLayer::builder().with_default_env().server_addr(addr);
let console_layer = if let Some(buffer_size) = console_event_buffer_size {
builder.event_buffer_capacity(buffer_size).spawn()
} else {
builder.spawn()
};
let subscriber = subscriber.with(console_layer);
init_subscriber(subscriber, targets, opentelemetry_url, service_name)?;
tracing::info!("Starting console on {addr}");
} else {
init_subscriber(subscriber, targets, opentelemetry_url, service_name)?;
}
Ok(())
}
fn init_subscriber<S>(
subscriber: S,
targets: Targets,
opentelemetry_url: Option<&Url>,
service_name: &'static str,
) -> color_eyre::Result<()>
where
S: SubscriberExt + Send + Sync,
for<'a> S: LookupSpan<'a>,
{
if let Some(url) = opentelemetry_url {
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_trace_config(
opentelemetry_sdk::trace::config().with_resource(Resource::new(vec![
KeyValue::new("service.name", service_name),
])),
)
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(url.as_str()),
)
.install_batch(opentelemetry_sdk::runtime::Tokio)?;
let otel_layer = tracing_opentelemetry::layer()
.with_tracer(tracer)
.with_filter(targets);
let subscriber = subscriber.with(otel_layer);
tracing::subscriber::set_global_default(subscriber)?;
} else {
tracing::subscriber::set_global_default(subscriber)?;
}
Ok(())
}
async fn add_root_certificates(
builder: reqwest::ClientBuilder,
path: &PathBuf,
) -> color_eyre::Result<reqwest::ClientBuilder> {
let bytes = tokio::fs::read(path).await?;
let res = rustls_pemfile::certs(&mut bytes.as_slice()).try_fold(builder, |builder, res| {
let cert = res?;
let cert = reqwest::Certificate::from_der(&cert)?;
Ok(builder.add_root_certificate(cert))
});
res
}
async fn open_keys(
certificate: &PathBuf,
private_key: &PathBuf,
) -> color_eyre::Result<rustls_021::sign::CertifiedKey> {
let cert_bytes = tokio::fs::read(certificate).await?;
let key_bytes = tokio::fs::read(private_key).await?;
let certs = rustls_pemfile::certs(&mut cert_bytes.as_slice())
.map(|res| res.map(|c| rustls_021::Certificate(c.to_vec())))
.collect::<Result<Vec<_>, _>>()?;
let key = rustls_pemfile::private_key(&mut key_bytes.as_slice())?
.ok_or_else(|| color_eyre::eyre::eyre!("No key in keyfile"))?;
let signing_key =
rustls_021::sign::any_supported_type(&rustls_021::PrivateKey(key.secret_der().to_vec()))?;
Ok(rustls_021::sign::CertifiedKey::new(certs, signing_key))
}
#[actix_web::main]
async fn main() -> color_eyre::Result<()> {
let Args {
bind_address,
pict_rs_endpoint,
pict_rs_api_key,
opentelemetry_url,
console_address,
console_event_buffer_size,
pict_rs_certificate,
server_certificate,
server_private_key,
} = Args::parse();
init_tracing(
"pict-rs-admin",
opentelemetry_url.as_ref(),
console_address,
console_event_buffer_size,
)?;
let builder = Client::builder()
.user_agent("pict-rs-admin v0.2.0")
.redirect(Policy::none());
let client = if let Some(path) = pict_rs_certificate {
add_root_certificates(builder, &path).await?.build()?
} else {
builder.build()?
};
let client = ClientBuilder::new(client)
.with(TracingMiddleware::default())
.build();
let client = PictrsClient {
client,
pict_rs_endpoint,
pict_rs_api_key,
};
let server = HttpServer::new(move || {
App::new()
.wrap(TracingLogger::default())
.app_data(web::Data::new(client.clone()))
.route("/", web::get().to(index))
.route("/image/{path}", web::get().to(image))
.route("/purge/{path}", web::get().to(purge))
.route("/static/{path}", web::get().to(serve_static))
});
if let Some((cert, key)) = server_certificate.zip(server_private_key) {
let certified_key = open_keys(&cert, &key).await?;
let (tx, rx) = rustls_channel_resolver::channel::<32>(certified_key);
let handle = actix_web::rt::spawn(async move {
let mut interval = actix_web::rt::time::interval(Duration::from_secs(30));
interval.tick().await;
loop {
interval.tick().await;
match open_keys(&cert, &key).await {
Ok(certified_key) => tx.update(certified_key),
Err(e) => tracing::error!("Failed to read TLS keys {e}"),
}
}
});
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(rx);
tracing::info!("Starting pict-rs-admin with TLS on {bind_address}");
server
.bind_rustls_021(bind_address, server_config)?
.run()
.await?;
handle.abort();
let _ = handle.await;
} else {
tracing::info!("Starting pict-rs-admin on {bind_address}");
server.bind(bind_address)?.run().await?;
}
Ok(())
}
include!(concat!(env!("OUT_DIR"), "/templates.rs"));