144 lines
4.3 KiB
Rust
144 lines
4.3 KiB
Rust
use std::{
|
|
cell::{OnceCell, RefCell},
|
|
sync::{
|
|
atomic::{AtomicU64, Ordering},
|
|
Arc, RwLock,
|
|
},
|
|
};
|
|
|
|
use nanorand::Rng;
|
|
use rustls::sign::CertifiedKey;
|
|
|
|
/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
|
|
/// registered with a rustls server
|
|
///
|
|
/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
|
|
/// contention, leading to faster reads but more writes per call to `update`.
|
|
pub fn channel<const SHARDS: usize>(
|
|
initial: CertifiedKey,
|
|
) -> (ChannelSender, Arc<ChannelResolver<SHARDS>>) {
|
|
let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
|
|
|
|
let cloned_resolver = Arc::clone(&resolver);
|
|
let inner = cloned_resolver as Arc<ErasedChannelResolver>;
|
|
|
|
(ChannelSender { inner }, resolver)
|
|
}
|
|
|
|
/// The Send half of the channel. This is used for updating the server with new keys
|
|
#[derive(Clone)]
|
|
pub struct ChannelSender {
|
|
inner: Arc<ErasedChannelResolver>,
|
|
}
|
|
|
|
mod sealed {
|
|
use std::sync::{atomic::AtomicU64, Arc, RwLock};
|
|
|
|
use rustls::sign::CertifiedKey;
|
|
|
|
pub struct ChannelResolverInner<L: ?Sized> {
|
|
pub(super) locks: L,
|
|
}
|
|
|
|
pub struct Shard {
|
|
pub(super) generation: AtomicU64,
|
|
pub(super) lock: RwLock<Arc<CertifiedKey>>,
|
|
}
|
|
}
|
|
|
|
// The Receive half of the channel. This is registerd with rustls to provide the server with keys
|
|
pub type ChannelResolver<const SHARDS: usize> =
|
|
sealed::ChannelResolverInner<[sealed::Shard; SHARDS]>;
|
|
type ErasedChannelResolver = sealed::ChannelResolverInner<[sealed::Shard]>;
|
|
|
|
thread_local! {
|
|
static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = OnceCell::new();
|
|
}
|
|
|
|
impl sealed::Shard {
|
|
fn new(key: CertifiedKey) -> Self {
|
|
Self {
|
|
generation: AtomicU64::new(0),
|
|
lock: RwLock::new(Arc::new(key)),
|
|
}
|
|
}
|
|
|
|
fn update(&self, key: CertifiedKey) {
|
|
{
|
|
*self.lock.write().unwrap() = Arc::new(key);
|
|
}
|
|
// update generation after lock is released. reduces lock contention with readers
|
|
self.generation.fetch_add(1, Ordering::AcqRel);
|
|
}
|
|
|
|
fn read(&self) -> Arc<CertifiedKey> {
|
|
let generation = self.generation.load(Ordering::Acquire);
|
|
|
|
let key = LOCAL_KEY.with(|local_key| {
|
|
local_key.get().and_then(|refcell| {
|
|
let borrowed = refcell.borrow();
|
|
// if TLS generation is the same, we can safely return TLS key
|
|
if borrowed.0 == generation {
|
|
Some(Arc::clone(&borrowed.1))
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
});
|
|
|
|
if let Some(key) = key {
|
|
key
|
|
} else {
|
|
// slow path, take a read lock and update TLS with new key
|
|
let key = Arc::clone(&self.lock.read().unwrap());
|
|
|
|
LOCAL_KEY.with(|local_key| {
|
|
let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
|
|
if guard.borrow().0 != generation {
|
|
*guard.borrow_mut() = (generation, Arc::clone(&key))
|
|
}
|
|
});
|
|
|
|
key
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ChannelSender {
|
|
/// Update the key in the channel
|
|
pub fn update(&self, key: CertifiedKey) {
|
|
for lock in &self.inner.locks {
|
|
lock.update(key.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
|
|
fn new(key: CertifiedKey) -> Self {
|
|
Self {
|
|
locks: [(); SHARDS].map(|()| sealed::Shard::new(key.clone())),
|
|
}
|
|
}
|
|
|
|
// exposed for benching
|
|
#[doc(hidden)]
|
|
pub fn read(&self) -> Arc<CertifiedKey> {
|
|
// choose random shard to reduce contention. unwrap since slice is always non-empty
|
|
self.locks[nanorand::tls_rng().generate_range(0..SHARDS)].read()
|
|
}
|
|
}
|
|
|
|
impl<const SHARDS: usize> std::fmt::Debug for ChannelResolver<SHARDS> {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("ChannelResolver")
|
|
.field("locks", &format!("[Lock; {SHARDS}]"))
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
|
|
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
|
|
Some(self.read())
|
|
}
|
|
}
|