diff --git a/Cargo.lock b/Cargo.lock index 5a6568e..a4729c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1812,6 +1812,7 @@ dependencies = [ "rustls-channel-resolver", "rustls-pemfile 2.0.0", "serde", + "tokio", "tracing", "tracing-actix-web", "tracing-error", diff --git a/Cargo.toml b/Cargo.toml index d7b2df8..2f5bb3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ rustls = "0.22" rustls-channel-resolver = "0.1.0" rustls-pemfile = "2.0.0" serde = { version = "1.0.188", features = ["derive"] } +tokio = { version = "1", features = ["fs"] } tracing = "0.1.37" tracing-actix-web = { version = "0.7.6", features = ["opentelemetry_0_21", "emit_event_on_error"] } tracing-error = "0.2.0" diff --git a/src/main.rs b/src/main.rs index a0bd4e2..34b6cb1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, path::PathBuf, time::Duration}; use actix_web::{ body::BodyStream, error::ErrorInternalServerError, web, App, HttpResponse, HttpServer, @@ -11,6 +11,7 @@ use opentelemetry_sdk::{propagation::TraceContextPropagator, Resource}; use reqwest::{redirect::Policy, Client}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest_tracing::TracingMiddleware; +use rustls_021::ServerConfig; use tracing_actix_web::TracingLogger; use tracing_error::ErrorLayer; use tracing_subscriber::{ @@ -24,7 +25,7 @@ struct Args { short, long, env = "PICTRS_ADMIN__BIND_ADDRESS", - default_value = "127.0.0.1:8084" + default_value = "[::]:8084" )] bind_address: SocketAddr, #[clap( @@ -41,6 +42,12 @@ struct Args { console_address: Option, #[clap(long, env = "PICTRS_ADMIN__CONSOLE_EVENT_BUFFER_SIZE")] console_event_buffer_size: Option, + #[clap(long, env = "PICTRS_ADMIN__CERTIFICATE")] + pict_rs_certificate: Option, + #[clap(long, env = "PICTRS_ADMIN__SERVER_CERTIFICATE")] + server_certificate: Option, + #[clap(long, env = "PICTRS_ADMIN__SERVER_PRIVATE_KEY")] + server_private_key: Option, } #[derive(Clone)] @@ -276,7 +283,7 @@ async fn serve_static(name: web::Path) -> HttpResponse { fn init_tracing( service_name: &'static str, opentelemetry_url: Option<&Url>, - console_addr: Option, + console_address: Option, console_event_buffer_size: Option, ) -> color_eyre::Result<()> { opentelemetry::global::set_text_map_propagator(TraceContextPropagator::new()); @@ -293,7 +300,7 @@ fn init_tracing( .with(format_layer) .with(ErrorLayer::default()); - if let Some(addr) = console_addr { + if let Some(addr) = console_address { let builder = ConsoleLayer::builder().with_default_env().server_addr(addr); let console_layer = if let Some(buffer_size) = console_event_buffer_size { @@ -304,10 +311,12 @@ fn init_tracing( let subscriber = subscriber.with(console_layer); - init_subscriber(subscriber, targets, opentelemetry_url, service_name) + init_subscriber(subscriber, targets, opentelemetry_url, service_name)?; + tracing::info!("Starting console on {addr}"); } else { - init_subscriber(subscriber, targets, opentelemetry_url, service_name) + init_subscriber(subscriber, targets, opentelemetry_url, service_name)?; } + Ok(()) } fn init_subscriber( @@ -349,6 +358,42 @@ where Ok(()) } +async fn add_root_certificates( + builder: reqwest::ClientBuilder, + path: &PathBuf, +) -> color_eyre::Result { + let bytes = tokio::fs::read(path).await?; + + let res = rustls_pemfile::certs(&mut bytes.as_slice()).try_fold(builder, |builder, res| { + let cert = res?; + let cert = reqwest::Certificate::from_der(&cert)?; + + Ok(builder.add_root_certificate(cert)) + }); + + res +} + +async fn open_keys( + certificate: &PathBuf, + private_key: &PathBuf, +) -> color_eyre::Result { + let cert_bytes = tokio::fs::read(certificate).await?; + let key_bytes = tokio::fs::read(private_key).await?; + + let certs = rustls_pemfile::certs(&mut cert_bytes.as_slice()) + .map(|res| res.map(|c| rustls_021::Certificate(c.to_vec()))) + .collect::, _>>()?; + + let key = rustls_pemfile::private_key(&mut key_bytes.as_slice())? + .ok_or_else(|| color_eyre::eyre::eyre!("No key in keyfile"))?; + + let signing_key = + rustls_021::sign::any_supported_type(&rustls_021::PrivateKey(key.secret_der().to_vec()))?; + + Ok(rustls_021::sign::CertifiedKey::new(certs, signing_key)) +} + #[actix_web::main] async fn main() -> color_eyre::Result<()> { let Args { @@ -358,6 +403,9 @@ async fn main() -> color_eyre::Result<()> { opentelemetry_url, console_address, console_event_buffer_size, + pict_rs_certificate, + server_certificate, + server_private_key, } = Args::parse(); init_tracing( @@ -367,10 +415,15 @@ async fn main() -> color_eyre::Result<()> { console_event_buffer_size, )?; - let client = Client::builder() - .user_agent("pict-rs-admin v0.1.0") - .redirect(Policy::none()) - .build()?; + let builder = Client::builder() + .user_agent("pict-rs-admin v0.2.0") + .redirect(Policy::none()); + + let client = if let Some(path) = pict_rs_certificate { + add_root_certificates(builder, &path).await?.build()? + } else { + builder.build()? + }; let client = ClientBuilder::new(client) .with(TracingMiddleware::default()) @@ -382,7 +435,7 @@ async fn main() -> color_eyre::Result<()> { pict_rs_api_key, }; - HttpServer::new(move || { + let server = HttpServer::new(move || { App::new() .wrap(TracingLogger::default()) .app_data(web::Data::new(client.clone())) @@ -390,10 +443,44 @@ async fn main() -> color_eyre::Result<()> { .route("/image/{path}", web::get().to(image)) .route("/purge/{path}", web::get().to(purge)) .route("/static/{path}", web::get().to(serve_static)) - }) - .bind(bind_address)? - .run() - .await?; + }); + + if let Some((cert, key)) = server_certificate.zip(server_private_key) { + let certified_key = open_keys(&cert, &key).await?; + + let (tx, rx) = rustls_channel_resolver::channel::<32>(certified_key); + + let handle = actix_web::rt::spawn(async move { + let mut interval = actix_web::rt::time::interval(Duration::from_secs(30)); + interval.tick().await; + + loop { + interval.tick().await; + + match open_keys(&cert, &key).await { + Ok(certified_key) => tx.update(certified_key), + Err(e) => tracing::error!("Failed to read TLS keys {e}"), + } + } + }); + + let server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(rx); + + tracing::info!("Starting pict-rs-admin with TLS on {bind_address}"); + server + .bind_rustls_021(bind_address, server_config)? + .run() + .await?; + + handle.abort(); + let _ = handle.await; + } else { + tracing::info!("Starting pict-rs-admin on {bind_address}"); + server.bind(bind_address)?.run().await?; + } Ok(()) }