500 lines
12 KiB
Rust
500 lines
12 KiB
Rust
mod selector;
|
|
|
|
use std::{
|
|
num::{NonZeroU16, NonZeroUsize},
|
|
sync::{atomic::AtomicU64, Arc, Mutex},
|
|
thread::JoinHandle,
|
|
time::Instant,
|
|
};
|
|
|
|
#[derive(Debug)]
|
|
pub struct Config {
|
|
name: &'static str,
|
|
buffer_multiplier: usize,
|
|
min_threads: u16,
|
|
max_threads: u16,
|
|
}
|
|
|
|
impl Config {
|
|
pub fn new() -> Self {
|
|
Config {
|
|
name: "cpupool",
|
|
buffer_multiplier: 8,
|
|
min_threads: 1,
|
|
max_threads: 2,
|
|
}
|
|
}
|
|
|
|
pub fn name(mut self, name: &'static str) -> Self {
|
|
self.name = name;
|
|
self
|
|
}
|
|
|
|
pub fn buffer_multiplier(mut self, buffer_multiplier: usize) -> Self {
|
|
self.buffer_multiplier = buffer_multiplier;
|
|
self
|
|
}
|
|
|
|
pub fn min_threads(mut self, min_threads: u16) -> Self {
|
|
self.min_threads = min_threads;
|
|
self
|
|
}
|
|
|
|
pub fn max_threads(mut self, max_threads: u16) -> Self {
|
|
self.max_threads = max_threads;
|
|
self
|
|
}
|
|
|
|
pub fn build(self) -> Result<CpuPool, ConfigError> {
|
|
let Config {
|
|
name,
|
|
buffer_multiplier,
|
|
min_threads,
|
|
max_threads,
|
|
} = self;
|
|
|
|
if max_threads < min_threads {
|
|
return Err(ConfigError::ThreadCount);
|
|
}
|
|
|
|
let buffer_multiplier = buffer_multiplier
|
|
.try_into()
|
|
.map_err(|_| ConfigError::BufferMultiplier)?;
|
|
|
|
let max_threads = max_threads
|
|
.try_into()
|
|
.map_err(|_| ConfigError::MaxThreads)?;
|
|
|
|
let min_threads = min_threads
|
|
.try_into()
|
|
.map_err(|_| ConfigError::MinThreads)?;
|
|
|
|
Ok(CpuPool {
|
|
state: Arc::new(CpuPoolState::new(
|
|
name,
|
|
buffer_multiplier,
|
|
min_threads,
|
|
max_threads,
|
|
)),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Default for Config {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum ConfigError {
|
|
ThreadCount,
|
|
BufferMultiplier,
|
|
MaxThreads,
|
|
MinThreads,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Canceled;
|
|
|
|
#[derive(Clone)]
|
|
pub struct CpuPool {
|
|
state: Arc<CpuPoolState>,
|
|
}
|
|
|
|
impl CpuPool {
|
|
pub fn new() -> Self {
|
|
Config::default().build().expect("Defaults are valid")
|
|
}
|
|
|
|
pub fn configure() -> Config {
|
|
Config::default()
|
|
}
|
|
|
|
pub async fn spawn<F, T>(&self, send_fn: F) -> Result<T, Canceled>
|
|
where
|
|
F: FnOnce() -> T + Send + 'static,
|
|
T: Send + 'static,
|
|
{
|
|
let (response_tx, response) = flume::bounded(1);
|
|
|
|
let send_fn = Box::new(move || {
|
|
let output = (send_fn)();
|
|
|
|
if response_tx.send(output).is_err() {
|
|
tracing::trace!("Receiver hung up");
|
|
}
|
|
});
|
|
|
|
let res = self.state.sender.try_send(send_fn);
|
|
|
|
let current_threads = self
|
|
.state
|
|
.current_threads
|
|
.load(std::sync::atomic::Ordering::Acquire);
|
|
|
|
if self.state.sender_len() > current_threads || self.state.sender.is_full() {
|
|
if self.push_thread() {
|
|
tracing::trace!("Pushed thread");
|
|
}
|
|
}
|
|
|
|
match res {
|
|
Err(flume::TrySendError::Full(item)) => {
|
|
if self.state.sender.send_async(item).await.is_err() {
|
|
return Err(Canceled);
|
|
}
|
|
}
|
|
Err(flume::TrySendError::Disconnected(_)) => {
|
|
return Err(Canceled);
|
|
}
|
|
Ok(()) => {}
|
|
}
|
|
|
|
let current_threads = self
|
|
.state
|
|
.current_threads
|
|
.load(std::sync::atomic::Ordering::Acquire);
|
|
|
|
let sender_len = self.state.sender_len();
|
|
|
|
if sender_len > current_threads || self.state.sender.is_full() {
|
|
self.push_thread();
|
|
} else if sender_len < u64::from(current_threads.ilog2()) {
|
|
if let Some(thread) = self.pop_thread() {
|
|
thread.reap().await;
|
|
}
|
|
}
|
|
|
|
response.into_recv_async().await.map_err(|_| Canceled)
|
|
}
|
|
|
|
pub async fn close(self) -> bool {
|
|
let Some(mut state) = Arc::into_inner(self.state) else {
|
|
return false;
|
|
};
|
|
|
|
let mut threads = state.take_threads();
|
|
|
|
for thread in &mut threads {
|
|
thread.signal.take();
|
|
}
|
|
|
|
for thread in &mut threads {
|
|
let _ = thread.closed.recv_async().await;
|
|
|
|
if let Some(handle) = thread.handle.take() {
|
|
handle.join().expect("Thread panicked");
|
|
}
|
|
}
|
|
|
|
true
|
|
}
|
|
|
|
fn push_thread(&self) -> bool {
|
|
let current_threads = self
|
|
.state
|
|
.current_threads
|
|
.load(std::sync::atomic::Ordering::Acquire);
|
|
|
|
if current_threads >= u64::from(u16::from(self.state.max_threads)) {
|
|
tracing::trace!("At thread maximum");
|
|
|
|
return false;
|
|
}
|
|
|
|
if self
|
|
.state
|
|
.current_threads
|
|
.compare_exchange(
|
|
current_threads,
|
|
current_threads + 1,
|
|
std::sync::atomic::Ordering::AcqRel,
|
|
std::sync::atomic::Ordering::Relaxed,
|
|
)
|
|
.is_err()
|
|
{
|
|
tracing::trace!("Didn't acquire spawn authorization");
|
|
|
|
return false;
|
|
}
|
|
|
|
// we updated the count, so we have authorization to spawn a new thread
|
|
|
|
let thread_id = self
|
|
.state
|
|
.thread_id
|
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
|
|
|
let thread = spawn(self.state.name, thread_id, self.state.receiver.clone());
|
|
|
|
self.state
|
|
.threads
|
|
.lock()
|
|
.expect("threads lock poison")
|
|
.push(thread);
|
|
|
|
true
|
|
}
|
|
|
|
fn pop_thread(&self) -> Option<Thread> {
|
|
let current_threads = self
|
|
.state
|
|
.current_threads
|
|
.load(std::sync::atomic::Ordering::Acquire);
|
|
|
|
if current_threads <= u64::from(u16::from(self.state.min_threads)) {
|
|
tracing::info!("At thread minimum");
|
|
|
|
return None;
|
|
}
|
|
|
|
if self
|
|
.state
|
|
.current_threads
|
|
.compare_exchange(
|
|
current_threads,
|
|
current_threads - 1,
|
|
std::sync::atomic::Ordering::AcqRel,
|
|
std::sync::atomic::Ordering::Relaxed,
|
|
)
|
|
.is_err()
|
|
{
|
|
tracing::trace!("Didn't acquire reap authorization");
|
|
|
|
return None;
|
|
}
|
|
|
|
// we updated the count, so we have authorization to reap a thread
|
|
|
|
self.state
|
|
.threads
|
|
.lock()
|
|
.expect("threads lock poison")
|
|
.pop()
|
|
}
|
|
}
|
|
|
|
impl Default for CpuPool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
type SendFn = Box<dyn FnOnce() + Send>;
|
|
|
|
struct CpuPoolState {
|
|
name: &'static str,
|
|
min_threads: NonZeroU16,
|
|
max_threads: NonZeroU16,
|
|
current_threads: AtomicU64,
|
|
thread_id: AtomicU64,
|
|
sender: flume::Sender<SendFn>,
|
|
receiver: flume::Receiver<SendFn>,
|
|
threads: Mutex<ThreadVec>,
|
|
}
|
|
|
|
impl CpuPoolState {
|
|
fn new(
|
|
name: &'static str,
|
|
buffer_multiplier: NonZeroUsize,
|
|
min_threads: NonZeroU16,
|
|
max_threads: NonZeroU16,
|
|
) -> Self {
|
|
let thread_capacity = usize::from(u16::from(max_threads));
|
|
|
|
let (sender, receiver) =
|
|
flume::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
|
|
|
|
let start_threads = u64::from(u16::from(min_threads));
|
|
|
|
let threads = ThreadVec::new(start_threads, thread_capacity, |i| {
|
|
spawn(name, i, receiver.clone())
|
|
});
|
|
|
|
let current_threads = AtomicU64::new(start_threads);
|
|
let thread_id = AtomicU64::new(start_threads);
|
|
|
|
CpuPoolState {
|
|
name,
|
|
min_threads,
|
|
max_threads,
|
|
current_threads,
|
|
thread_id,
|
|
sender,
|
|
receiver,
|
|
threads: Mutex::new(threads),
|
|
}
|
|
}
|
|
|
|
fn take_threads(&mut self) -> Vec<Thread> {
|
|
self.threads.lock().expect("threads lock poison").take()
|
|
}
|
|
|
|
fn sender_len(&self) -> u64 {
|
|
u64::try_from(self.sender.len()).unwrap_or(u64::MAX)
|
|
}
|
|
}
|
|
|
|
struct ThreadVec {
|
|
threads: Vec<Thread>,
|
|
}
|
|
|
|
impl ThreadVec {
|
|
fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
|
|
where
|
|
F: Fn(u64) -> Thread,
|
|
{
|
|
let mut threads = Vec::with_capacity(max_threads);
|
|
|
|
for i in 0..start_threads {
|
|
threads.push((spawn)(i));
|
|
}
|
|
|
|
Self { threads }
|
|
}
|
|
|
|
fn push(&mut self, thread: Thread) {
|
|
self.threads.push(thread);
|
|
}
|
|
|
|
fn pop(&mut self) -> Option<Thread> {
|
|
self.threads.pop()
|
|
}
|
|
|
|
fn take(&mut self) -> Vec<Thread> {
|
|
std::mem::take(&mut self.threads)
|
|
}
|
|
}
|
|
|
|
impl Drop for ThreadVec {
|
|
fn drop(&mut self) {
|
|
for thread in &mut self.threads {
|
|
thread.signal.take();
|
|
}
|
|
|
|
for thread in &mut self.threads {
|
|
if let Some(handle) = thread.handle.take() {
|
|
handle.join().expect("Thread panicked");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct Thread {
|
|
handle: Option<JoinHandle<()>>,
|
|
signal: Option<SendOnDrop>,
|
|
closed: flume::Receiver<()>,
|
|
}
|
|
|
|
impl Thread {
|
|
async fn reap(mut self) {
|
|
self.signal.take();
|
|
|
|
let _ = self.closed.recv_async().await;
|
|
|
|
if let Some(handle) = self.handle.take() {
|
|
handle.join().expect("Thread panicked");
|
|
}
|
|
}
|
|
}
|
|
|
|
struct SendOnDrop {
|
|
sender: flume::Sender<()>,
|
|
}
|
|
|
|
impl Drop for SendOnDrop {
|
|
fn drop(&mut self) {
|
|
let _ = self.sender.try_send(());
|
|
}
|
|
}
|
|
|
|
fn spawn(name: &'static str, id: u64, receiver: flume::Receiver<SendFn>) -> Thread {
|
|
let (closed_tx, closed) = flume::bounded(1);
|
|
let (signal, signal_rx) = flume::bounded(1);
|
|
|
|
let signal = SendOnDrop { sender: signal };
|
|
let closed_tx = SendOnDrop { sender: closed_tx };
|
|
|
|
let handle = std::thread::Builder::new()
|
|
.name(format!("{name}-{id}"))
|
|
.spawn(move || run(name, id, receiver, signal_rx, closed_tx))
|
|
.expect("Failed to spawn new thread");
|
|
|
|
Thread {
|
|
handle: Some(handle),
|
|
signal: Some(signal),
|
|
closed,
|
|
}
|
|
}
|
|
|
|
struct MetricsGuard {
|
|
name: &'static str,
|
|
id: u64,
|
|
start: Instant,
|
|
armed: bool,
|
|
}
|
|
|
|
impl MetricsGuard {
|
|
fn guard(name: &'static str, id: u64) -> Self {
|
|
tracing::trace!("Starting {name}-{id}");
|
|
metrics::increment_counter!(format!("async-cpupool.{name}.launched"), "id" => id.to_string());
|
|
|
|
MetricsGuard {
|
|
name,
|
|
id,
|
|
start: Instant::now(),
|
|
armed: true,
|
|
}
|
|
}
|
|
|
|
fn disarm(mut self) {
|
|
self.armed = false;
|
|
}
|
|
}
|
|
|
|
impl Drop for MetricsGuard {
|
|
fn drop(&mut self) {
|
|
metrics::increment_counter!(format!("async-cpupool.{}.closed", self.name), "clean" => (!self.armed).to_string(), "id" => self.id.to_string());
|
|
metrics::histogram!(format!("async-cpupool.{}.duration", self.name), self.start.elapsed().as_secs_f64(), "clean" => (!self.armed).to_string(), "id" => self.id.to_string());
|
|
tracing::trace!("Stopping {}-{}", self.name, self.id);
|
|
}
|
|
}
|
|
|
|
fn run(
|
|
name: &'static str,
|
|
id: u64,
|
|
receiver: flume::Receiver<SendFn>,
|
|
signal: flume::Receiver<()>,
|
|
closed_tx: SendOnDrop,
|
|
) {
|
|
let guard = MetricsGuard::guard(name, id);
|
|
|
|
loop {
|
|
match selector::blocking_select(signal.recv_async(), receiver.recv_async()) {
|
|
selector::Either::Left(_) | selector::Either::Right(Err(_)) => break,
|
|
selector::Either::Right(Ok(send_fn)) => invoke_send_fn(name, id, send_fn),
|
|
}
|
|
}
|
|
|
|
guard.disarm();
|
|
|
|
drop(closed_tx);
|
|
}
|
|
|
|
fn invoke_send_fn(name: &'static str, id: u64, send_fn: SendFn) {
|
|
let start = Instant::now();
|
|
metrics::increment_counter!(format!("async-cpupool.{name}.operation.start"), "id" => id.to_string());
|
|
|
|
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
|
|
(send_fn)();
|
|
}));
|
|
|
|
metrics::increment_counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string(), "id" => id.to_string());
|
|
metrics::histogram!(format!("async-cpupool.{name}.operation.duration"), start.elapsed().as_secs_f64(), "complete" => res.is_ok().to_string(), "id" => id.to_string());
|
|
|
|
if let Err(e) = res {
|
|
tracing::trace!("panic in spawned task: {e:?}");
|
|
}
|
|
}
|