Remove 'Inner' type to avoid double Arc-ing
This commit is contained in:
parent
0697ad93a0
commit
7b5db89dc8
2 changed files with 30 additions and 26 deletions
|
@ -1,4 +1,4 @@
|
|||
use std::{sync::Arc, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::{web, App, HttpServer};
|
||||
|
||||
|
@ -8,7 +8,7 @@ async fn index() -> &'static str {
|
|||
|
||||
#[actix_web::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let initial_key = read_key().await?;
|
||||
let initial_key = read_key().await?.unwrap();
|
||||
|
||||
let (tx, rx) = rustls_resolver::channel::<32>(initial_key);
|
||||
|
||||
|
@ -31,7 +31,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
let server_config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(Arc::new(rx));
|
||||
.with_cert_resolver(rx);
|
||||
|
||||
HttpServer::new(|| App::new().route("/", web::get().to(index)))
|
||||
.bind_rustls_021("0.0.0.0:8443", server_config)?
|
||||
|
|
50
src/lib.rs
50
src/lib.rs
|
@ -9,27 +9,33 @@ use std::{
|
|||
use rand::{seq::SliceRandom, thread_rng};
|
||||
use rustls::sign::CertifiedKey;
|
||||
|
||||
/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
|
||||
/// registered with a rustls server
|
||||
///
|
||||
/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
|
||||
/// contention, leading to faster reads but more writes per call to `update`.
|
||||
pub fn channel<const SHARDS: usize>(
|
||||
initial: CertifiedKey,
|
||||
) -> (ChannelSender<SHARDS>, ChannelResolver<SHARDS>) {
|
||||
let inner = Arc::new(Inner::<SHARDS>::new(initial));
|
||||
) -> (ChannelSender<SHARDS>, Arc<ChannelResolver<SHARDS>>) {
|
||||
let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
|
||||
|
||||
(
|
||||
ChannelSender {
|
||||
inner: Arc::clone(&inner),
|
||||
inner: Arc::clone(&resolver),
|
||||
},
|
||||
ChannelResolver { inner },
|
||||
resolver,
|
||||
)
|
||||
}
|
||||
|
||||
/// The Send half of the channel. This is used for updating the server with new keys
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelSender<const SHARDS: usize> {
|
||||
inner: Arc<Inner<SHARDS>>,
|
||||
inner: Arc<ChannelResolver<SHARDS>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
// The Receive half of the channel. This is registerd with rustls to provide the server with keys
|
||||
pub struct ChannelResolver<const SHARDS: usize> {
|
||||
inner: Arc<Inner<SHARDS>>,
|
||||
locks: [Shard; SHARDS],
|
||||
}
|
||||
|
||||
struct Shard {
|
||||
|
@ -51,6 +57,8 @@ impl Shard {
|
|||
|
||||
fn update(&self, key: CertifiedKey) {
|
||||
let mut guard = self.lock.write().unwrap();
|
||||
// update generation while lock is held - ensures any "after" accesses to generation queue
|
||||
// at the read-lock for a new key
|
||||
self.generation.fetch_add(1, Ordering::Release);
|
||||
*guard = Arc::new(key);
|
||||
}
|
||||
|
@ -61,6 +69,7 @@ impl Shard {
|
|||
let key = LOCAL_KEY.with(|local_key| {
|
||||
local_key.get().and_then(|refcell| {
|
||||
let borrowed = refcell.borrow();
|
||||
// if TLS generation is the same, we can safely return TLS key
|
||||
if borrowed.0 == generation {
|
||||
Some(Arc::clone(&borrowed.1))
|
||||
} else {
|
||||
|
@ -72,6 +81,7 @@ impl Shard {
|
|||
if let Some(key) = key {
|
||||
key
|
||||
} else {
|
||||
// slow path, take a read lock and update TLS with new key
|
||||
let key = Arc::clone(&self.lock.read().unwrap());
|
||||
|
||||
LOCAL_KEY.with(|local_key| {
|
||||
|
@ -86,11 +96,14 @@ impl Shard {
|
|||
}
|
||||
}
|
||||
|
||||
struct Inner<const SHARDS: usize> {
|
||||
locks: [Shard; SHARDS],
|
||||
impl<const SHARDS: usize> ChannelSender<SHARDS> {
|
||||
/// Update the key in the channel
|
||||
pub fn update(&self, key: CertifiedKey) {
|
||||
self.inner.update(key);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SHARDS: usize> Inner<SHARDS> {
|
||||
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
|
||||
fn new(key: CertifiedKey) -> Self {
|
||||
Self {
|
||||
locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
|
||||
|
@ -103,20 +116,11 @@ impl<const SHARDS: usize> Inner<SHARDS> {
|
|||
}
|
||||
}
|
||||
|
||||
fn read(&self) -> Arc<CertifiedKey> {
|
||||
self.locks.choose(&mut thread_rng()).unwrap().read()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SHARDS: usize> ChannelSender<SHARDS> {
|
||||
pub fn update(&self, key: CertifiedKey) {
|
||||
self.inner.update(key);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
|
||||
// exposed for benching
|
||||
#[doc(hidden)]
|
||||
pub fn read(&self) -> Arc<CertifiedKey> {
|
||||
self.inner.read()
|
||||
// choose random shard to reduce contention. unwrap since slice is always non-empty
|
||||
self.locks.choose(&mut thread_rng()).unwrap().read()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue