postgres: allow connecting to TLS-enabled databases

This commit is contained in:
asonix 2024-01-15 18:11:08 -05:00
parent f6087d65be
commit 19147e2035
7 changed files with 184 additions and 19 deletions

28
Cargo.lock generated
View file

@ -1838,6 +1838,8 @@ dependencies = [
"reqwest",
"reqwest-middleware",
"reqwest-tracing",
"rustls 0.22.2",
"rustls-pemfile 2.0.0",
"rusty-s3",
"serde",
"serde-tuple-vec-map",
@ -1864,6 +1866,7 @@ dependencies = [
"tracing-subscriber",
"url",
"uuid",
"webpki-roots 0.26.0",
]
[[package]]
@ -2233,7 +2236,7 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"rustls 0.21.10",
"rustls-pemfile",
"rustls-pemfile 1.0.4",
"serde",
"serde_json",
"serde_urlencoded",
@ -2247,7 +2250,7 @@ dependencies = [
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"webpki-roots",
"webpki-roots 0.25.3",
"winreg",
]
@ -2370,6 +2373,8 @@ version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
dependencies = [
"log",
"ring 0.17.7",
"rustls-pki-types",
"rustls-webpki 0.102.1",
"subtle",
@ -2385,6 +2390,16 @@ dependencies = [
"base64 0.21.7",
]
[[package]]
name = "rustls-pemfile"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4"
dependencies = [
"base64 0.21.7",
"rustls-pki-types",
]
[[package]]
name = "rustls-pki-types"
version = "1.1.0"
@ -3522,6 +3537,15 @@ version = "0.25.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10"
[[package]]
name = "webpki-roots"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "whoami"
version = "1.4.1"

View file

@ -46,6 +46,10 @@ refinery = { version = "0.8.10", features = ["tokio-postgres", "postgres"] }
reqwest = { version = "0.11.18", default-features = false, features = ["json", "rustls-tls", "stream"] }
reqwest-middleware = "0.2.2"
reqwest-tracing = { version = "0.4.5" }
# pinned to tokio-postgres-rustls
rustls = "0.22.0"
# pinned to rustls
rustls-pemfile = "2.0.0"
rusty-s3 = "0.5.0"
serde = { version = "1.0", features = ["derive"] }
serde-tuple-vec-map = "1.0.1"
@ -81,6 +85,8 @@ tracing-subscriber = { version = "0.3.0", features = [
] }
url = { version = "2.2", features = ["serde"] }
uuid = { version = "1", features = ["serde", "std", "v4", "v7"] }
# pinned to rustls
webpki-roots = "0.26.0"
[dependencies.tracing-actix-web]
version = "0.7.8"

View file

@ -1415,6 +1415,14 @@ pub(super) struct Postgres {
/// The URL of the postgres database
#[arg(short, long)]
pub(super) url: Url,
/// whether to connect to postgres via TLS
#[arg(short, long)]
pub(super) use_tls: bool,
/// The path to the root certificate for postgres' CA
#[arg(short, long)]
pub(super) certificate_file: Option<PathBuf>,
}
#[derive(Debug, Parser, serde::Serialize)]

View file

@ -389,7 +389,11 @@ impl From<crate::config::commandline::Sled> for crate::config::file::Sled {
impl From<crate::config::commandline::Postgres> for crate::config::file::Postgres {
fn from(value: crate::config::commandline::Postgres) -> Self {
crate::config::file::Postgres { url: value.url }
crate::config::file::Postgres {
url: value.url,
use_tls: value.use_tls,
certificate_file: value.certificate_file,
}
}
}

View file

@ -458,4 +458,6 @@ pub(crate) struct Sled {
#[serde(rename_all = "snake_case")]
pub(crate) struct Postgres {
pub(crate) url: Url,
pub(crate) use_tls: bool,
pub(crate) certificate_file: Option<PathBuf>,
}

View file

@ -802,8 +802,13 @@ impl Repo {
Ok(Self::Sled(repo))
}
config::Repo::Postgres(config::Postgres { url }) => {
let repo = self::postgres::PostgresRepo::connect(url).await?;
config::Repo::Postgres(config::Postgres {
url,
use_tls,
certificate_file,
}) => {
let repo =
self::postgres::PostgresRepo::connect(url, use_tls, certificate_file).await?;
Ok(Self::Postgres(repo))
}

View file

@ -4,6 +4,7 @@ mod schema;
use std::{
collections::{BTreeSet, VecDeque},
path::PathBuf,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Weak,
@ -22,7 +23,8 @@ use diesel_async::{
};
use futures_core::Stream;
use tokio::sync::Notify;
use tokio_postgres::{tls::NoTlsStream, AsyncMessage, Connection, NoTls, Notification, Socket};
use tokio_postgres::{AsyncMessage, Connection, NoTls, Notification, Socket};
use tokio_postgres_rustls::MakeRustlsConnect;
use tracing::Instrument;
use url::Url;
use uuid::Uuid;
@ -82,6 +84,9 @@ pub(crate) enum ConnectPostgresError {
#[error("Failed to connect to postgres for migrations")]
ConnectForMigration(#[source] tokio_postgres::Error),
#[error("Failed to build TLS configuration")]
Tls(#[source] TlsError),
#[error("Failed to run migrations")]
Migration(#[source] refinery::Error),
@ -116,6 +121,21 @@ pub(crate) enum PostgresError {
DbTimeout,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum TlsError {
#[error("Couldn't read configured certificate file")]
ReadCertificate(#[source] std::io::Error),
#[error("Couldn't parse configured certificate file: {0:?}")]
ParseCertificate(rustls_pemfile::Error),
#[error("Configured certificate file is not a certificate")]
NotCertificate,
#[error("Couldn't add certificate to root store")]
AddCertificate(#[source] rustls::Error),
}
impl PostgresError {
pub(super) const fn error_code(&self) -> ErrorCode {
match self {
@ -146,13 +166,90 @@ impl PostgresError {
}
}
impl PostgresRepo {
pub(crate) async fn connect(postgres_url: Url) -> Result<Self, ConnectPostgresError> {
let (mut client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls)
async fn build_tls_connector(
certificate_file: Option<PathBuf>,
) -> Result<MakeRustlsConnect, TlsError> {
let mut cert_store = rustls::RootCertStore {
roots: Vec::from(webpki_roots::TLS_SERVER_ROOTS),
};
if let Some(certificate_file) = certificate_file {
let bytes = tokio::fs::read(certificate_file)
.await
.map_err(TlsError::ReadCertificate)?;
let opt =
rustls_pemfile::read_one_from_slice(&bytes).map_err(TlsError::ParseCertificate)?;
let (item, _remainder) = opt.ok_or(TlsError::NotCertificate)?;
let cert = if let rustls_pemfile::Item::X509Certificate(cert) = item {
cert
} else {
return Err(TlsError::NotCertificate);
};
cert_store.add(cert).map_err(TlsError::AddCertificate)?;
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(cert_store)
.with_no_client_auth();
let tls = MakeRustlsConnect::new(config);
Ok(tls)
}
async fn connect_for_migrations(
postgres_url: &Url,
tls_connector: Option<MakeRustlsConnect>,
) -> Result<
(
tokio_postgres::Client,
DropHandle<Result<(), tokio_postgres::Error>>,
),
ConnectPostgresError,
> {
let tup = if let Some(connector) = tls_connector {
let (client, conn) = tokio_postgres::connect(postgres_url.as_str(), connector)
.await
.map_err(ConnectPostgresError::ConnectForMigration)?;
let handle = crate::sync::abort_on_drop(crate::sync::spawn("postgres-migrations", conn));
(
client,
crate::sync::abort_on_drop(crate::sync::spawn("postgres-connection", conn)),
)
} else {
let (client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls)
.await
.map_err(ConnectPostgresError::ConnectForMigration)?;
(
client,
crate::sync::abort_on_drop(crate::sync::spawn("postgres-connection", conn)),
)
};
Ok(tup)
}
impl PostgresRepo {
pub(crate) async fn connect(
postgres_url: Url,
use_tls: bool,
certificate_file: Option<PathBuf>,
) -> Result<Self, ConnectPostgresError> {
let connector = if use_tls {
Some(
build_tls_connector(certificate_file)
.await
.map_err(ConnectPostgresError::Tls)?,
)
} else {
None
};
let (mut client, handle) = connect_for_migrations(&postgres_url, connector.clone()).await?;
embedded::migrations::runner()
.run_async(&mut client)
@ -169,7 +266,7 @@ impl PostgresRepo {
let (tx, rx) = flume::bounded(10);
let mut config = ManagerConfig::default();
config.custom_setup = build_handler(tx);
config.custom_setup = build_handler(tx, connector);
let mgr = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
postgres_url,
@ -388,22 +485,39 @@ async fn delegate_notifications(
tracing::warn!("Notification delegator shutting down");
}
fn build_handler(sender: flume::Sender<Notification>) -> ConfigFn {
fn build_handler(
sender: flume::Sender<Notification>,
connector: Option<MakeRustlsConnect>,
) -> ConfigFn {
Box::new(
move |config: &str| -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> {
let sender = sender.clone();
let connector = connector.clone();
let connect_span = tracing::trace_span!(parent: None, "connect future");
Box::pin(
async move {
let (client, conn) =
tokio_postgres::connect(config, tokio_postgres::tls::NoTls)
let client = if let Some(connector) = connector {
let (client, conn) = tokio_postgres::connect(config, connector)
.await
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
// not very cash money (structured concurrency) of me
spawn_db_notification_task(sender, conn);
// not very cash money (structured concurrency) of me
spawn_db_notification_task(sender, conn);
client
} else {
let (client, conn) =
tokio_postgres::connect(config, tokio_postgres::tls::NoTls)
.await
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
// not very cash money (structured concurrency) of me
spawn_db_notification_task(sender, conn);
client
};
AsyncPgConnection::try_from(client).await
}
@ -413,10 +527,12 @@ fn build_handler(sender: flume::Sender<Notification>) -> ConfigFn {
)
}
fn spawn_db_notification_task(
fn spawn_db_notification_task<S>(
sender: flume::Sender<Notification>,
mut conn: Connection<Socket, NoTlsStream>,
) {
mut conn: Connection<Socket, S>,
) where
S: tokio_postgres::tls::TlsStream + Unpin + 'static,
{
crate::sync::spawn("postgres-notifications", async move {
while let Some(res) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
tracing::trace!("db_notification_task: looping");