async-cpupool/src/lib.rs

500 lines
12 KiB
Rust

mod selector;
use std::{
num::{NonZeroU16, NonZeroUsize},
sync::{atomic::AtomicU64, Arc, Mutex},
thread::JoinHandle,
time::Instant,
};
#[derive(Debug)]
pub struct Config {
name: &'static str,
buffer_multiplier: usize,
min_threads: u16,
max_threads: u16,
}
impl Config {
pub fn new() -> Self {
Config {
name: "cpupool",
buffer_multiplier: 8,
min_threads: 1,
max_threads: 2,
}
}
pub fn name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
pub fn buffer_multiplier(mut self, buffer_multiplier: usize) -> Self {
self.buffer_multiplier = buffer_multiplier;
self
}
pub fn min_threads(mut self, min_threads: u16) -> Self {
self.min_threads = min_threads;
self
}
pub fn max_threads(mut self, max_threads: u16) -> Self {
self.max_threads = max_threads;
self
}
pub fn build(self) -> Result<CpuPool, ConfigError> {
let Config {
name,
buffer_multiplier,
min_threads,
max_threads,
} = self;
if max_threads < min_threads {
return Err(ConfigError::ThreadCount);
}
let buffer_multiplier = buffer_multiplier
.try_into()
.map_err(|_| ConfigError::BufferMultiplier)?;
let max_threads = max_threads
.try_into()
.map_err(|_| ConfigError::MaxThreads)?;
let min_threads = min_threads
.try_into()
.map_err(|_| ConfigError::MinThreads)?;
Ok(CpuPool {
state: Arc::new(CpuPoolState::new(
name,
buffer_multiplier,
min_threads,
max_threads,
)),
})
}
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum ConfigError {
ThreadCount,
BufferMultiplier,
MaxThreads,
MinThreads,
}
#[derive(Debug)]
pub struct Canceled;
#[derive(Clone)]
pub struct CpuPool {
state: Arc<CpuPoolState>,
}
impl CpuPool {
pub fn new() -> Self {
Config::default().build().expect("Defaults are valid")
}
pub fn configure() -> Config {
Config::default()
}
pub async fn spawn<F, T>(&self, send_fn: F) -> Result<T, Canceled>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (response_tx, response) = flume::bounded(1);
let send_fn = Box::new(move || {
let output = (send_fn)();
if response_tx.send(output).is_err() {
tracing::trace!("Receiver hung up");
}
});
let res = self.state.sender.try_send(send_fn);
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if self.state.sender_len() > current_threads || self.state.sender.is_full() {
if self.push_thread() {
tracing::trace!("Pushed thread");
}
}
match res {
Err(flume::TrySendError::Full(item)) => {
if self.state.sender.send_async(item).await.is_err() {
return Err(Canceled);
}
}
Err(flume::TrySendError::Disconnected(_)) => {
return Err(Canceled);
}
Ok(()) => {}
}
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
let sender_len = self.state.sender_len();
if sender_len > current_threads || self.state.sender.is_full() {
self.push_thread();
} else if sender_len < u64::from(current_threads.ilog2()) {
if let Some(thread) = self.pop_thread() {
thread.reap().await;
}
}
response.into_recv_async().await.map_err(|_| Canceled)
}
pub async fn close(self) -> bool {
let Some(mut state) = Arc::into_inner(self.state) else {
return false;
};
let mut threads = state.take_threads();
for thread in &mut threads {
thread.signal.take();
}
for thread in &mut threads {
let _ = thread.closed.recv_async().await;
if let Some(handle) = thread.handle.take() {
handle.join().expect("Thread panicked");
}
}
true
}
fn push_thread(&self) -> bool {
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if current_threads >= u64::from(u16::from(self.state.max_threads)) {
tracing::trace!("At thread maximum");
return false;
}
if self
.state
.current_threads
.compare_exchange(
current_threads,
current_threads + 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
tracing::trace!("Didn't acquire spawn authorization");
return false;
}
// we updated the count, so we have authorization to spawn a new thread
let thread_id = self
.state
.thread_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let thread = spawn(self.state.name, thread_id, self.state.receiver.clone());
self.state
.threads
.lock()
.expect("threads lock poison")
.push(thread);
true
}
fn pop_thread(&self) -> Option<Thread> {
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if current_threads <= u64::from(u16::from(self.state.min_threads)) {
tracing::info!("At thread minimum");
return None;
}
if self
.state
.current_threads
.compare_exchange(
current_threads,
current_threads - 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
tracing::trace!("Didn't acquire reap authorization");
return None;
}
// we updated the count, so we have authorization to reap a thread
self.state
.threads
.lock()
.expect("threads lock poison")
.pop()
}
}
impl Default for CpuPool {
fn default() -> Self {
Self::new()
}
}
type SendFn = Box<dyn FnOnce() + Send>;
struct CpuPoolState {
name: &'static str,
min_threads: NonZeroU16,
max_threads: NonZeroU16,
current_threads: AtomicU64,
thread_id: AtomicU64,
sender: flume::Sender<SendFn>,
receiver: flume::Receiver<SendFn>,
threads: Mutex<ThreadVec>,
}
impl CpuPoolState {
fn new(
name: &'static str,
buffer_multiplier: NonZeroUsize,
min_threads: NonZeroU16,
max_threads: NonZeroU16,
) -> Self {
let thread_capacity = usize::from(u16::from(max_threads));
let (sender, receiver) =
flume::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
let start_threads = u64::from(u16::from(min_threads));
let threads = ThreadVec::new(start_threads, thread_capacity, |i| {
spawn(name, i, receiver.clone())
});
let current_threads = AtomicU64::new(start_threads);
let thread_id = AtomicU64::new(start_threads);
CpuPoolState {
name,
min_threads,
max_threads,
current_threads,
thread_id,
sender,
receiver,
threads: Mutex::new(threads),
}
}
fn take_threads(&mut self) -> Vec<Thread> {
self.threads.lock().expect("threads lock poison").take()
}
fn sender_len(&self) -> u64 {
u64::try_from(self.sender.len()).unwrap_or(u64::MAX)
}
}
struct ThreadVec {
threads: Vec<Thread>,
}
impl ThreadVec {
fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
where
F: Fn(u64) -> Thread,
{
let mut threads = Vec::with_capacity(max_threads);
for i in 0..start_threads {
threads.push((spawn)(i));
}
Self { threads }
}
fn push(&mut self, thread: Thread) {
self.threads.push(thread);
}
fn pop(&mut self) -> Option<Thread> {
self.threads.pop()
}
fn take(&mut self) -> Vec<Thread> {
std::mem::take(&mut self.threads)
}
}
impl Drop for ThreadVec {
fn drop(&mut self) {
for thread in &mut self.threads {
thread.signal.take();
}
for thread in &mut self.threads {
if let Some(handle) = thread.handle.take() {
handle.join().expect("Thread panicked");
}
}
}
}
struct Thread {
handle: Option<JoinHandle<()>>,
signal: Option<SendOnDrop>,
closed: flume::Receiver<()>,
}
impl Thread {
async fn reap(mut self) {
self.signal.take();
let _ = self.closed.recv_async().await;
if let Some(handle) = self.handle.take() {
handle.join().expect("Thread panicked");
}
}
}
struct SendOnDrop {
sender: flume::Sender<()>,
}
impl Drop for SendOnDrop {
fn drop(&mut self) {
let _ = self.sender.try_send(());
}
}
fn spawn(name: &'static str, id: u64, receiver: flume::Receiver<SendFn>) -> Thread {
let (closed_tx, closed) = flume::bounded(1);
let (signal, signal_rx) = flume::bounded(1);
let signal = SendOnDrop { sender: signal };
let closed_tx = SendOnDrop { sender: closed_tx };
let handle = std::thread::Builder::new()
.name(format!("{name}-{id}"))
.spawn(move || run(name, id, receiver, signal_rx, closed_tx))
.expect("Failed to spawn new thread");
Thread {
handle: Some(handle),
signal: Some(signal),
closed,
}
}
struct MetricsGuard {
name: &'static str,
id: u64,
start: Instant,
armed: bool,
}
impl MetricsGuard {
fn guard(name: &'static str, id: u64) -> Self {
tracing::trace!("Starting {name}-{id}");
metrics::increment_counter!(format!("async-cpupool.{name}.launched"), "id" => id.to_string());
MetricsGuard {
name,
id,
start: Instant::now(),
armed: true,
}
}
fn disarm(mut self) {
self.armed = false;
}
}
impl Drop for MetricsGuard {
fn drop(&mut self) {
metrics::increment_counter!(format!("async-cpupool.{}.closed", self.name), "clean" => (!self.armed).to_string(), "id" => self.id.to_string());
metrics::histogram!(format!("async-cpupool.{}.duration", self.name), self.start.elapsed().as_secs_f64(), "clean" => (!self.armed).to_string(), "id" => self.id.to_string());
tracing::trace!("Stopping {}-{}", self.name, self.id);
}
}
fn run(
name: &'static str,
id: u64,
receiver: flume::Receiver<SendFn>,
signal: flume::Receiver<()>,
closed_tx: SendOnDrop,
) {
let guard = MetricsGuard::guard(name, id);
loop {
match selector::blocking_select(signal.recv_async(), receiver.recv_async()) {
selector::Either::Left(_) | selector::Either::Right(Err(_)) => break,
selector::Either::Right(Ok(send_fn)) => invoke_send_fn(name, id, send_fn),
}
}
guard.disarm();
drop(closed_tx);
}
fn invoke_send_fn(name: &'static str, id: u64, send_fn: SendFn) {
let start = Instant::now();
metrics::increment_counter!(format!("async-cpupool.{name}.operation.start"), "id" => id.to_string());
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
(send_fn)();
}));
metrics::increment_counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string(), "id" => id.to_string());
metrics::histogram!(format!("async-cpupool.{name}.operation.duration"), start.elapsed().as_secs_f64(), "complete" => res.is_ok().to_string(), "id" => id.to_string());
if let Err(e) = res {
tracing::trace!("panic in spawned task: {e:?}");
}
}