Remove 'Inner' type to avoid double Arc-ing

This commit is contained in:
asonix 2024-01-28 12:27:35 -06:00
parent 0697ad93a0
commit 7b5db89dc8
2 changed files with 30 additions and 26 deletions

View file

@ -1,4 +1,4 @@
use std::{sync::Arc, time::Duration}; use std::time::Duration;
use actix_web::{web, App, HttpServer}; use actix_web::{web, App, HttpServer};
@ -8,7 +8,7 @@ async fn index() -> &'static str {
#[actix_web::main] #[actix_web::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let initial_key = read_key().await?; let initial_key = read_key().await?.unwrap();
let (tx, rx) = rustls_resolver::channel::<32>(initial_key); let (tx, rx) = rustls_resolver::channel::<32>(initial_key);
@ -31,7 +31,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let server_config = rustls::ServerConfig::builder() let server_config = rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_cert_resolver(Arc::new(rx)); .with_cert_resolver(rx);
HttpServer::new(|| App::new().route("/", web::get().to(index))) HttpServer::new(|| App::new().route("/", web::get().to(index)))
.bind_rustls_021("0.0.0.0:8443", server_config)? .bind_rustls_021("0.0.0.0:8443", server_config)?

View file

@ -9,27 +9,33 @@ use std::{
use rand::{seq::SliceRandom, thread_rng}; use rand::{seq::SliceRandom, thread_rng};
use rustls::sign::CertifiedKey; use rustls::sign::CertifiedKey;
/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
/// registered with a rustls server
///
/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
/// contention, leading to faster reads but more writes per call to `update`.
pub fn channel<const SHARDS: usize>( pub fn channel<const SHARDS: usize>(
initial: CertifiedKey, initial: CertifiedKey,
) -> (ChannelSender<SHARDS>, ChannelResolver<SHARDS>) { ) -> (ChannelSender<SHARDS>, Arc<ChannelResolver<SHARDS>>) {
let inner = Arc::new(Inner::<SHARDS>::new(initial)); let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
( (
ChannelSender { ChannelSender {
inner: Arc::clone(&inner), inner: Arc::clone(&resolver),
}, },
ChannelResolver { inner }, resolver,
) )
} }
/// The Send half of the channel. This is used for updating the server with new keys
#[derive(Clone)] #[derive(Clone)]
pub struct ChannelSender<const SHARDS: usize> { pub struct ChannelSender<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>, inner: Arc<ChannelResolver<SHARDS>>,
} }
#[derive(Clone)] // The Receive half of the channel. This is registerd with rustls to provide the server with keys
pub struct ChannelResolver<const SHARDS: usize> { pub struct ChannelResolver<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>, locks: [Shard; SHARDS],
} }
struct Shard { struct Shard {
@ -51,6 +57,8 @@ impl Shard {
fn update(&self, key: CertifiedKey) { fn update(&self, key: CertifiedKey) {
let mut guard = self.lock.write().unwrap(); let mut guard = self.lock.write().unwrap();
// update generation while lock is held - ensures any "after" accesses to generation queue
// at the read-lock for a new key
self.generation.fetch_add(1, Ordering::Release); self.generation.fetch_add(1, Ordering::Release);
*guard = Arc::new(key); *guard = Arc::new(key);
} }
@ -61,6 +69,7 @@ impl Shard {
let key = LOCAL_KEY.with(|local_key| { let key = LOCAL_KEY.with(|local_key| {
local_key.get().and_then(|refcell| { local_key.get().and_then(|refcell| {
let borrowed = refcell.borrow(); let borrowed = refcell.borrow();
// if TLS generation is the same, we can safely return TLS key
if borrowed.0 == generation { if borrowed.0 == generation {
Some(Arc::clone(&borrowed.1)) Some(Arc::clone(&borrowed.1))
} else { } else {
@ -72,6 +81,7 @@ impl Shard {
if let Some(key) = key { if let Some(key) = key {
key key
} else { } else {
// slow path, take a read lock and update TLS with new key
let key = Arc::clone(&self.lock.read().unwrap()); let key = Arc::clone(&self.lock.read().unwrap());
LOCAL_KEY.with(|local_key| { LOCAL_KEY.with(|local_key| {
@ -86,11 +96,14 @@ impl Shard {
} }
} }
struct Inner<const SHARDS: usize> { impl<const SHARDS: usize> ChannelSender<SHARDS> {
locks: [Shard; SHARDS], /// Update the key in the channel
pub fn update(&self, key: CertifiedKey) {
self.inner.update(key);
}
} }
impl<const SHARDS: usize> Inner<SHARDS> { impl<const SHARDS: usize> ChannelResolver<SHARDS> {
fn new(key: CertifiedKey) -> Self { fn new(key: CertifiedKey) -> Self {
Self { Self {
locks: [(); SHARDS].map(|()| Shard::new(key.clone())), locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
@ -103,20 +116,11 @@ impl<const SHARDS: usize> Inner<SHARDS> {
} }
} }
fn read(&self) -> Arc<CertifiedKey> { // exposed for benching
self.locks.choose(&mut thread_rng()).unwrap().read() #[doc(hidden)]
}
}
impl<const SHARDS: usize> ChannelSender<SHARDS> {
pub fn update(&self, key: CertifiedKey) {
self.inner.update(key);
}
}
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
pub fn read(&self) -> Arc<CertifiedKey> { pub fn read(&self) -> Arc<CertifiedKey> {
self.inner.read() // choose random shard to reduce contention. unwrap since slice is always non-empty
self.locks.choose(&mut thread_rng()).unwrap().read()
} }
} }