Simple CPU Pool
This commit is contained in:
commit
eec9ab16fd
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
/target
|
||||
/Cargo.lock
|
||||
/.direnv
|
||||
/.envrc
|
10
Cargo.toml
Normal file
10
Cargo.toml
Normal file
|
@ -0,0 +1,10 @@
|
|||
[package]
|
||||
name = "async-cpupool"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
flume = "0.11.0"
|
||||
tracing = "0.1.40"
|
61
flake.lock
Normal file
61
flake.lock
Normal file
|
@ -0,0 +1,61 @@
|
|||
{
|
||||
"nodes": {
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1694529238,
|
||||
"narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "ff7b65b44d01cf9ba6a71320833626af21126384",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1700612854,
|
||||
"narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
25
flake.nix
Normal file
25
flake.nix
Normal file
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
description = "async-cpupool";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs, flake-utils }:
|
||||
flake-utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
};
|
||||
in
|
||||
{
|
||||
packages.default = pkgs.hello;
|
||||
|
||||
devShell = with pkgs; mkShell {
|
||||
nativeBuildInputs = [ cargo cargo-outdated clippy rust-analyzer rustc rustfmt ];
|
||||
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
};
|
||||
});
|
||||
}
|
343
src/lib.rs
Normal file
343
src/lib.rs
Normal file
|
@ -0,0 +1,343 @@
|
|||
use std::{
|
||||
num::{NonZeroU16, NonZeroUsize},
|
||||
sync::{atomic::AtomicU64, Arc, Mutex},
|
||||
thread::JoinHandle,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
buffer_size: NonZeroUsize,
|
||||
min_threads: NonZeroU16,
|
||||
max_threads: NonZeroU16,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new() -> Self {
|
||||
Config {
|
||||
buffer_size: 8usize.try_into().expect("valid nonzero usize"),
|
||||
min_threads: 1u16.try_into().expect("valid nonzero u16"),
|
||||
max_threads: 2u16.try_into().expect("valid nonzero u16"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer_size(mut self, buffer_size: NonZeroUsize) -> Self {
|
||||
self.buffer_size = buffer_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn min_threads(mut self, min_threads: NonZeroU16) -> Self {
|
||||
self.min_threads = min_threads;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_threads(mut self, max_threads: NonZeroU16) -> Self {
|
||||
self.max_threads = max_threads;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<CpuPool, ConfigError> {
|
||||
if self.max_threads < self.min_threads {
|
||||
return Err(ConfigError::ThreadCount);
|
||||
}
|
||||
|
||||
Ok(CpuPool {
|
||||
state: Arc::new(CpuPoolState::new(self)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ConfigError {
|
||||
ThreadCount,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Canceled;
|
||||
|
||||
pub struct CpuPool {
|
||||
state: Arc<CpuPoolState>,
|
||||
}
|
||||
|
||||
impl CpuPool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(CpuPoolState::new(Config::default())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn configure() -> Config {
|
||||
Config::default()
|
||||
}
|
||||
|
||||
pub async fn spawn<F, T>(&self, send_fn: F) -> Result<T, Canceled>
|
||||
where
|
||||
F: FnOnce() -> T + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
if self.state.sender.is_full() {
|
||||
self.push_thread();
|
||||
}
|
||||
|
||||
let (response_tx, response) = flume::bounded(1);
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send_async(Box::new(move || {
|
||||
let output = (send_fn)();
|
||||
|
||||
if response_tx.send(output).is_err() {
|
||||
tracing::trace!("Receiver hung up");
|
||||
}
|
||||
}))
|
||||
.await
|
||||
.expect("No receiver");
|
||||
|
||||
if self.state.sender.is_empty() {
|
||||
if let Some(thread) = self.pop_thread() {
|
||||
thread.reap().await;
|
||||
}
|
||||
}
|
||||
|
||||
response.recv_async().await.map_err(|_| Canceled)
|
||||
}
|
||||
|
||||
fn push_thread(&self) {
|
||||
let current_threads = self
|
||||
.state
|
||||
.current_threads
|
||||
.load(std::sync::atomic::Ordering::Acquire);
|
||||
|
||||
if current_threads >= u64::from(u16::from(self.state.max_threads)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if self
|
||||
.state
|
||||
.current_threads
|
||||
.compare_exchange(
|
||||
current_threads,
|
||||
current_threads + 1,
|
||||
std::sync::atomic::Ordering::AcqRel,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// we updated the count, so we have authorization to spawn a new thread
|
||||
|
||||
let thread_id = self
|
||||
.state
|
||||
.thread_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let thread = spawn(thread_id, self.state.receiver.clone());
|
||||
|
||||
self.state.threads.lock().unwrap().push(thread);
|
||||
}
|
||||
|
||||
fn pop_thread(&self) -> Option<Thread> {
|
||||
let current_threads = self
|
||||
.state
|
||||
.current_threads
|
||||
.load(std::sync::atomic::Ordering::Acquire);
|
||||
|
||||
if current_threads <= u64::from(u16::from(self.state.min_threads)) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self
|
||||
.state
|
||||
.current_threads
|
||||
.compare_exchange(
|
||||
current_threads,
|
||||
current_threads - 1,
|
||||
std::sync::atomic::Ordering::AcqRel,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// we updated the count, so we have authorization to reap a thread
|
||||
|
||||
self.state.threads.lock().unwrap().pop()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CpuPool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
type SendFn = Box<dyn FnOnce() + Send>;
|
||||
|
||||
struct CpuPoolState {
|
||||
min_threads: NonZeroU16,
|
||||
max_threads: NonZeroU16,
|
||||
current_threads: AtomicU64,
|
||||
thread_id: AtomicU64,
|
||||
sender: flume::Sender<SendFn>,
|
||||
receiver: flume::Receiver<SendFn>,
|
||||
threads: Mutex<ThreadVec>,
|
||||
}
|
||||
|
||||
impl CpuPoolState {
|
||||
fn new(
|
||||
Config {
|
||||
buffer_size,
|
||||
min_threads,
|
||||
max_threads,
|
||||
}: Config,
|
||||
) -> Self {
|
||||
let (sender, receiver) = flume::bounded(usize::from(buffer_size));
|
||||
|
||||
let start_threads = u64::from(u16::from(min_threads));
|
||||
|
||||
let threads = ThreadVec::new(start_threads, usize::from(u16::from(max_threads)), |i| {
|
||||
spawn(i, receiver.clone())
|
||||
});
|
||||
|
||||
let current_threads = AtomicU64::new(start_threads);
|
||||
let thread_id = AtomicU64::new(start_threads);
|
||||
|
||||
CpuPoolState {
|
||||
min_threads,
|
||||
max_threads,
|
||||
current_threads,
|
||||
thread_id,
|
||||
sender,
|
||||
receiver,
|
||||
threads: Mutex::new(threads),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ThreadVec {
|
||||
threads: Vec<Thread>,
|
||||
}
|
||||
|
||||
impl ThreadVec {
|
||||
fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
|
||||
where
|
||||
F: Fn(u64) -> Thread,
|
||||
{
|
||||
let mut threads = Vec::with_capacity(max_threads);
|
||||
|
||||
for i in 0..start_threads {
|
||||
threads.push((spawn)(i));
|
||||
}
|
||||
|
||||
Self { threads }
|
||||
}
|
||||
|
||||
fn push(&mut self, thread: Thread) {
|
||||
self.threads.push(thread);
|
||||
}
|
||||
|
||||
fn pop(&mut self) -> Option<Thread> {
|
||||
self.threads.pop()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ThreadVec {
|
||||
fn drop(&mut self) {
|
||||
for thread in &mut self.threads {
|
||||
thread.signal.take();
|
||||
}
|
||||
|
||||
for thread in &mut self.threads {
|
||||
if let Some(handle) = thread.handle.take() {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Thread {
|
||||
handle: Option<JoinHandle<()>>,
|
||||
signal: Option<SendOnDrop>,
|
||||
closed: flume::Receiver<()>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
async fn reap(mut self) {
|
||||
self.signal.take();
|
||||
|
||||
let _ = self.closed.recv_async().await;
|
||||
|
||||
if let Some(handle) = self.handle.take() {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SendOnDrop {
|
||||
sender: flume::Sender<()>,
|
||||
}
|
||||
|
||||
impl Drop for SendOnDrop {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.sender.try_send(());
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn(i: u64, receiver: flume::Receiver<SendFn>) -> Thread {
|
||||
let (signal, signal_rx) = flume::bounded(1);
|
||||
let (closed_tx, closed) = flume::bounded(1);
|
||||
|
||||
let signal = SendOnDrop { sender: signal };
|
||||
let closed_tx = SendOnDrop { sender: closed_tx };
|
||||
|
||||
let handle = std::thread::Builder::new()
|
||||
.name(format!("cpupool-{i}"))
|
||||
.spawn(move || run(receiver, signal_rx, closed_tx))
|
||||
.expect("Failed to spawn new thread");
|
||||
|
||||
Thread {
|
||||
handle: Some(handle),
|
||||
signal: Some(signal),
|
||||
closed,
|
||||
}
|
||||
}
|
||||
|
||||
fn run(receiver: flume::Receiver<SendFn>, signal_rx: flume::Receiver<()>, closed_tx: SendOnDrop) {
|
||||
loop {
|
||||
let bail = flume::Selector::new()
|
||||
.recv(&receiver, |res| {
|
||||
if let Ok(send_fn) = res {
|
||||
invoke_send_fn(send_fn);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
.recv(&signal_rx, |_res| true)
|
||||
.wait();
|
||||
|
||||
if bail {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
drop(closed_tx);
|
||||
}
|
||||
|
||||
fn invoke_send_fn(send_fn: SendFn) {
|
||||
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
|
||||
(send_fn)();
|
||||
}));
|
||||
|
||||
if res.is_err() {
|
||||
tracing::trace!("panic in spawned task");
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue