Implement resolver channel, benchmark it

This commit is contained in:
asonix 2024-01-27 21:03:41 -06:00
commit 1cd8c0da4b
7 changed files with 350 additions and 0 deletions

5
.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
/target
/Cargo.lock
/out
/.direnv
/.envrc

18
Cargo.toml Normal file
View file

@ -0,0 +1,18 @@
[package]
name = "rustls-resolver"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rand = "0.8.5"
rustls = "0.21"
[dev-dependencies]
criterion = "0.3"
rustls-pemfile = "2.0.0"
[[bench]]
name = "parallel_access"
harness = false

105
benches/parallel_access.rs Normal file
View file

@ -0,0 +1,105 @@
use std::{
io::BufReader,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use rustls::sign::CertifiedKey;
use rustls_resolver::channel;
fn prepare_key() -> CertifiedKey {
let certfile = std::fs::File::open("./out/example.crt").unwrap();
let mut reader = BufReader::new(certfile);
let certs = rustls_pemfile::certs(&mut reader)
.map(|res| res.map(|c| rustls::Certificate(c.to_vec())))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let keyfile = std::fs::File::open("./out/example.key").unwrap();
let mut reader = BufReader::new(keyfile);
let private_key = rustls_pemfile::private_key(&mut reader).unwrap().unwrap();
let private_key =
rustls::sign::any_supported_type(&rustls::PrivateKey(Vec::from(private_key.secret_der())))
.unwrap();
CertifiedKey::new(certs, private_key)
}
fn parallel_bench<const N: usize>(c: &mut Criterion, name: &str, writer: bool) {
let mut group = c.benchmark_group(name);
for i in [1u64, 2, 4, 8, 16, 32, 64].iter() {
group.bench_with_input(BenchmarkId::from_parameter(i), i, |b, i| {
let key = prepare_key();
let (tx, rx) = channel::<N>(key.clone());
let go = Arc::new(AtomicBool::new(true));
let mut handles = (0..*i)
.map(|_| {
let rx = rx.clone();
let go = go.clone();
std::thread::spawn(move || {
while go.load(Ordering::Relaxed) {
let _key = black_box(rx.read());
}
})
})
.collect::<Vec<_>>();
if writer {
let go = go.clone();
handles.push(std::thread::spawn(move || {
while go.load(Ordering::Relaxed) {
tx.update(key.clone())
}
}));
}
b.iter(|| {
let _key = black_box(rx.read());
});
go.store(false, Ordering::Relaxed);
for handle in handles {
handle.join().unwrap();
}
});
}
group.finish();
}
pub fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("sequential_access", |b| {
let (_, rx) = channel::<1>(prepare_key());
b.iter(|| {
let _key = black_box(rx.read());
})
});
parallel_bench::<1>(c, "parallel_access_1", false);
parallel_bench::<2>(c, "parallel_access_2", false);
parallel_bench::<4>(c, "parallel_access_4", false);
parallel_bench::<8>(c, "parallel_access_8", false);
parallel_bench::<16>(c, "parallel_access_16", false);
parallel_bench::<32>(c, "parallel_access_32", false);
parallel_bench::<64>(c, "parallel_access_64", false);
parallel_bench::<1>(c, "parallel_access_with_writer_1", true);
parallel_bench::<2>(c, "parallel_access_with_writer_2", true);
parallel_bench::<4>(c, "parallel_access_with_writer_4", true);
parallel_bench::<8>(c, "parallel_access_with_writer_8", true);
parallel_bench::<16>(c, "parallel_access_with_writer_16", true);
parallel_bench::<32>(c, "parallel_access_with_writer_32", true);
parallel_bench::<64>(c, "parallel_access_with_writer_64", true);
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

61
flake.lock Normal file
View file

@ -0,0 +1,61 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1705309234,
"narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1706191920,
"narHash": "sha256-eLihrZAPZX0R6RyM5fYAWeKVNuQPYjAkCUBr+JNvtdE=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ae5c332cbb5827f6b1f02572496b141021de335f",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

27
flake.nix Normal file
View file

@ -0,0 +1,27 @@
{
description = "rustls-resolver";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
};
outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system:
let
pkgs = import nixpkgs {
inherit system;
};
in
{
packages = rec {
default = pkgs.hello;
};
devShell = with pkgs; mkShell {
nativeBuildInputs = [ cargo cargo-outdated certstrap clippy gcc rust-analyzer rustc rustfmt ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
};
});
}

7
setup-tls.sh Executable file
View file

@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -xe
certstrap init --common-name exampleCA
certstrap request-cert --common-name example --domain localhost
certstrap sign example --CA exampleCA

127
src/lib.rs Normal file
View file

@ -0,0 +1,127 @@
use std::{
cell::{OnceCell, RefCell},
sync::{
atomic::{AtomicU64, Ordering},
Arc, RwLock,
},
};
use rand::{seq::SliceRandom, thread_rng};
use rustls::sign::CertifiedKey;
pub fn channel<const SHARDS: usize>(
initial: CertifiedKey,
) -> (ChannelSender<SHARDS>, ChannelResolver<SHARDS>) {
let inner = Arc::new(Inner::<SHARDS>::new(initial));
(
ChannelSender {
inner: Arc::clone(&inner),
},
ChannelResolver { inner },
)
}
#[derive(Clone)]
pub struct ChannelSender<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>,
}
#[derive(Clone)]
pub struct ChannelResolver<const SHARDS: usize> {
inner: Arc<Inner<SHARDS>>,
}
struct Shard {
generation: AtomicU64,
lock: RwLock<Arc<CertifiedKey>>,
}
thread_local! {
static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = OnceCell::new();
}
impl Shard {
fn new(key: CertifiedKey) -> Self {
Self {
generation: AtomicU64::new(0),
lock: RwLock::new(Arc::new(key)),
}
}
fn update(&self, key: CertifiedKey) {
let mut guard = self.lock.write().unwrap();
self.generation.fetch_add(1, Ordering::Release);
*guard = Arc::new(key);
}
fn read(&self) -> Arc<CertifiedKey> {
let generation = self.generation.load(Ordering::Acquire);
let key = LOCAL_KEY.with(|local_key| {
local_key.get().and_then(|refcell| {
let borrowed = refcell.borrow();
if borrowed.0 == generation {
Some(Arc::clone(&borrowed.1))
} else {
None
}
})
});
if let Some(key) = key {
key
} else {
let key = Arc::clone(&self.lock.read().unwrap());
LOCAL_KEY.with(|local_key| {
let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
if guard.borrow().0 != generation {
*guard.borrow_mut() = (generation, Arc::clone(&key))
}
});
key
}
}
}
struct Inner<const SHARDS: usize> {
locks: [Shard; SHARDS],
}
impl<const SHARDS: usize> Inner<SHARDS> {
fn new(key: CertifiedKey) -> Self {
Self {
locks: [(); SHARDS].map(|()| Shard::new(key.clone())),
}
}
fn update(&self, key: CertifiedKey) {
for lock in &self.locks {
lock.update(key.clone());
}
}
fn read(&self) -> Arc<CertifiedKey> {
self.locks.choose(&mut thread_rng()).unwrap().read()
}
}
impl<const SHARDS: usize> ChannelSender<SHARDS> {
pub fn update(&self, key: CertifiedKey) {
self.inner.update(key);
}
}
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
pub fn read(&self) -> Arc<CertifiedKey> {
self.inner.read()
}
}
impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.read())
}
}