rustls-resolver/src/lib.rs
2024-02-03 21:39:30 -06:00

144 lines
4.3 KiB
Rust

use std::{
cell::{OnceCell, RefCell},
sync::{
atomic::{AtomicU64, Ordering},
Arc, RwLock,
},
};
use nanorand::Rng;
use rustls::sign::CertifiedKey;
/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
/// registered with a rustls server
///
/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
/// contention, leading to faster reads but more writes per call to `update`.
pub fn channel<const SHARDS: usize>(
initial: CertifiedKey,
) -> (ChannelSender, Arc<ChannelResolver<SHARDS>>) {
let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
let cloned_resolver = Arc::clone(&resolver);
let inner = cloned_resolver as Arc<ErasedChannelResolver>;
(ChannelSender { inner }, resolver)
}
/// The Send half of the channel. This is used for updating the server with new keys
#[derive(Clone)]
pub struct ChannelSender {
inner: Arc<ErasedChannelResolver>,
}
mod sealed {
use std::sync::{atomic::AtomicU64, Arc, RwLock};
use rustls::sign::CertifiedKey;
pub struct ChannelResolverInner<L: ?Sized> {
pub(super) locks: L,
}
pub struct Shard {
pub(super) generation: AtomicU64,
pub(super) lock: RwLock<Arc<CertifiedKey>>,
}
}
// The Receive half of the channel. This is registerd with rustls to provide the server with keys
pub type ChannelResolver<const SHARDS: usize> =
sealed::ChannelResolverInner<[sealed::Shard; SHARDS]>;
type ErasedChannelResolver = sealed::ChannelResolverInner<[sealed::Shard]>;
thread_local! {
static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = OnceCell::new();
}
impl sealed::Shard {
fn new(key: CertifiedKey) -> Self {
Self {
generation: AtomicU64::new(0),
lock: RwLock::new(Arc::new(key)),
}
}
fn update(&self, key: CertifiedKey) {
{
*self.lock.write().unwrap() = Arc::new(key);
}
// update generation after lock is released. reduces lock contention with readers
self.generation.fetch_add(1, Ordering::AcqRel);
}
fn read(&self) -> Arc<CertifiedKey> {
let generation = self.generation.load(Ordering::Acquire);
let key = LOCAL_KEY.with(|local_key| {
local_key.get().and_then(|refcell| {
let borrowed = refcell.borrow();
// if TLS generation is the same, we can safely return TLS key
if borrowed.0 == generation {
Some(Arc::clone(&borrowed.1))
} else {
None
}
})
});
if let Some(key) = key {
key
} else {
// slow path, take a read lock and update TLS with new key
let key = Arc::clone(&self.lock.read().unwrap());
LOCAL_KEY.with(|local_key| {
let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
if guard.borrow().0 != generation {
*guard.borrow_mut() = (generation, Arc::clone(&key))
}
});
key
}
}
}
impl ChannelSender {
/// Update the key in the channel
pub fn update(&self, key: CertifiedKey) {
for lock in &self.inner.locks {
lock.update(key.clone());
}
}
}
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
fn new(key: CertifiedKey) -> Self {
Self {
locks: [(); SHARDS].map(|()| sealed::Shard::new(key.clone())),
}
}
// exposed for benching
#[doc(hidden)]
pub fn read(&self) -> Arc<CertifiedKey> {
// choose random shard to reduce contention. unwrap since slice is always non-empty
self.locks[nanorand::tls_rng().generate_range(0..SHARDS)].read()
}
}
impl<const SHARDS: usize> std::fmt::Debug for ChannelResolver<SHARDS> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChannelResolver")
.field("locks", &format!("[Lock; {SHARDS}]"))
.finish()
}
}
impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.read())
}
}