From 7b5db89dc8e0a2c25cc55e8af80f4072a8eba3dd Mon Sep 17 00:00:00 2001 From: asonix Date: Sun, 28 Jan 2024 12:27:35 -0600 Subject: [PATCH] Remove 'Inner' type to avoid double Arc-ing --- examples/demo.rs | 6 +++--- src/lib.rs | 50 ++++++++++++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/examples/demo.rs b/examples/demo.rs index 0bef87c..ead4f91 100644 --- a/examples/demo.rs +++ b/examples/demo.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use actix_web::{web, App, HttpServer}; @@ -8,7 +8,7 @@ async fn index() -> &'static str { #[actix_web::main] async fn main() -> Result<(), Box> { - let initial_key = read_key().await?; + let initial_key = read_key().await?.unwrap(); let (tx, rx) = rustls_resolver::channel::<32>(initial_key); @@ -31,7 +31,7 @@ async fn main() -> Result<(), Box> { let server_config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_cert_resolver(Arc::new(rx)); + .with_cert_resolver(rx); HttpServer::new(|| App::new().route("/", web::get().to(index))) .bind_rustls_021("0.0.0.0:8443", server_config)? diff --git a/src/lib.rs b/src/lib.rs index 893b474..6589ed4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,27 +9,33 @@ use std::{ use rand::{seq::SliceRandom, thread_rng}; use rustls::sign::CertifiedKey; +/// Create a new resolver channel. The Sender can update the key, and the Receiver can be +/// registered with a rustls server +/// +/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce +/// contention, leading to faster reads but more writes per call to `update`. pub fn channel( initial: CertifiedKey, -) -> (ChannelSender, ChannelResolver) { - let inner = Arc::new(Inner::::new(initial)); +) -> (ChannelSender, Arc>) { + let resolver = Arc::new(ChannelResolver::::new(initial)); ( ChannelSender { - inner: Arc::clone(&inner), + inner: Arc::clone(&resolver), }, - ChannelResolver { inner }, + resolver, ) } +/// The Send half of the channel. This is used for updating the server with new keys #[derive(Clone)] pub struct ChannelSender { - inner: Arc>, + inner: Arc>, } -#[derive(Clone)] +// The Receive half of the channel. This is registerd with rustls to provide the server with keys pub struct ChannelResolver { - inner: Arc>, + locks: [Shard; SHARDS], } struct Shard { @@ -51,6 +57,8 @@ impl Shard { fn update(&self, key: CertifiedKey) { let mut guard = self.lock.write().unwrap(); + // update generation while lock is held - ensures any "after" accesses to generation queue + // at the read-lock for a new key self.generation.fetch_add(1, Ordering::Release); *guard = Arc::new(key); } @@ -61,6 +69,7 @@ impl Shard { let key = LOCAL_KEY.with(|local_key| { local_key.get().and_then(|refcell| { let borrowed = refcell.borrow(); + // if TLS generation is the same, we can safely return TLS key if borrowed.0 == generation { Some(Arc::clone(&borrowed.1)) } else { @@ -72,6 +81,7 @@ impl Shard { if let Some(key) = key { key } else { + // slow path, take a read lock and update TLS with new key let key = Arc::clone(&self.lock.read().unwrap()); LOCAL_KEY.with(|local_key| { @@ -86,11 +96,14 @@ impl Shard { } } -struct Inner { - locks: [Shard; SHARDS], +impl ChannelSender { + /// Update the key in the channel + pub fn update(&self, key: CertifiedKey) { + self.inner.update(key); + } } -impl Inner { +impl ChannelResolver { fn new(key: CertifiedKey) -> Self { Self { locks: [(); SHARDS].map(|()| Shard::new(key.clone())), @@ -103,20 +116,11 @@ impl Inner { } } - fn read(&self) -> Arc { - self.locks.choose(&mut thread_rng()).unwrap().read() - } -} - -impl ChannelSender { - pub fn update(&self, key: CertifiedKey) { - self.inner.update(key); - } -} - -impl ChannelResolver { + // exposed for benching + #[doc(hidden)] pub fn read(&self) -> Arc { - self.inner.read() + // choose random shard to reduce contention. unwrap since slice is always non-empty + self.locks.choose(&mut thread_rng()).unwrap().read() } }