From 1cd8c0da4bb67ab4cb7fa87b50f693ba26899c68 Mon Sep 17 00:00:00 2001 From: asonix Date: Sat, 27 Jan 2024 21:03:41 -0600 Subject: [PATCH] Implement resolver channel, benchmark it --- .gitignore | 5 ++ Cargo.toml | 18 ++++++ benches/parallel_access.rs | 105 ++++++++++++++++++++++++++++++ flake.lock | 61 ++++++++++++++++++ flake.nix | 27 ++++++++ setup-tls.sh | 7 ++ src/lib.rs | 127 +++++++++++++++++++++++++++++++++++++ 7 files changed, 350 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 benches/parallel_access.rs create mode 100644 flake.lock create mode 100644 flake.nix create mode 100755 setup-tls.sh create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ac7df3d --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +/Cargo.lock +/out +/.direnv +/.envrc diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..39ce41e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "rustls-resolver" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rand = "0.8.5" +rustls = "0.21" + +[dev-dependencies] +criterion = "0.3" +rustls-pemfile = "2.0.0" + +[[bench]] +name = "parallel_access" +harness = false diff --git a/benches/parallel_access.rs b/benches/parallel_access.rs new file mode 100644 index 0000000..2e212df --- /dev/null +++ b/benches/parallel_access.rs @@ -0,0 +1,105 @@ +use std::{ + io::BufReader, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use rustls::sign::CertifiedKey; +use rustls_resolver::channel; + +fn prepare_key() -> CertifiedKey { + let certfile = std::fs::File::open("./out/example.crt").unwrap(); + let mut reader = BufReader::new(certfile); + let certs = rustls_pemfile::certs(&mut reader) + .map(|res| res.map(|c| rustls::Certificate(c.to_vec()))) + .collect::, _>>() + .unwrap(); + + let keyfile = std::fs::File::open("./out/example.key").unwrap(); + let mut reader = BufReader::new(keyfile); + let private_key = rustls_pemfile::private_key(&mut reader).unwrap().unwrap(); + + let private_key = + rustls::sign::any_supported_type(&rustls::PrivateKey(Vec::from(private_key.secret_der()))) + .unwrap(); + + CertifiedKey::new(certs, private_key) +} + +fn parallel_bench(c: &mut Criterion, name: &str, writer: bool) { + let mut group = c.benchmark_group(name); + + for i in [1u64, 2, 4, 8, 16, 32, 64].iter() { + group.bench_with_input(BenchmarkId::from_parameter(i), i, |b, i| { + let key = prepare_key(); + let (tx, rx) = channel::(key.clone()); + + let go = Arc::new(AtomicBool::new(true)); + + let mut handles = (0..*i) + .map(|_| { + let rx = rx.clone(); + let go = go.clone(); + std::thread::spawn(move || { + while go.load(Ordering::Relaxed) { + let _key = black_box(rx.read()); + } + }) + }) + .collect::>(); + + if writer { + let go = go.clone(); + handles.push(std::thread::spawn(move || { + while go.load(Ordering::Relaxed) { + tx.update(key.clone()) + } + })); + } + + b.iter(|| { + let _key = black_box(rx.read()); + }); + + go.store(false, Ordering::Relaxed); + + for handle in handles { + handle.join().unwrap(); + } + }); + } + + group.finish(); +} + +pub fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("sequential_access", |b| { + let (_, rx) = channel::<1>(prepare_key()); + + b.iter(|| { + let _key = black_box(rx.read()); + }) + }); + + parallel_bench::<1>(c, "parallel_access_1", false); + parallel_bench::<2>(c, "parallel_access_2", false); + parallel_bench::<4>(c, "parallel_access_4", false); + parallel_bench::<8>(c, "parallel_access_8", false); + parallel_bench::<16>(c, "parallel_access_16", false); + parallel_bench::<32>(c, "parallel_access_32", false); + parallel_bench::<64>(c, "parallel_access_64", false); + + parallel_bench::<1>(c, "parallel_access_with_writer_1", true); + parallel_bench::<2>(c, "parallel_access_with_writer_2", true); + parallel_bench::<4>(c, "parallel_access_with_writer_4", true); + parallel_bench::<8>(c, "parallel_access_with_writer_8", true); + parallel_bench::<16>(c, "parallel_access_with_writer_16", true); + parallel_bench::<32>(c, "parallel_access_with_writer_32", true); + parallel_bench::<64>(c, "parallel_access_with_writer_64", true); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..c7d657f --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1705309234, + "narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1706191920, + "narHash": "sha256-eLihrZAPZX0R6RyM5fYAWeKVNuQPYjAkCUBr+JNvtdE=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ae5c332cbb5827f6b1f02572496b141021de335f", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..4f9d08f --- /dev/null +++ b/flake.nix @@ -0,0 +1,27 @@ +{ + description = "rustls-resolver"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + }; + in + { + packages = rec { + default = pkgs.hello; + }; + + devShell = with pkgs; mkShell { + nativeBuildInputs = [ cargo cargo-outdated certstrap clippy gcc rust-analyzer rustc rustfmt ]; + + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; + }; + }); +} diff --git a/setup-tls.sh b/setup-tls.sh new file mode 100755 index 0000000..ae5cc3c --- /dev/null +++ b/setup-tls.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +set -xe + +certstrap init --common-name exampleCA +certstrap request-cert --common-name example --domain localhost +certstrap sign example --CA exampleCA diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..893b474 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,127 @@ +use std::{ + cell::{OnceCell, RefCell}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, RwLock, + }, +}; + +use rand::{seq::SliceRandom, thread_rng}; +use rustls::sign::CertifiedKey; + +pub fn channel( + initial: CertifiedKey, +) -> (ChannelSender, ChannelResolver) { + let inner = Arc::new(Inner::::new(initial)); + + ( + ChannelSender { + inner: Arc::clone(&inner), + }, + ChannelResolver { inner }, + ) +} + +#[derive(Clone)] +pub struct ChannelSender { + inner: Arc>, +} + +#[derive(Clone)] +pub struct ChannelResolver { + inner: Arc>, +} + +struct Shard { + generation: AtomicU64, + lock: RwLock>, +} + +thread_local! { + static LOCAL_KEY: OnceCell)>> = OnceCell::new(); +} + +impl Shard { + fn new(key: CertifiedKey) -> Self { + Self { + generation: AtomicU64::new(0), + lock: RwLock::new(Arc::new(key)), + } + } + + fn update(&self, key: CertifiedKey) { + let mut guard = self.lock.write().unwrap(); + self.generation.fetch_add(1, Ordering::Release); + *guard = Arc::new(key); + } + + fn read(&self) -> Arc { + let generation = self.generation.load(Ordering::Acquire); + + let key = LOCAL_KEY.with(|local_key| { + local_key.get().and_then(|refcell| { + let borrowed = refcell.borrow(); + if borrowed.0 == generation { + Some(Arc::clone(&borrowed.1)) + } else { + None + } + }) + }); + + if let Some(key) = key { + key + } else { + let key = Arc::clone(&self.lock.read().unwrap()); + + LOCAL_KEY.with(|local_key| { + let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key)))); + if guard.borrow().0 != generation { + *guard.borrow_mut() = (generation, Arc::clone(&key)) + } + }); + + key + } + } +} + +struct Inner { + locks: [Shard; SHARDS], +} + +impl Inner { + fn new(key: CertifiedKey) -> Self { + Self { + locks: [(); SHARDS].map(|()| Shard::new(key.clone())), + } + } + + fn update(&self, key: CertifiedKey) { + for lock in &self.locks { + lock.update(key.clone()); + } + } + + fn read(&self) -> Arc { + self.locks.choose(&mut thread_rng()).unwrap().read() + } +} + +impl ChannelSender { + pub fn update(&self, key: CertifiedKey) { + self.inner.update(key); + } +} + +impl ChannelResolver { + pub fn read(&self) -> Arc { + self.inner.read() + } +} + +impl rustls::server::ResolvesServerCert for ChannelResolver { + fn resolve(&self, _: rustls::server::ClientHello) -> Option> { + Some(self.read()) + } +}