Simple CPU Pool
This commit is contained in:
commit
eec9ab16fd
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
/target
|
||||||
|
/Cargo.lock
|
||||||
|
/.direnv
|
||||||
|
/.envrc
|
10
Cargo.toml
Normal file
10
Cargo.toml
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
[package]
|
||||||
|
name = "async-cpupool"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
flume = "0.11.0"
|
||||||
|
tracing = "0.1.40"
|
61
flake.lock
Normal file
61
flake.lock
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"flake-utils": {
|
||||||
|
"inputs": {
|
||||||
|
"systems": "systems"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1694529238,
|
||||||
|
"narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=",
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"rev": "ff7b65b44d01cf9ba6a71320833626af21126384",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1700612854,
|
||||||
|
"narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"ref": "nixos-unstable",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-utils": "flake-utils",
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"systems": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1681028828,
|
||||||
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
25
flake.nix
Normal file
25
flake.nix
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
{
|
||||||
|
description = "async-cpupool";
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||||
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
|
};
|
||||||
|
|
||||||
|
outputs = { self, nixpkgs, flake-utils }:
|
||||||
|
flake-utils.lib.eachDefaultSystem (system:
|
||||||
|
let
|
||||||
|
pkgs = import nixpkgs {
|
||||||
|
inherit system;
|
||||||
|
};
|
||||||
|
in
|
||||||
|
{
|
||||||
|
packages.default = pkgs.hello;
|
||||||
|
|
||||||
|
devShell = with pkgs; mkShell {
|
||||||
|
nativeBuildInputs = [ cargo cargo-outdated clippy rust-analyzer rustc rustfmt ];
|
||||||
|
|
||||||
|
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
343
src/lib.rs
Normal file
343
src/lib.rs
Normal file
|
@ -0,0 +1,343 @@
|
||||||
|
use std::{
|
||||||
|
num::{NonZeroU16, NonZeroUsize},
|
||||||
|
sync::{atomic::AtomicU64, Arc, Mutex},
|
||||||
|
thread::JoinHandle,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Config {
|
||||||
|
buffer_size: NonZeroUsize,
|
||||||
|
min_threads: NonZeroU16,
|
||||||
|
max_threads: NonZeroU16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Config {
|
||||||
|
buffer_size: 8usize.try_into().expect("valid nonzero usize"),
|
||||||
|
min_threads: 1u16.try_into().expect("valid nonzero u16"),
|
||||||
|
max_threads: 2u16.try_into().expect("valid nonzero u16"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn buffer_size(mut self, buffer_size: NonZeroUsize) -> Self {
|
||||||
|
self.buffer_size = buffer_size;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn min_threads(mut self, min_threads: NonZeroU16) -> Self {
|
||||||
|
self.min_threads = min_threads;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_threads(mut self, max_threads: NonZeroU16) -> Self {
|
||||||
|
self.max_threads = max_threads;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> Result<CpuPool, ConfigError> {
|
||||||
|
if self.max_threads < self.min_threads {
|
||||||
|
return Err(ConfigError::ThreadCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(CpuPool {
|
||||||
|
state: Arc::new(CpuPoolState::new(self)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ConfigError {
|
||||||
|
ThreadCount,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Canceled;
|
||||||
|
|
||||||
|
pub struct CpuPool {
|
||||||
|
state: Arc<CpuPoolState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CpuPool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
state: Arc::new(CpuPoolState::new(Config::default())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn configure() -> Config {
|
||||||
|
Config::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn spawn<F, T>(&self, send_fn: F) -> Result<T, Canceled>
|
||||||
|
where
|
||||||
|
F: FnOnce() -> T + Send + 'static,
|
||||||
|
T: Send + 'static,
|
||||||
|
{
|
||||||
|
if self.state.sender.is_full() {
|
||||||
|
self.push_thread();
|
||||||
|
}
|
||||||
|
|
||||||
|
let (response_tx, response) = flume::bounded(1);
|
||||||
|
|
||||||
|
self.state
|
||||||
|
.sender
|
||||||
|
.send_async(Box::new(move || {
|
||||||
|
let output = (send_fn)();
|
||||||
|
|
||||||
|
if response_tx.send(output).is_err() {
|
||||||
|
tracing::trace!("Receiver hung up");
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.expect("No receiver");
|
||||||
|
|
||||||
|
if self.state.sender.is_empty() {
|
||||||
|
if let Some(thread) = self.pop_thread() {
|
||||||
|
thread.reap().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response.recv_async().await.map_err(|_| Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_thread(&self) {
|
||||||
|
let current_threads = self
|
||||||
|
.state
|
||||||
|
.current_threads
|
||||||
|
.load(std::sync::atomic::Ordering::Acquire);
|
||||||
|
|
||||||
|
if current_threads >= u64::from(u16::from(self.state.max_threads)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self
|
||||||
|
.state
|
||||||
|
.current_threads
|
||||||
|
.compare_exchange(
|
||||||
|
current_threads,
|
||||||
|
current_threads + 1,
|
||||||
|
std::sync::atomic::Ordering::AcqRel,
|
||||||
|
std::sync::atomic::Ordering::Relaxed,
|
||||||
|
)
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we updated the count, so we have authorization to spawn a new thread
|
||||||
|
|
||||||
|
let thread_id = self
|
||||||
|
.state
|
||||||
|
.thread_id
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
|
||||||
|
let thread = spawn(thread_id, self.state.receiver.clone());
|
||||||
|
|
||||||
|
self.state.threads.lock().unwrap().push(thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pop_thread(&self) -> Option<Thread> {
|
||||||
|
let current_threads = self
|
||||||
|
.state
|
||||||
|
.current_threads
|
||||||
|
.load(std::sync::atomic::Ordering::Acquire);
|
||||||
|
|
||||||
|
if current_threads <= u64::from(u16::from(self.state.min_threads)) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self
|
||||||
|
.state
|
||||||
|
.current_threads
|
||||||
|
.compare_exchange(
|
||||||
|
current_threads,
|
||||||
|
current_threads - 1,
|
||||||
|
std::sync::atomic::Ordering::AcqRel,
|
||||||
|
std::sync::atomic::Ordering::Relaxed,
|
||||||
|
)
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we updated the count, so we have authorization to reap a thread
|
||||||
|
|
||||||
|
self.state.threads.lock().unwrap().pop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CpuPool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SendFn = Box<dyn FnOnce() + Send>;
|
||||||
|
|
||||||
|
struct CpuPoolState {
|
||||||
|
min_threads: NonZeroU16,
|
||||||
|
max_threads: NonZeroU16,
|
||||||
|
current_threads: AtomicU64,
|
||||||
|
thread_id: AtomicU64,
|
||||||
|
sender: flume::Sender<SendFn>,
|
||||||
|
receiver: flume::Receiver<SendFn>,
|
||||||
|
threads: Mutex<ThreadVec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CpuPoolState {
|
||||||
|
fn new(
|
||||||
|
Config {
|
||||||
|
buffer_size,
|
||||||
|
min_threads,
|
||||||
|
max_threads,
|
||||||
|
}: Config,
|
||||||
|
) -> Self {
|
||||||
|
let (sender, receiver) = flume::bounded(usize::from(buffer_size));
|
||||||
|
|
||||||
|
let start_threads = u64::from(u16::from(min_threads));
|
||||||
|
|
||||||
|
let threads = ThreadVec::new(start_threads, usize::from(u16::from(max_threads)), |i| {
|
||||||
|
spawn(i, receiver.clone())
|
||||||
|
});
|
||||||
|
|
||||||
|
let current_threads = AtomicU64::new(start_threads);
|
||||||
|
let thread_id = AtomicU64::new(start_threads);
|
||||||
|
|
||||||
|
CpuPoolState {
|
||||||
|
min_threads,
|
||||||
|
max_threads,
|
||||||
|
current_threads,
|
||||||
|
thread_id,
|
||||||
|
sender,
|
||||||
|
receiver,
|
||||||
|
threads: Mutex::new(threads),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ThreadVec {
|
||||||
|
threads: Vec<Thread>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadVec {
|
||||||
|
fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
|
||||||
|
where
|
||||||
|
F: Fn(u64) -> Thread,
|
||||||
|
{
|
||||||
|
let mut threads = Vec::with_capacity(max_threads);
|
||||||
|
|
||||||
|
for i in 0..start_threads {
|
||||||
|
threads.push((spawn)(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
Self { threads }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push(&mut self, thread: Thread) {
|
||||||
|
self.threads.push(thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pop(&mut self) -> Option<Thread> {
|
||||||
|
self.threads.pop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ThreadVec {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
for thread in &mut self.threads {
|
||||||
|
thread.signal.take();
|
||||||
|
}
|
||||||
|
|
||||||
|
for thread in &mut self.threads {
|
||||||
|
if let Some(handle) = thread.handle.take() {
|
||||||
|
handle.join().expect("Thread panicked");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Thread {
|
||||||
|
handle: Option<JoinHandle<()>>,
|
||||||
|
signal: Option<SendOnDrop>,
|
||||||
|
closed: flume::Receiver<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Thread {
|
||||||
|
async fn reap(mut self) {
|
||||||
|
self.signal.take();
|
||||||
|
|
||||||
|
let _ = self.closed.recv_async().await;
|
||||||
|
|
||||||
|
if let Some(handle) = self.handle.take() {
|
||||||
|
handle.join().expect("Thread panicked");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SendOnDrop {
|
||||||
|
sender: flume::Sender<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for SendOnDrop {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let _ = self.sender.try_send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn(i: u64, receiver: flume::Receiver<SendFn>) -> Thread {
|
||||||
|
let (signal, signal_rx) = flume::bounded(1);
|
||||||
|
let (closed_tx, closed) = flume::bounded(1);
|
||||||
|
|
||||||
|
let signal = SendOnDrop { sender: signal };
|
||||||
|
let closed_tx = SendOnDrop { sender: closed_tx };
|
||||||
|
|
||||||
|
let handle = std::thread::Builder::new()
|
||||||
|
.name(format!("cpupool-{i}"))
|
||||||
|
.spawn(move || run(receiver, signal_rx, closed_tx))
|
||||||
|
.expect("Failed to spawn new thread");
|
||||||
|
|
||||||
|
Thread {
|
||||||
|
handle: Some(handle),
|
||||||
|
signal: Some(signal),
|
||||||
|
closed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(receiver: flume::Receiver<SendFn>, signal_rx: flume::Receiver<()>, closed_tx: SendOnDrop) {
|
||||||
|
loop {
|
||||||
|
let bail = flume::Selector::new()
|
||||||
|
.recv(&receiver, |res| {
|
||||||
|
if let Ok(send_fn) = res {
|
||||||
|
invoke_send_fn(send_fn);
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.recv(&signal_rx, |_res| true)
|
||||||
|
.wait();
|
||||||
|
|
||||||
|
if bail {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(closed_tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn invoke_send_fn(send_fn: SendFn) {
|
||||||
|
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
|
||||||
|
(send_fn)();
|
||||||
|
}));
|
||||||
|
|
||||||
|
if res.is_err() {
|
||||||
|
tracing::trace!("panic in spawned task");
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue