From 19147e2035194f99cadf420b1cb6f0a170f4f23a Mon Sep 17 00:00:00 2001 From: asonix Date: Mon, 15 Jan 2024 18:11:08 -0500 Subject: [PATCH] postgres: allow connecting to TLS-enabled databases --- Cargo.lock | 28 +++++++- Cargo.toml | 6 ++ src/config/commandline.rs | 8 +++ src/config/defaults.rs | 6 +- src/config/file.rs | 2 + src/repo.rs | 9 ++- src/repo/postgres.rs | 144 ++++++++++++++++++++++++++++++++++---- 7 files changed, 184 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d235fd0..8a6cd63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1838,6 +1838,8 @@ dependencies = [ "reqwest", "reqwest-middleware", "reqwest-tracing", + "rustls 0.22.2", + "rustls-pemfile 2.0.0", "rusty-s3", "serde", "serde-tuple-vec-map", @@ -1864,6 +1866,7 @@ dependencies = [ "tracing-subscriber", "url", "uuid", + "webpki-roots 0.26.0", ] [[package]] @@ -2233,7 +2236,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "rustls 0.21.10", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", @@ -2247,7 +2250,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 0.25.3", "winreg", ] @@ -2370,6 +2373,8 @@ version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" dependencies = [ + "log", + "ring 0.17.7", "rustls-pki-types", "rustls-webpki 0.102.1", "subtle", @@ -2385,6 +2390,16 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pemfile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4" +dependencies = [ + "base64 0.21.7", + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.1.0" @@ -3522,6 +3537,15 @@ version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" +[[package]] +name = "webpki-roots" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "whoami" version = "1.4.1" diff --git a/Cargo.toml b/Cargo.toml index 7d6fd71..f97f426 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,10 @@ refinery = { version = "0.8.10", features = ["tokio-postgres", "postgres"] } reqwest = { version = "0.11.18", default-features = false, features = ["json", "rustls-tls", "stream"] } reqwest-middleware = "0.2.2" reqwest-tracing = { version = "0.4.5" } +# pinned to tokio-postgres-rustls +rustls = "0.22.0" +# pinned to rustls +rustls-pemfile = "2.0.0" rusty-s3 = "0.5.0" serde = { version = "1.0", features = ["derive"] } serde-tuple-vec-map = "1.0.1" @@ -81,6 +85,8 @@ tracing-subscriber = { version = "0.3.0", features = [ ] } url = { version = "2.2", features = ["serde"] } uuid = { version = "1", features = ["serde", "std", "v4", "v7"] } +# pinned to rustls +webpki-roots = "0.26.0" [dependencies.tracing-actix-web] version = "0.7.8" diff --git a/src/config/commandline.rs b/src/config/commandline.rs index f929d3e..4077604 100644 --- a/src/config/commandline.rs +++ b/src/config/commandline.rs @@ -1415,6 +1415,14 @@ pub(super) struct Postgres { /// The URL of the postgres database #[arg(short, long)] pub(super) url: Url, + + /// whether to connect to postgres via TLS + #[arg(short, long)] + pub(super) use_tls: bool, + + /// The path to the root certificate for postgres' CA + #[arg(short, long)] + pub(super) certificate_file: Option, } #[derive(Debug, Parser, serde::Serialize)] diff --git a/src/config/defaults.rs b/src/config/defaults.rs index a56489a..8a67fda 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -389,7 +389,11 @@ impl From for crate::config::file::Sled { impl From for crate::config::file::Postgres { fn from(value: crate::config::commandline::Postgres) -> Self { - crate::config::file::Postgres { url: value.url } + crate::config::file::Postgres { + url: value.url, + use_tls: value.use_tls, + certificate_file: value.certificate_file, + } } } diff --git a/src/config/file.rs b/src/config/file.rs index bb85d95..e75271f 100644 --- a/src/config/file.rs +++ b/src/config/file.rs @@ -458,4 +458,6 @@ pub(crate) struct Sled { #[serde(rename_all = "snake_case")] pub(crate) struct Postgres { pub(crate) url: Url, + pub(crate) use_tls: bool, + pub(crate) certificate_file: Option, } diff --git a/src/repo.rs b/src/repo.rs index eaada7a..acff51b 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -802,8 +802,13 @@ impl Repo { Ok(Self::Sled(repo)) } - config::Repo::Postgres(config::Postgres { url }) => { - let repo = self::postgres::PostgresRepo::connect(url).await?; + config::Repo::Postgres(config::Postgres { + url, + use_tls, + certificate_file, + }) => { + let repo = + self::postgres::PostgresRepo::connect(url, use_tls, certificate_file).await?; Ok(Self::Postgres(repo)) } diff --git a/src/repo/postgres.rs b/src/repo/postgres.rs index fa5a5ab..7fd92dd 100644 --- a/src/repo/postgres.rs +++ b/src/repo/postgres.rs @@ -4,6 +4,7 @@ mod schema; use std::{ collections::{BTreeSet, VecDeque}, + path::PathBuf, sync::{ atomic::{AtomicU64, Ordering}, Arc, Weak, @@ -22,7 +23,8 @@ use diesel_async::{ }; use futures_core::Stream; use tokio::sync::Notify; -use tokio_postgres::{tls::NoTlsStream, AsyncMessage, Connection, NoTls, Notification, Socket}; +use tokio_postgres::{AsyncMessage, Connection, NoTls, Notification, Socket}; +use tokio_postgres_rustls::MakeRustlsConnect; use tracing::Instrument; use url::Url; use uuid::Uuid; @@ -82,6 +84,9 @@ pub(crate) enum ConnectPostgresError { #[error("Failed to connect to postgres for migrations")] ConnectForMigration(#[source] tokio_postgres::Error), + #[error("Failed to build TLS configuration")] + Tls(#[source] TlsError), + #[error("Failed to run migrations")] Migration(#[source] refinery::Error), @@ -116,6 +121,21 @@ pub(crate) enum PostgresError { DbTimeout, } +#[derive(Debug, thiserror::Error)] +pub(crate) enum TlsError { + #[error("Couldn't read configured certificate file")] + ReadCertificate(#[source] std::io::Error), + + #[error("Couldn't parse configured certificate file: {0:?}")] + ParseCertificate(rustls_pemfile::Error), + + #[error("Configured certificate file is not a certificate")] + NotCertificate, + + #[error("Couldn't add certificate to root store")] + AddCertificate(#[source] rustls::Error), +} + impl PostgresError { pub(super) const fn error_code(&self) -> ErrorCode { match self { @@ -146,13 +166,90 @@ impl PostgresError { } } -impl PostgresRepo { - pub(crate) async fn connect(postgres_url: Url) -> Result { - let (mut client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls) +async fn build_tls_connector( + certificate_file: Option, +) -> Result { + let mut cert_store = rustls::RootCertStore { + roots: Vec::from(webpki_roots::TLS_SERVER_ROOTS), + }; + + if let Some(certificate_file) = certificate_file { + let bytes = tokio::fs::read(certificate_file) + .await + .map_err(TlsError::ReadCertificate)?; + + let opt = + rustls_pemfile::read_one_from_slice(&bytes).map_err(TlsError::ParseCertificate)?; + let (item, _remainder) = opt.ok_or(TlsError::NotCertificate)?; + + let cert = if let rustls_pemfile::Item::X509Certificate(cert) = item { + cert + } else { + return Err(TlsError::NotCertificate); + }; + + cert_store.add(cert).map_err(TlsError::AddCertificate)?; + } + + let config = rustls::ClientConfig::builder() + .with_root_certificates(cert_store) + .with_no_client_auth(); + + let tls = MakeRustlsConnect::new(config); + + Ok(tls) +} + +async fn connect_for_migrations( + postgres_url: &Url, + tls_connector: Option, +) -> Result< + ( + tokio_postgres::Client, + DropHandle>, + ), + ConnectPostgresError, +> { + let tup = if let Some(connector) = tls_connector { + let (client, conn) = tokio_postgres::connect(postgres_url.as_str(), connector) .await .map_err(ConnectPostgresError::ConnectForMigration)?; - let handle = crate::sync::abort_on_drop(crate::sync::spawn("postgres-migrations", conn)); + ( + client, + crate::sync::abort_on_drop(crate::sync::spawn("postgres-connection", conn)), + ) + } else { + let (client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls) + .await + .map_err(ConnectPostgresError::ConnectForMigration)?; + + ( + client, + crate::sync::abort_on_drop(crate::sync::spawn("postgres-connection", conn)), + ) + }; + + Ok(tup) +} + +impl PostgresRepo { + pub(crate) async fn connect( + postgres_url: Url, + use_tls: bool, + certificate_file: Option, + ) -> Result { + let connector = if use_tls { + Some( + build_tls_connector(certificate_file) + .await + .map_err(ConnectPostgresError::Tls)?, + ) + } else { + None + }; + + let (mut client, handle) = connect_for_migrations(&postgres_url, connector.clone()).await?; embedded::migrations::runner() .run_async(&mut client) @@ -169,7 +266,7 @@ impl PostgresRepo { let (tx, rx) = flume::bounded(10); let mut config = ManagerConfig::default(); - config.custom_setup = build_handler(tx); + config.custom_setup = build_handler(tx, connector); let mgr = AsyncDieselConnectionManager::::new_with_config( postgres_url, @@ -388,22 +485,39 @@ async fn delegate_notifications( tracing::warn!("Notification delegator shutting down"); } -fn build_handler(sender: flume::Sender) -> ConfigFn { +fn build_handler( + sender: flume::Sender, + connector: Option, +) -> ConfigFn { Box::new( move |config: &str| -> BoxFuture<'_, ConnectionResult> { let sender = sender.clone(); + let connector = connector.clone(); let connect_span = tracing::trace_span!(parent: None, "connect future"); Box::pin( async move { - let (client, conn) = - tokio_postgres::connect(config, tokio_postgres::tls::NoTls) + let client = if let Some(connector) = connector { + let (client, conn) = tokio_postgres::connect(config, connector) .await .map_err(|e| ConnectionError::BadConnection(e.to_string()))?; - // not very cash money (structured concurrency) of me - spawn_db_notification_task(sender, conn); + // not very cash money (structured concurrency) of me + spawn_db_notification_task(sender, conn); + + client + } else { + let (client, conn) = + tokio_postgres::connect(config, tokio_postgres::tls::NoTls) + .await + .map_err(|e| ConnectionError::BadConnection(e.to_string()))?; + + // not very cash money (structured concurrency) of me + spawn_db_notification_task(sender, conn); + + client + }; AsyncPgConnection::try_from(client).await } @@ -413,10 +527,12 @@ fn build_handler(sender: flume::Sender) -> ConfigFn { ) } -fn spawn_db_notification_task( +fn spawn_db_notification_task( sender: flume::Sender, - mut conn: Connection, -) { + mut conn: Connection, +) where + S: tokio_postgres::tls::TlsStream + Unpin + 'static, +{ crate::sync::spawn("postgres-notifications", async move { while let Some(res) = std::future::poll_fn(|cx| conn.poll_message(cx)).await { tracing::trace!("db_notification_task: looping");