Remove 'Inner' type to avoid double Arc-ing

This commit is contained in:
asonix 2024-01-28 12:27:35 -06:00
parent 0697ad93a0
commit 7b5db89dc8
2 changed files with 30 additions and 26 deletions

View file

@ -1,4 +1,4 @@
use std::{sync::Arc, time::Duration};
use std::time::Duration;
use actix_web::{web, App, HttpServer};
@ -8,7 +8,7 @@ async fn index() -> &'static str {
#[actix_web::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let initial_key = read_key().await?;
let initial_key = read_key().await?.unwrap();
let (tx, rx) = rustls_resolver::channel::<32>(initial_key);
@ -31,7 +31,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let server_config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(rx));
.with_cert_resolver(rx);
HttpServer::new(|| App::new().route("/", web::get().to(index)))
.bind_rustls_021("0.0.0.0:8443", server_config)?

View file

@ -9,27 +9,33 @@ use std::{
use rand::{seq::SliceRandom, thread_rng};
use rustls::sign::CertifiedKey;
/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
/// registered with a rustls server
///
/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
/// contention, leading to faster reads but more writes per call to `update`.
pub fn channel<const SHARDS: usize>(
initial: CertifiedKey,
) -> (ChannelSender<SHARDS>, ChannelResolver<SHARDS>) {
let inner = Arc::new(Inner::<SHARDS>::new(initial));
) -> (ChannelSender<SHARDS>, Arc<ChannelResolver<SHARDS>>) {
let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
(
ChannelSender {
inner: Arc::clone(&inner),
inner: Arc::clone(&resolver),
},
ChannelResolver { inner },
resolver,
)
}
/// The Send half of the channel. This is used for updating the server with new keys
#[derive(Clone)]
pub struct ChannelSender<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>,
inner: Arc<ChannelResolver<SHARDS>>,
}
#[derive(Clone)]
// The Receive half of the channel. This is registerd with rustls to provide the server with keys
pub struct ChannelResolver<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>,
locks: [Shard; SHARDS],
}
struct Shard {
@ -51,6 +57,8 @@ impl Shard {
fn update(&self, key: CertifiedKey) {
let mut guard = self.lock.write().unwrap();
// update generation while lock is held - ensures any "after" accesses to generation queue
// at the read-lock for a new key
self.generation.fetch_add(1, Ordering::Release);
*guard = Arc::new(key);
}
@ -61,6 +69,7 @@ impl Shard {
let key = LOCAL_KEY.with(|local_key| {
local_key.get().and_then(|refcell| {
let borrowed = refcell.borrow();
// if TLS generation is the same, we can safely return TLS key
if borrowed.0 == generation {
Some(Arc::clone(&borrowed.1))
} else {
@ -72,6 +81,7 @@ impl Shard {
if let Some(key) = key {
key
} else {
// slow path, take a read lock and update TLS with new key
let key = Arc::clone(&self.lock.read().unwrap());
LOCAL_KEY.with(|local_key| {
@ -86,11 +96,14 @@ impl Shard {
}
}
struct Inner<const SHARDS: usize> {
locks: [Shard; SHARDS],
impl<const SHARDS: usize> ChannelSender<SHARDS> {
/// Update the key in the channel
pub fn update(&self, key: CertifiedKey) {
self.inner.update(key);
}
}
impl<const SHARDS: usize> Inner<SHARDS> {
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
fn new(key: CertifiedKey) -> Self {
Self {
locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
@ -103,20 +116,11 @@ impl<const SHARDS: usize> Inner<SHARDS> {
}
}
fn read(&self) -> Arc<CertifiedKey> {
self.locks.choose(&mut thread_rng()).unwrap().read()
}
}
impl<const SHARDS: usize> ChannelSender<SHARDS> {
pub fn update(&self, key: CertifiedKey) {
self.inner.update(key);
}
}
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
// exposed for benching
#[doc(hidden)]
pub fn read(&self) -> Arc<CertifiedKey> {
self.inner.read()
// choose random shard to reduce contention. unwrap since slice is always non-empty
self.locks.choose(&mut thread_rng()).unwrap().read()
}
}