Remove 'Inner' type to avoid double Arc-ing
This commit is contained in:
parent
0697ad93a0
commit
7b5db89dc8
2 changed files with 30 additions and 26 deletions
|
@ -1,4 +1,4 @@
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::time::Duration;
|
||||||
|
|
||||||
use actix_web::{web, App, HttpServer};
|
use actix_web::{web, App, HttpServer};
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ async fn index() -> &'static str {
|
||||||
|
|
||||||
#[actix_web::main]
|
#[actix_web::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let initial_key = read_key().await?;
|
let initial_key = read_key().await?.unwrap();
|
||||||
|
|
||||||
let (tx, rx) = rustls_resolver::channel::<32>(initial_key);
|
let (tx, rx) = rustls_resolver::channel::<32>(initial_key);
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let server_config = rustls::ServerConfig::builder()
|
let server_config = rustls::ServerConfig::builder()
|
||||||
.with_safe_defaults()
|
.with_safe_defaults()
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
.with_cert_resolver(Arc::new(rx));
|
.with_cert_resolver(rx);
|
||||||
|
|
||||||
HttpServer::new(|| App::new().route("/", web::get().to(index)))
|
HttpServer::new(|| App::new().route("/", web::get().to(index)))
|
||||||
.bind_rustls_021("0.0.0.0:8443", server_config)?
|
.bind_rustls_021("0.0.0.0:8443", server_config)?
|
||||||
|
|
50
src/lib.rs
50
src/lib.rs
|
@ -9,27 +9,33 @@ use std::{
|
||||||
use rand::{seq::SliceRandom, thread_rng};
|
use rand::{seq::SliceRandom, thread_rng};
|
||||||
use rustls::sign::CertifiedKey;
|
use rustls::sign::CertifiedKey;
|
||||||
|
|
||||||
|
/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
|
||||||
|
/// registered with a rustls server
|
||||||
|
///
|
||||||
|
/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
|
||||||
|
/// contention, leading to faster reads but more writes per call to `update`.
|
||||||
pub fn channel<const SHARDS: usize>(
|
pub fn channel<const SHARDS: usize>(
|
||||||
initial: CertifiedKey,
|
initial: CertifiedKey,
|
||||||
) -> (ChannelSender<SHARDS>, ChannelResolver<SHARDS>) {
|
) -> (ChannelSender<SHARDS>, Arc<ChannelResolver<SHARDS>>) {
|
||||||
let inner = Arc::new(Inner::<SHARDS>::new(initial));
|
let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
|
||||||
|
|
||||||
(
|
(
|
||||||
ChannelSender {
|
ChannelSender {
|
||||||
inner: Arc::clone(&inner),
|
inner: Arc::clone(&resolver),
|
||||||
},
|
},
|
||||||
ChannelResolver { inner },
|
resolver,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The Send half of the channel. This is used for updating the server with new keys
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ChannelSender<const SHARDS: usize> {
|
pub struct ChannelSender<const SHARDS: usize> {
|
||||||
inner: Arc<Inner<SHARDS>>,
|
inner: Arc<ChannelResolver<SHARDS>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
// The Receive half of the channel. This is registerd with rustls to provide the server with keys
|
||||||
pub struct ChannelResolver<const SHARDS: usize> {
|
pub struct ChannelResolver<const SHARDS: usize> {
|
||||||
inner: Arc<Inner<SHARDS>>,
|
locks: [Shard; SHARDS],
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Shard {
|
struct Shard {
|
||||||
|
@ -51,6 +57,8 @@ impl Shard {
|
||||||
|
|
||||||
fn update(&self, key: CertifiedKey) {
|
fn update(&self, key: CertifiedKey) {
|
||||||
let mut guard = self.lock.write().unwrap();
|
let mut guard = self.lock.write().unwrap();
|
||||||
|
// update generation while lock is held - ensures any "after" accesses to generation queue
|
||||||
|
// at the read-lock for a new key
|
||||||
self.generation.fetch_add(1, Ordering::Release);
|
self.generation.fetch_add(1, Ordering::Release);
|
||||||
*guard = Arc::new(key);
|
*guard = Arc::new(key);
|
||||||
}
|
}
|
||||||
|
@ -61,6 +69,7 @@ impl Shard {
|
||||||
let key = LOCAL_KEY.with(|local_key| {
|
let key = LOCAL_KEY.with(|local_key| {
|
||||||
local_key.get().and_then(|refcell| {
|
local_key.get().and_then(|refcell| {
|
||||||
let borrowed = refcell.borrow();
|
let borrowed = refcell.borrow();
|
||||||
|
// if TLS generation is the same, we can safely return TLS key
|
||||||
if borrowed.0 == generation {
|
if borrowed.0 == generation {
|
||||||
Some(Arc::clone(&borrowed.1))
|
Some(Arc::clone(&borrowed.1))
|
||||||
} else {
|
} else {
|
||||||
|
@ -72,6 +81,7 @@ impl Shard {
|
||||||
if let Some(key) = key {
|
if let Some(key) = key {
|
||||||
key
|
key
|
||||||
} else {
|
} else {
|
||||||
|
// slow path, take a read lock and update TLS with new key
|
||||||
let key = Arc::clone(&self.lock.read().unwrap());
|
let key = Arc::clone(&self.lock.read().unwrap());
|
||||||
|
|
||||||
LOCAL_KEY.with(|local_key| {
|
LOCAL_KEY.with(|local_key| {
|
||||||
|
@ -86,11 +96,14 @@ impl Shard {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Inner<const SHARDS: usize> {
|
impl<const SHARDS: usize> ChannelSender<SHARDS> {
|
||||||
locks: [Shard; SHARDS],
|
/// Update the key in the channel
|
||||||
|
pub fn update(&self, key: CertifiedKey) {
|
||||||
|
self.inner.update(key);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const SHARDS: usize> Inner<SHARDS> {
|
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
|
||||||
fn new(key: CertifiedKey) -> Self {
|
fn new(key: CertifiedKey) -> Self {
|
||||||
Self {
|
Self {
|
||||||
locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
|
locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
|
||||||
|
@ -103,20 +116,11 @@ impl<const SHARDS: usize> Inner<SHARDS> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read(&self) -> Arc<CertifiedKey> {
|
// exposed for benching
|
||||||
self.locks.choose(&mut thread_rng()).unwrap().read()
|
#[doc(hidden)]
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<const SHARDS: usize> ChannelSender<SHARDS> {
|
|
||||||
pub fn update(&self, key: CertifiedKey) {
|
|
||||||
self.inner.update(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
|
|
||||||
pub fn read(&self) -> Arc<CertifiedKey> {
|
pub fn read(&self) -> Arc<CertifiedKey> {
|
||||||
self.inner.read()
|
// choose random shard to reduce contention. unwrap since slice is always non-empty
|
||||||
|
self.locks.choose(&mut thread_rng()).unwrap().read()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue