From eec9ab16fd7875bfd388ca1905fd82607b000728 Mon Sep 17 00:00:00 2001 From: asonix Date: Fri, 24 Nov 2023 13:42:10 -0600 Subject: [PATCH] Simple CPU Pool --- .gitignore | 4 + Cargo.toml | 10 ++ flake.lock | 61 ++++++++++ flake.nix | 25 ++++ src/lib.rs | 343 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 443 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..af89c83 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/target +/Cargo.lock +/.direnv +/.envrc diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..1d76547 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "async-cpupool" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +flume = "0.11.0" +tracing = "0.1.40" diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..3a84ed9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1694529238, + "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1700612854, + "narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..4b8a515 --- /dev/null +++ b/flake.nix @@ -0,0 +1,25 @@ +{ + description = "async-cpupool"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + }; + in + { + packages.default = pkgs.hello; + + devShell = with pkgs; mkShell { + nativeBuildInputs = [ cargo cargo-outdated clippy rust-analyzer rustc rustfmt ]; + + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; + }; + }); +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0d8d9ab --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,343 @@ +use std::{ + num::{NonZeroU16, NonZeroUsize}, + sync::{atomic::AtomicU64, Arc, Mutex}, + thread::JoinHandle, +}; + +#[derive(Debug)] +pub struct Config { + buffer_size: NonZeroUsize, + min_threads: NonZeroU16, + max_threads: NonZeroU16, +} + +impl Config { + pub fn new() -> Self { + Config { + buffer_size: 8usize.try_into().expect("valid nonzero usize"), + min_threads: 1u16.try_into().expect("valid nonzero u16"), + max_threads: 2u16.try_into().expect("valid nonzero u16"), + } + } + + pub fn buffer_size(mut self, buffer_size: NonZeroUsize) -> Self { + self.buffer_size = buffer_size; + self + } + + pub fn min_threads(mut self, min_threads: NonZeroU16) -> Self { + self.min_threads = min_threads; + self + } + + pub fn max_threads(mut self, max_threads: NonZeroU16) -> Self { + self.max_threads = max_threads; + self + } + + pub fn build(self) -> Result { + if self.max_threads < self.min_threads { + return Err(ConfigError::ThreadCount); + } + + Ok(CpuPool { + state: Arc::new(CpuPoolState::new(self)), + }) + } +} + +impl Default for Config { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +pub enum ConfigError { + ThreadCount, +} + +#[derive(Debug)] +pub struct Canceled; + +pub struct CpuPool { + state: Arc, +} + +impl CpuPool { + pub fn new() -> Self { + Self { + state: Arc::new(CpuPoolState::new(Config::default())), + } + } + + pub fn configure() -> Config { + Config::default() + } + + pub async fn spawn(&self, send_fn: F) -> Result + where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, + { + if self.state.sender.is_full() { + self.push_thread(); + } + + let (response_tx, response) = flume::bounded(1); + + self.state + .sender + .send_async(Box::new(move || { + let output = (send_fn)(); + + if response_tx.send(output).is_err() { + tracing::trace!("Receiver hung up"); + } + })) + .await + .expect("No receiver"); + + if self.state.sender.is_empty() { + if let Some(thread) = self.pop_thread() { + thread.reap().await; + } + } + + response.recv_async().await.map_err(|_| Canceled) + } + + fn push_thread(&self) { + let current_threads = self + .state + .current_threads + .load(std::sync::atomic::Ordering::Acquire); + + if current_threads >= u64::from(u16::from(self.state.max_threads)) { + return; + } + + if self + .state + .current_threads + .compare_exchange( + current_threads, + current_threads + 1, + std::sync::atomic::Ordering::AcqRel, + std::sync::atomic::Ordering::Relaxed, + ) + .is_err() + { + return; + } + + // we updated the count, so we have authorization to spawn a new thread + + let thread_id = self + .state + .thread_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let thread = spawn(thread_id, self.state.receiver.clone()); + + self.state.threads.lock().unwrap().push(thread); + } + + fn pop_thread(&self) -> Option { + let current_threads = self + .state + .current_threads + .load(std::sync::atomic::Ordering::Acquire); + + if current_threads <= u64::from(u16::from(self.state.min_threads)) { + return None; + } + + if self + .state + .current_threads + .compare_exchange( + current_threads, + current_threads - 1, + std::sync::atomic::Ordering::AcqRel, + std::sync::atomic::Ordering::Relaxed, + ) + .is_err() + { + return None; + } + + // we updated the count, so we have authorization to reap a thread + + self.state.threads.lock().unwrap().pop() + } +} + +impl Default for CpuPool { + fn default() -> Self { + Self::new() + } +} + +type SendFn = Box; + +struct CpuPoolState { + min_threads: NonZeroU16, + max_threads: NonZeroU16, + current_threads: AtomicU64, + thread_id: AtomicU64, + sender: flume::Sender, + receiver: flume::Receiver, + threads: Mutex, +} + +impl CpuPoolState { + fn new( + Config { + buffer_size, + min_threads, + max_threads, + }: Config, + ) -> Self { + let (sender, receiver) = flume::bounded(usize::from(buffer_size)); + + let start_threads = u64::from(u16::from(min_threads)); + + let threads = ThreadVec::new(start_threads, usize::from(u16::from(max_threads)), |i| { + spawn(i, receiver.clone()) + }); + + let current_threads = AtomicU64::new(start_threads); + let thread_id = AtomicU64::new(start_threads); + + CpuPoolState { + min_threads, + max_threads, + current_threads, + thread_id, + sender, + receiver, + threads: Mutex::new(threads), + } + } +} + +struct ThreadVec { + threads: Vec, +} + +impl ThreadVec { + fn new(start_threads: u64, max_threads: usize, spawn: F) -> Self + where + F: Fn(u64) -> Thread, + { + let mut threads = Vec::with_capacity(max_threads); + + for i in 0..start_threads { + threads.push((spawn)(i)); + } + + Self { threads } + } + + fn push(&mut self, thread: Thread) { + self.threads.push(thread); + } + + fn pop(&mut self) -> Option { + self.threads.pop() + } +} + +impl Drop for ThreadVec { + fn drop(&mut self) { + for thread in &mut self.threads { + thread.signal.take(); + } + + for thread in &mut self.threads { + if let Some(handle) = thread.handle.take() { + handle.join().expect("Thread panicked"); + } + } + } +} + +struct Thread { + handle: Option>, + signal: Option, + closed: flume::Receiver<()>, +} + +impl Thread { + async fn reap(mut self) { + self.signal.take(); + + let _ = self.closed.recv_async().await; + + if let Some(handle) = self.handle.take() { + handle.join().expect("Thread panicked"); + } + } +} + +struct SendOnDrop { + sender: flume::Sender<()>, +} + +impl Drop for SendOnDrop { + fn drop(&mut self) { + let _ = self.sender.try_send(()); + } +} + +fn spawn(i: u64, receiver: flume::Receiver) -> Thread { + let (signal, signal_rx) = flume::bounded(1); + let (closed_tx, closed) = flume::bounded(1); + + let signal = SendOnDrop { sender: signal }; + let closed_tx = SendOnDrop { sender: closed_tx }; + + let handle = std::thread::Builder::new() + .name(format!("cpupool-{i}")) + .spawn(move || run(receiver, signal_rx, closed_tx)) + .expect("Failed to spawn new thread"); + + Thread { + handle: Some(handle), + signal: Some(signal), + closed, + } +} + +fn run(receiver: flume::Receiver, signal_rx: flume::Receiver<()>, closed_tx: SendOnDrop) { + loop { + let bail = flume::Selector::new() + .recv(&receiver, |res| { + if let Ok(send_fn) = res { + invoke_send_fn(send_fn); + false + } else { + true + } + }) + .recv(&signal_rx, |_res| true) + .wait(); + + if bail { + break; + } + } + + drop(closed_tx); +} + +fn invoke_send_fn(send_fn: SendFn) { + let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || { + (send_fn)(); + })); + + if res.is_err() { + tracing::trace!("panic in spawned task"); + } +}