rustls-resolver/src/lib.rs

128 lines
3.1 KiB
Rust

use std::{
cell::{OnceCell, RefCell},
sync::{
atomic::{AtomicU64, Ordering},
Arc, RwLock,
},
};
use rand::{seq::SliceRandom, thread_rng};
use rustls::sign::CertifiedKey;
pub fn channel<const SHARDS: usize>(
initial: CertifiedKey,
) -> (ChannelSender<SHARDS>, ChannelResolver<SHARDS>) {
let inner = Arc::new(Inner::<SHARDS>::new(initial));
(
ChannelSender {
inner: Arc::clone(&inner),
},
ChannelResolver { inner },
)
}
#[derive(Clone)]
pub struct ChannelSender<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>,
}
#[derive(Clone)]
pub struct ChannelResolver<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>,
}
struct Shard {
generation: AtomicU64,
lock: RwLock<Arc<CertifiedKey>>,
}
thread_local! {
static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = OnceCell::new();
}
impl Shard {
fn new(key: CertifiedKey) -> Self {
Self {
generation: AtomicU64::new(0),
lock: RwLock::new(Arc::new(key)),
}
}
fn update(&self, key: CertifiedKey) {
let mut guard = self.lock.write().unwrap();
self.generation.fetch_add(1, Ordering::Release);
*guard = Arc::new(key);
}
fn read(&self) -> Arc<CertifiedKey> {
let generation = self.generation.load(Ordering::Acquire);
let key = LOCAL_KEY.with(|local_key| {
local_key.get().and_then(|refcell| {
let borrowed = refcell.borrow();
if borrowed.0 == generation {
Some(Arc::clone(&borrowed.1))
} else {
None
}
})
});
if let Some(key) = key {
key
} else {
let key = Arc::clone(&self.lock.read().unwrap());
LOCAL_KEY.with(|local_key| {
let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
if guard.borrow().0 != generation {
*guard.borrow_mut() = (generation, Arc::clone(&key))
}
});
key
}
}
}
struct Inner<const SHARDS: usize> {
locks: [Shard; SHARDS],
}
impl<const SHARDS: usize> Inner<SHARDS> {
fn new(key: CertifiedKey) -> Self {
Self {
locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
}
}
fn update(&self, key: CertifiedKey) {
for lock in &self.locks {
lock.update(key.clone());
}
}
fn read(&self) -> Arc<CertifiedKey> {
self.locks.choose(&mut thread_rng()).unwrap().read()
}
}
impl<const SHARDS: usize> ChannelSender<SHARDS> {
pub fn update(&self, key: CertifiedKey) {
self.inner.update(key);
}
}
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
pub fn read(&self) -> Arc<CertifiedKey> {
self.inner.read()
}
}
impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.read())
}
}