diff --git a/Cargo.lock b/Cargo.lock index 29a2da4..f5a126a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,7 +136,7 @@ dependencies = [ "tokio-rustls", "tokio-util", "tracing", - "webpki-roots", + "webpki-roots 0.25.3", ] [[package]] @@ -1936,6 +1936,7 @@ dependencies = [ "ructe", "rustls", "rustls-channel-resolver", + "rustls-pemfile", "serde", "serde_json", "sled", @@ -1951,6 +1952,7 @@ dependencies = [ "tracing-subscriber", "url", "uuid", + "webpki-roots 0.26.0", ] [[package]] @@ -2347,6 +2349,22 @@ dependencies = [ "rustls", ] +[[package]] +name = "rustls-pemfile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4" +dependencies = [ + "base64", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e9d979b3ce68192e42760c7810125eb6cf2ea10efae545a156063e61f314e2a" + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -3130,6 +3148,15 @@ version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" +[[package]] +name = "webpki-roots" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index a992e76..e5c3af8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,10 +30,11 @@ opentelemetry-otlp = "0.14" qrcodegen = "1.7" rustls = "0.21" rustls-channel-resolver = "0.1.0" +rustls-pemfile = "2.0.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sled = { version = "0.34.7", features = ["zstd"] } -tokio = { version = "1", default-features = false, features = ["sync"] } +tokio = { version = "1", default-features = false, features = ["fs", "sync"] } thiserror = "1.0" tracing = "0.1" tracing-error = "0.2" @@ -47,6 +48,7 @@ tracing-subscriber = { version = "0.3", features = [ ] } url = { version = "2.2", features = ["serde"] } uuid = { version = "1", features = ["serde", "v4"] } +webpki-roots = "0.26.0" [dependencies.tracing-actix-web] diff --git a/src/lib.rs b/src/lib.rs index 7970aca..1f55101 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,9 @@ use actix_web::{ }; use awc::Client; use clap::Parser; +use rustls::{ + sign::CertifiedKey, Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, +}; use sled::Db; use std::{ io::Cursor, @@ -43,7 +46,7 @@ pub struct Config { short, long, env = "PICTRS_AGGREGATOR_ADDR", - default_value = "0.0.0.0:8082", + default_value = "[::]:8082", help = "The address and port the server binds to" )] addr: SocketAddr, @@ -57,6 +60,15 @@ pub struct Config { )] upstream: Url, + #[arg( + short = 'H', + long, + env = "PICTRS_AGGREGATOR_HOST", + default_value = "localhost:8082", + help = "The host at which pict-rs-aggregator is accessible" + )] + host: String, + #[arg( short, long, @@ -77,9 +89,16 @@ pub struct Config { #[arg( short, + long, + env = "PICTRS_AGGREGATOR_CONSOLE_ADDRESS", + help = "The address at which to bind the tokio-console exporter" + )] + console_address: Option, + + #[arg( long, env = "PICTRS_AGGREGATOR_CONSOLE_EVENT_BUFFER_SIZE", - help = "The number of events to buffer in console. When unset, console is disabled" + help = "The number of events to buffer in console" )] console_event_buffer_size: Option, @@ -90,6 +109,62 @@ pub struct Config { help = "URL for the OpenTelemetry Colletor" )] opentelemetry_url: Option, + + #[arg( + long, + env = "PICTRS_AGGREGATOR_CERTIFICATE", + help = "The CA Certificate used to verify pict-rs' TLS certificate" + )] + certificate: Option, + + #[arg( + long, + env = "PICTRS_AGGREGATOR_SERVER_CERTIFICATE", + help = "The Certificate used to serve pict-rs-aggregator over TLS" + )] + server_certificate: Option, + + #[arg( + long, + env = "PICTRS_AGGREGATOR_SERVER_PRIVATE_KEY", + help = "The Private Key used to serve pict-rs-aggregator over TLS" + )] + server_private_key: Option, +} + +pub struct Tls { + certificate: PathBuf, + private_key: PathBuf, +} + +impl Tls { + pub fn from_config(config: &Config) -> Option { + config + .server_certificate + .as_ref() + .zip(config.server_private_key.as_ref()) + .map(|(cert, key)| Tls { + certificate: cert.clone(), + private_key: key.clone(), + }) + } + + pub async fn open_keys(&self) -> color_eyre::Result { + let cert_bytes = tokio::fs::read(&self.certificate).await?; + let key_bytes = tokio::fs::read(&self.private_key).await?; + + let certs = rustls_pemfile::certs(&mut cert_bytes.as_slice()) + .map(|res| res.map(|c| Certificate(c.to_vec()))) + .collect::, _>>()?; + + let key = rustls_pemfile::private_key(&mut key_bytes.as_slice())? + .ok_or_else(|| color_eyre::eyre::eyre!("No key in keyfile"))?; + + let signing_key = + rustls::sign::any_supported_type(&PrivateKey(Vec::from(key.secret_der())))?; + + Ok(CertifiedKey::new(certs, signing_key)) + } } pub fn accept() -> &'static str { @@ -109,6 +184,10 @@ impl Config { self.console_event_buffer_size } + pub fn console_address(&self) -> Option { + self.console_address + } + pub fn bind_address(&self) -> SocketAddr { self.addr } @@ -116,6 +195,34 @@ impl Config { pub fn opentelemetry_url(&self) -> Option<&Url> { self.opentelemetry_url.as_ref() } + + pub async fn build_rustls_client_config(&self) -> color_eyre::Result { + let mut root_store = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS + .iter() + .map(|root| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + root.subject.to_vec(), + root.subject_public_key_info.to_vec(), + root.name_constraints.as_ref().map(|n| n.to_vec()), + ) + }) + .collect(), + }; + + if let Some(certificate) = &self.certificate { + let bytes = tokio::fs::read(certificate).await?; + + for res in rustls_pemfile::certs(&mut bytes.as_slice()) { + root_store.add(&Certificate(res?.to_vec()))?; + } + } + + Ok(ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth()) + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] @@ -139,6 +246,7 @@ impl std::fmt::Display for Direction { #[derive(Clone)] pub struct State { upstream: Url, + host: String, scope: String, store: Store, startup: SystemTime, @@ -274,6 +382,7 @@ impl State { pub fn state(config: Config, scope: &str, db: Db) -> Result { Ok(State { upstream: config.upstream, + host: config.host, scope: scope.to_string(), store: Store::new(&db)?, startup: SystemTime::now(), @@ -1129,11 +1238,15 @@ async fn delete_entry( } fn qr(req: &HttpRequest, path: &web::Path, state: &web::Data) -> String { - let host = req.head().headers().get("host").unwrap(); + let host = req + .head() + .headers() + .get("host") + .and_then(|h| h.to_str().ok()) + .unwrap_or(&state.host); let url = format!( - "https://{}{}", - host.to_str().unwrap(), + "https://{host}{}", state.public_collection_path(path.collection) ); diff --git a/src/main.rs b/src/main.rs index 76f6997..c11c4bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,13 @@ use actix_web::{App, HttpServer}; -use awc::Client; +use awc::{Client, Connector}; use clap::Parser; use console_subscriber::ConsoleLayer; use opentelemetry::KeyValue; use opentelemetry_otlp::WithExportConfig; use opentelemetry_sdk::{propagation::TraceContextPropagator, Resource}; -use std::time::Duration; +use pict_rs_aggregator::Tls; +use rustls::ServerConfig; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use tracing::subscriber::set_global_default; use tracing_actix_web::TracingLogger; use tracing_awc::Tracing; @@ -22,6 +24,7 @@ async fn main() -> color_eyre::Result<()> { init_logger( config.opentelemetry_url(), + config.console_address(), config.console_event_buffer_size(), )?; @@ -32,32 +35,74 @@ async fn main() -> color_eyre::Result<()> { .path(db_path) .cache_capacity(config.sled_cache_capacity()) .open()?; + let bind_address = config.bind_address(); + + let rustls_client_config = config.build_rustls_client_config().await?; + + let tls = Tls::from_config(&config); + let state = pict_rs_aggregator::state(config, "", db)?; - tracing::info!("Launching on {bind_address}"); + let server = HttpServer::new(move || { + let connector = Connector::new().rustls_021(Arc::new(rustls_client_config.clone())); - HttpServer::new(move || { let client = Client::builder() .wrap(Tracing) .timeout(Duration::from_secs(30)) - .add_default_header(("User-Agent", "pict_rs_aggregator-v0.5.0-beta.3")) + .add_default_header(("User-Agent", "pict_rs_aggregator-v0.5.0")) .disable_redirects() + .connector(connector) .finish(); App::new() .wrap(TracingLogger::default()) .configure(|cfg| pict_rs_aggregator::configure(cfg, state.clone(), client)) - }) - .bind(bind_address)? - .run() - .await?; + }); + + if let Some(tls) = tls { + let key = tls.open_keys().await?; + + let (tx, rx) = rustls_channel_resolver::channel::<32>(key); + + let handle = actix_rt::spawn(async move { + let mut interval = actix_rt::time::interval(Duration::from_secs(30)); + interval.tick().await; + + loop { + interval.tick().await; + + match tls.open_keys().await { + Ok(key) => tx.update(key), + Err(e) => tracing::error!("Failed to open keys for TLS {e}"), + } + } + }); + + let server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(rx); + + tracing::info!("Serving pict-rs-aggregator over TLS on {bind_address}"); + server + .bind_rustls_021(bind_address, server_config)? + .run() + .await?; + + handle.abort(); + let _ = handle.await; + } else { + tracing::info!("Serving pict-rs-aggregator on {bind_address}"); + server.bind(bind_address)?.run().await?; + } Ok(()) } fn init_logger( opentelemetry_url: Option<&Url>, + console_addr: Option, console_event_buffer_size: Option, ) -> color_eyre::Result<()> { color_eyre::install()?; @@ -76,19 +121,24 @@ fn init_logger( .with(format_layer) .with(ErrorLayer::default()); - if let Some(buffer_size) = console_event_buffer_size { - let console_layer = ConsoleLayer::builder() - .with_default_env() - .server_addr(([0, 0, 0, 0], 6669)) - .event_buffer_capacity(buffer_size) - .spawn(); + if let Some(addr) = console_addr { + let builder = ConsoleLayer::builder().with_default_env().server_addr(addr); + + let console_layer = if let Some(buffer_size) = console_event_buffer_size { + builder.event_buffer_capacity(buffer_size).spawn() + } else { + builder.spawn() + }; let subscriber = subscriber.with(console_layer); - init_subscriber(subscriber, targets, opentelemetry_url) + init_subscriber(subscriber, targets, opentelemetry_url)?; + tracing::info!("Serving tokio-console endpoint on {addr}"); } else { - init_subscriber(subscriber, targets, opentelemetry_url) + init_subscriber(subscriber, targets, opentelemetry_url)?; } + + Ok(()) } fn init_subscriber(