Simple CPU Pool

This commit is contained in:
asonix 2023-11-24 13:42:10 -06:00
commit eec9ab16fd
5 changed files with 443 additions and 0 deletions

4
.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
/target
/Cargo.lock
/.direnv
/.envrc

10
Cargo.toml Normal file
View file

@ -0,0 +1,10 @@
[package]
name = "async-cpupool"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
flume = "0.11.0"
tracing = "0.1.40"

61
flake.lock Normal file
View file

@ -0,0 +1,61 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1694529238,
"narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "ff7b65b44d01cf9ba6a71320833626af21126384",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1700612854,
"narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

25
flake.nix Normal file
View file

@ -0,0 +1,25 @@
{
description = "async-cpupool";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
};
outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system:
let
pkgs = import nixpkgs {
inherit system;
};
in
{
packages.default = pkgs.hello;
devShell = with pkgs; mkShell {
nativeBuildInputs = [ cargo cargo-outdated clippy rust-analyzer rustc rustfmt ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
};
});
}

343
src/lib.rs Normal file
View file

@ -0,0 +1,343 @@
use std::{
num::{NonZeroU16, NonZeroUsize},
sync::{atomic::AtomicU64, Arc, Mutex},
thread::JoinHandle,
};
#[derive(Debug)]
pub struct Config {
buffer_size: NonZeroUsize,
min_threads: NonZeroU16,
max_threads: NonZeroU16,
}
impl Config {
pub fn new() -> Self {
Config {
buffer_size: 8usize.try_into().expect("valid nonzero usize"),
min_threads: 1u16.try_into().expect("valid nonzero u16"),
max_threads: 2u16.try_into().expect("valid nonzero u16"),
}
}
pub fn buffer_size(mut self, buffer_size: NonZeroUsize) -> Self {
self.buffer_size = buffer_size;
self
}
pub fn min_threads(mut self, min_threads: NonZeroU16) -> Self {
self.min_threads = min_threads;
self
}
pub fn max_threads(mut self, max_threads: NonZeroU16) -> Self {
self.max_threads = max_threads;
self
}
pub fn build(self) -> Result<CpuPool, ConfigError> {
if self.max_threads < self.min_threads {
return Err(ConfigError::ThreadCount);
}
Ok(CpuPool {
state: Arc::new(CpuPoolState::new(self)),
})
}
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum ConfigError {
ThreadCount,
}
#[derive(Debug)]
pub struct Canceled;
pub struct CpuPool {
state: Arc<CpuPoolState>,
}
impl CpuPool {
pub fn new() -> Self {
Self {
state: Arc::new(CpuPoolState::new(Config::default())),
}
}
pub fn configure() -> Config {
Config::default()
}
pub async fn spawn<F, T>(&self, send_fn: F) -> Result<T, Canceled>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
if self.state.sender.is_full() {
self.push_thread();
}
let (response_tx, response) = flume::bounded(1);
self.state
.sender
.send_async(Box::new(move || {
let output = (send_fn)();
if response_tx.send(output).is_err() {
tracing::trace!("Receiver hung up");
}
}))
.await
.expect("No receiver");
if self.state.sender.is_empty() {
if let Some(thread) = self.pop_thread() {
thread.reap().await;
}
}
response.recv_async().await.map_err(|_| Canceled)
}
fn push_thread(&self) {
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if current_threads >= u64::from(u16::from(self.state.max_threads)) {
return;
}
if self
.state
.current_threads
.compare_exchange(
current_threads,
current_threads + 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
return;
}
// we updated the count, so we have authorization to spawn a new thread
let thread_id = self
.state
.thread_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let thread = spawn(thread_id, self.state.receiver.clone());
self.state.threads.lock().unwrap().push(thread);
}
fn pop_thread(&self) -> Option<Thread> {
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if current_threads <= u64::from(u16::from(self.state.min_threads)) {
return None;
}
if self
.state
.current_threads
.compare_exchange(
current_threads,
current_threads - 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
return None;
}
// we updated the count, so we have authorization to reap a thread
self.state.threads.lock().unwrap().pop()
}
}
impl Default for CpuPool {
fn default() -> Self {
Self::new()
}
}
type SendFn = Box<dyn FnOnce() + Send>;
struct CpuPoolState {
min_threads: NonZeroU16,
max_threads: NonZeroU16,
current_threads: AtomicU64,
thread_id: AtomicU64,
sender: flume::Sender<SendFn>,
receiver: flume::Receiver<SendFn>,
threads: Mutex<ThreadVec>,
}
impl CpuPoolState {
fn new(
Config {
buffer_size,
min_threads,
max_threads,
}: Config,
) -> Self {
let (sender, receiver) = flume::bounded(usize::from(buffer_size));
let start_threads = u64::from(u16::from(min_threads));
let threads = ThreadVec::new(start_threads, usize::from(u16::from(max_threads)), |i| {
spawn(i, receiver.clone())
});
let current_threads = AtomicU64::new(start_threads);
let thread_id = AtomicU64::new(start_threads);
CpuPoolState {
min_threads,
max_threads,
current_threads,
thread_id,
sender,
receiver,
threads: Mutex::new(threads),
}
}
}
struct ThreadVec {
threads: Vec<Thread>,
}
impl ThreadVec {
fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
where
F: Fn(u64) -> Thread,
{
let mut threads = Vec::with_capacity(max_threads);
for i in 0..start_threads {
threads.push((spawn)(i));
}
Self { threads }
}
fn push(&mut self, thread: Thread) {
self.threads.push(thread);
}
fn pop(&mut self) -> Option<Thread> {
self.threads.pop()
}
}
impl Drop for ThreadVec {
fn drop(&mut self) {
for thread in &mut self.threads {
thread.signal.take();
}
for thread in &mut self.threads {
if let Some(handle) = thread.handle.take() {
handle.join().expect("Thread panicked");
}
}
}
}
struct Thread {
handle: Option<JoinHandle<()>>,
signal: Option<SendOnDrop>,
closed: flume::Receiver<()>,
}
impl Thread {
async fn reap(mut self) {
self.signal.take();
let _ = self.closed.recv_async().await;
if let Some(handle) = self.handle.take() {
handle.join().expect("Thread panicked");
}
}
}
struct SendOnDrop {
sender: flume::Sender<()>,
}
impl Drop for SendOnDrop {
fn drop(&mut self) {
let _ = self.sender.try_send(());
}
}
fn spawn(i: u64, receiver: flume::Receiver<SendFn>) -> Thread {
let (signal, signal_rx) = flume::bounded(1);
let (closed_tx, closed) = flume::bounded(1);
let signal = SendOnDrop { sender: signal };
let closed_tx = SendOnDrop { sender: closed_tx };
let handle = std::thread::Builder::new()
.name(format!("cpupool-{i}"))
.spawn(move || run(receiver, signal_rx, closed_tx))
.expect("Failed to spawn new thread");
Thread {
handle: Some(handle),
signal: Some(signal),
closed,
}
}
fn run(receiver: flume::Receiver<SendFn>, signal_rx: flume::Receiver<()>, closed_tx: SendOnDrop) {
loop {
let bail = flume::Selector::new()
.recv(&receiver, |res| {
if let Ok(send_fn) = res {
invoke_send_fn(send_fn);
false
} else {
true
}
})
.recv(&signal_rx, |_res| true)
.wait();
if bail {
break;
}
}
drop(closed_tx);
}
fn invoke_send_fn(send_fn: SendFn) {
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
(send_fn)();
}));
if res.is_err() {
tracing::trace!("panic in spawned task");
}
}