128 lines
3.1 KiB
Rust
128 lines
3.1 KiB
Rust
use std::{
|
|
cell::{OnceCell, RefCell},
|
|
sync::{
|
|
atomic::{AtomicU64, Ordering},
|
|
Arc, RwLock,
|
|
},
|
|
};
|
|
|
|
use rand::{seq::SliceRandom, thread_rng};
|
|
use rustls::sign::CertifiedKey;
|
|
|
|
pub fn channel<const SHARDS: usize>(
|
|
initial: CertifiedKey,
|
|
) -> (ChannelSender<SHARDS>, ChannelResolver<SHARDS>) {
|
|
let inner = Arc::new(Inner::<SHARDS>::new(initial));
|
|
|
|
(
|
|
ChannelSender {
|
|
inner: Arc::clone(&inner),
|
|
},
|
|
ChannelResolver { inner },
|
|
)
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ChannelSender<const SHARDS: usize> {
|
|
inner: Arc<Inner<SHARDS>>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ChannelResolver<const SHARDS: usize> {
|
|
inner: Arc<Inner<SHARDS>>,
|
|
}
|
|
|
|
struct Shard {
|
|
generation: AtomicU64,
|
|
lock: RwLock<Arc<CertifiedKey>>,
|
|
}
|
|
|
|
thread_local! {
|
|
static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = OnceCell::new();
|
|
}
|
|
|
|
impl Shard {
|
|
fn new(key: CertifiedKey) -> Self {
|
|
Self {
|
|
generation: AtomicU64::new(0),
|
|
lock: RwLock::new(Arc::new(key)),
|
|
}
|
|
}
|
|
|
|
fn update(&self, key: CertifiedKey) {
|
|
let mut guard = self.lock.write().unwrap();
|
|
self.generation.fetch_add(1, Ordering::Release);
|
|
*guard = Arc::new(key);
|
|
}
|
|
|
|
fn read(&self) -> Arc<CertifiedKey> {
|
|
let generation = self.generation.load(Ordering::Acquire);
|
|
|
|
let key = LOCAL_KEY.with(|local_key| {
|
|
local_key.get().and_then(|refcell| {
|
|
let borrowed = refcell.borrow();
|
|
if borrowed.0 == generation {
|
|
Some(Arc::clone(&borrowed.1))
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
});
|
|
|
|
if let Some(key) = key {
|
|
key
|
|
} else {
|
|
let key = Arc::clone(&self.lock.read().unwrap());
|
|
|
|
LOCAL_KEY.with(|local_key| {
|
|
let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
|
|
if guard.borrow().0 != generation {
|
|
*guard.borrow_mut() = (generation, Arc::clone(&key))
|
|
}
|
|
});
|
|
|
|
key
|
|
}
|
|
}
|
|
}
|
|
|
|
struct Inner<const SHARDS: usize> {
|
|
locks: [Shard; SHARDS],
|
|
}
|
|
|
|
impl<const SHARDS: usize> Inner<SHARDS> {
|
|
fn new(key: CertifiedKey) -> Self {
|
|
Self {
|
|
locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
|
|
}
|
|
}
|
|
|
|
fn update(&self, key: CertifiedKey) {
|
|
for lock in &self.locks {
|
|
lock.update(key.clone());
|
|
}
|
|
}
|
|
|
|
fn read(&self) -> Arc<CertifiedKey> {
|
|
self.locks.choose(&mut thread_rng()).unwrap().read()
|
|
}
|
|
}
|
|
|
|
impl<const SHARDS: usize> ChannelSender<SHARDS> {
|
|
pub fn update(&self, key: CertifiedKey) {
|
|
self.inner.update(key);
|
|
}
|
|
}
|
|
|
|
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
|
|
pub fn read(&self) -> Arc<CertifiedKey> {
|
|
self.inner.read()
|
|
}
|
|
}
|
|
|
|
impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
|
|
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
|
|
Some(self.read())
|
|
}
|
|
}
|