async-cpupool/src/lib.rs
2024-04-15 17:59:13 -05:00

658 lines
17 KiB
Rust

#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
mod drop_notifier;
mod executor;
mod notify;
mod queue;
mod selector;
mod spsc;
mod sync;
use std::{
future::Future,
num::{NonZeroU16, NonZeroUsize},
sync::{atomic::AtomicU64, Arc, Mutex},
thread::JoinHandle,
time::Instant,
};
use drop_notifier::{DropListener, DropNotifier};
use executor::block_on;
use queue::Queue;
use selector::select;
/// Configuration builder for the CpuPool
#[derive(Debug)]
pub struct Config {
name: &'static str,
buffer_multiplier: usize,
min_threads: u16,
max_threads: u16,
}
impl Config {
/// Create a new configuration builder with the default configuration
pub fn new() -> Self {
Config {
name: "cpupool",
buffer_multiplier: 8,
min_threads: 1,
max_threads: 4,
}
}
/// Set the name for the CpuPool
///
/// This is used for setting the names of spawned threads
///
/// default: `"cpupool"`
///
/// Example:
/// ```rust
/// # use async_cpupool::Config;
/// Config::new().name("sig-pool");
/// ```
pub fn name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
/// Set the multiplier for the internal queue's buffer size
///
/// This value must be at least 1. the buffer's size will be equal to `max_threads * buffer_multiplier`
///
/// default: `8`
///
/// Example:
/// ```rust
/// # use async_cpupool::Config;
/// Config::new().buffer_multiplier(4);
/// ```
pub fn buffer_multiplier(mut self, buffer_multiplier: usize) -> Self {
self.buffer_multiplier = buffer_multiplier;
self
}
/// Set the minimum allowed number of running threads
///
/// When there is little work to do, threads will be reaped until just this number remain
///
/// default: `1`
///
/// Example:
/// ```rust
/// # use async_cpupool::Config;
/// Config::new().min_threads(2);
/// ```
pub fn min_threads(mut self, min_threads: u16) -> Self {
self.min_threads = min_threads;
self
}
/// Set the maximum allowed number of running threads
///
/// When the threadpool is under load, threads will be spawned until this limit is reached
///
/// default: `4`
///
/// Example:
/// ```rust
/// # use async_cpupool::Config;
/// Config::new().max_threads(16);
/// ```
pub fn max_threads(mut self, max_threads: u16) -> Self {
self.max_threads = max_threads;
self
}
/// Create a CpuPool with the given configuration, spawning `min_threads` threads
///
/// This will error if `min_threads` is greater than `max_threads`, or if `buffer_multiplier`,
/// `max_threads`, or `min_threads` are `0`
///
/// Example:
/// ```rust
/// # use async_cpupool::Config;
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let pool = Config::new()
/// .name("sig-pool")
/// .min_threads(4)
/// .max_threads(16)
/// .buffer_multiplier(2)
/// .build()?;
/// # Ok(())
/// # }
/// ```
pub fn build(self) -> Result<CpuPool, ConfigError> {
let Config {
name,
buffer_multiplier,
min_threads,
max_threads,
} = self;
if max_threads < min_threads {
return Err(ConfigError::ThreadCount);
}
let buffer_multiplier = buffer_multiplier
.try_into()
.map_err(|_| ConfigError::BufferMultiplier)?;
let max_threads = max_threads
.try_into()
.map_err(|_| ConfigError::MaxThreads)?;
let min_threads = min_threads
.try_into()
.map_err(|_| ConfigError::MinThreads)?;
Ok(CpuPool {
state: Arc::new(CpuPoolState::new(
name,
buffer_multiplier,
min_threads,
max_threads,
)),
})
}
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
/// Errors created by invalid configuration of the CpuPool
#[derive(Debug)]
pub enum ConfigError {
/// The configured maxumim threads value is lower than the configured minimum threads value
ThreadCount,
/// The buffer_multiplier is 0
BufferMultiplier,
/// The max_threads value is 0
MaxThreads,
/// The min_threads value is 0
MinThreads,
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ThreadCount => write!(f, "min_threads cannot be higher than max_threads"),
Self::BufferMultiplier => write!(f, "buffer_multiplier cannot be zero"),
Self::MaxThreads => write!(f, "max_threads cannot be zero"),
Self::MinThreads => write!(f, "min_threads cannot be zero"),
}
}
}
impl std::error::Error for ConfigError {}
/// The blocking operation was canceled due to a panic
#[derive(Debug)]
pub struct Canceled;
impl std::fmt::Display for Canceled {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Blocking operation has panicked")
}
}
impl std::error::Error for Canceled {}
/// The CPUPool handle
#[derive(Clone, Debug)]
pub struct CpuPool {
state: Arc<CpuPoolState>,
}
impl CpuPool {
/// Create a new CpuPool with the default configuration
///
/// Example:
/// ```rust
/// # use async_cpupool::CpuPool;
/// let pool = CpuPool::new();
/// ```
pub fn new() -> Self {
Config::default().build().expect("Defaults are valid")
}
/// Create a configuration builder to customize the CpuPool
///
/// Example:
/// ```rust
/// # use async_cpupool::CpuPool;
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let pool = CpuPool::configure().build()?;
/// # Ok(())
/// # }
/// ```
pub fn configure() -> Config {
Config::default()
}
/// Spawn a blocking operation on the CpuPool
///
/// Example:
/// ```rust
/// # use async_cpupool::CpuPool;
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # smol::block_on(async {
/// let pool = CpuPool::new();
///
/// pool.spawn(|| std::thread::sleep(std::time::Duration::from_secs(3))).await?;
/// # Ok(())
/// # })
/// # }
/// ```
pub fn spawn<F, T>(&self, send_fn: F) -> impl Future<Output = Result<T, Canceled>> + '_
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (response_tx, response_rx) = spsc::channel();
let send_fn = Box::new(move || {
let output = (send_fn)();
match response_tx.blocking_send(output) {
Ok(()) => (), // sent
Err(Canceled) => tracing::warn!("receiver hung up"),
}
});
let opt = self.state.queue.try_push(send_fn);
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
let pushed = match self.state.queue.is_full_or() {
Ok(()) => self.push_thread(),
Err(len) if len > current_threads as usize => self.push_thread(),
Err(_) => false,
};
if pushed {
tracing::trace!("Pushed thread");
}
async {
if let Some(item) = opt {
self.state.queue.push(item).await;
}
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
match self.state.queue.is_full_or() {
Ok(()) => {
self.push_thread();
}
Err(len) if len > current_threads as usize => {
self.push_thread();
}
Err(len) if len < current_threads.ilog2() as usize => {
if let Some(thread) = self.pop_thread() {
thread.reap().await;
}
}
Err(_) => {}
}
response_rx.recv().await
}
}
/// Attempt to close the CpuPool
///
/// This operation returns `true` when the pool was succesfully closed, or `false` if there
/// exist other references to the pool, preventing closure.
///
/// It is not required to call close to close a CpuPool. CpuPools will automatically close
/// themselves when all clones are dropped. This is simply a method to integrate better with
/// async runtimes.
/// Example:
/// ```rust
/// # use async_cpupool::CpuPool;
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # smol::block_on(async {
/// let pool = CpuPool::new();
///
/// let closed = pool.close().await;
/// assert!(closed);
/// # Ok(())
/// # })
/// # }
/// ```
pub async fn close(self) -> bool {
let Some(mut state) = Arc::into_inner(self.state) else {
return false;
};
let mut threads = state.take_threads();
for thread in &mut threads {
thread.signal.take();
}
for mut thread in threads {
thread.closed.listen().await;
if let Some(handle) = thread.handle.take() {
handle.join().expect("Thread panicked");
}
}
true
}
fn push_thread(&self) -> bool {
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if current_threads >= u64::from(u16::from(self.state.max_threads)) {
tracing::trace!("At thread maximum");
return false;
}
if self
.state
.current_threads
.compare_exchange(
current_threads,
current_threads + 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
tracing::trace!("Didn't acquire spawn authorization");
return false;
}
// we updated the count, so we have authorization to spawn a new thread
let thread_id = self
.state
.thread_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let thread = spawn(self.state.name, thread_id, self.state.queue.clone());
self.state
.threads
.lock()
.expect("threads lock poison")
.push(thread);
true
}
fn pop_thread(&self) -> Option<Thread> {
let current_threads = self
.state
.current_threads
.load(std::sync::atomic::Ordering::Acquire);
if current_threads <= u64::from(u16::from(self.state.min_threads)) {
tracing::info!("At thread minimum");
return None;
}
if self
.state
.current_threads
.compare_exchange(
current_threads,
current_threads - 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
tracing::trace!("Didn't acquire reap authorization");
return None;
}
// we updated the count, so we have authorization to reap a thread
self.state
.threads
.lock()
.expect("threads lock poison")
.pop()
}
}
impl Default for CpuPool {
fn default() -> Self {
Self::new()
}
}
type SendFn = Box<dyn FnOnce() + Send>;
struct CpuPoolState {
name: &'static str,
min_threads: NonZeroU16,
max_threads: NonZeroU16,
current_threads: AtomicU64,
thread_id: AtomicU64,
queue: Queue<SendFn>,
threads: Mutex<ThreadVec>,
}
impl CpuPoolState {
fn new(
name: &'static str,
buffer_multiplier: NonZeroUsize,
min_threads: NonZeroU16,
max_threads: NonZeroU16,
) -> Self {
let thread_capacity = usize::from(u16::from(max_threads));
let queue = queue::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
let start_threads = u64::from(u16::from(min_threads));
let threads = ThreadVec::new(start_threads, thread_capacity, |i| {
spawn(name, i, queue.clone())
});
let current_threads = AtomicU64::new(start_threads);
let thread_id = AtomicU64::new(start_threads);
CpuPoolState {
name,
min_threads,
max_threads,
current_threads,
thread_id,
queue,
threads: Mutex::new(threads),
}
}
fn take_threads(&mut self) -> Vec<Thread> {
self.threads.lock().expect("threads lock poison").take()
}
}
impl std::fmt::Debug for CpuPoolState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CpuPoolState")
.field("name", &self.name)
.field("min_threads", &self.min_threads)
.field("max_threads", &self.max_threads)
.finish()
}
}
struct ThreadVec {
threads: Vec<Thread>,
}
impl ThreadVec {
fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
where
F: Fn(u64) -> Thread,
{
let mut threads = Vec::with_capacity(max_threads);
for i in 0..start_threads {
threads.push((spawn)(i));
}
Self { threads }
}
fn push(&mut self, thread: Thread) {
self.threads.push(thread);
}
fn pop(&mut self) -> Option<Thread> {
self.threads.pop()
}
fn take(&mut self) -> Vec<Thread> {
std::mem::take(&mut self.threads)
}
}
impl Drop for ThreadVec {
fn drop(&mut self) {
for thread in &mut self.threads {
thread.signal.take();
}
for thread in &mut self.threads {
if let Some(handle) = thread.handle.take() {
handle.join().expect("Thread panicked");
}
}
}
}
struct Thread {
handle: Option<JoinHandle<()>>,
signal: Option<DropNotifier>,
closed: DropListener,
}
impl Thread {
async fn reap(mut self) {
self.signal.take();
self.closed.listen().await;
if let Some(handle) = self.handle.take() {
handle.join().expect("Thread panicked");
}
}
}
fn spawn(name: &'static str, id: u64, receiver: Queue<SendFn>) -> Thread {
let (closed_notifier, closed_listener) = drop_notifier::notifier();
let (signal_notifier, signal_listener) = drop_notifier::notifier();
let handle = std::thread::Builder::new()
.name(format!("{name}-{id}"))
.spawn(move || run(name, id, receiver, signal_listener, closed_notifier))
.expect("Failed to spawn new thread");
Thread {
handle: Some(handle),
signal: Some(signal_notifier),
closed: closed_listener,
}
}
struct MetricsGuard {
name: &'static str,
id: u64,
start: Instant,
armed: bool,
}
impl MetricsGuard {
fn guard(name: &'static str, id: u64) -> Self {
tracing::trace!("Starting {name}-{id}");
metrics::counter!(format!("async-cpupool.{name}.thread.launched")).increment(1);
MetricsGuard {
name,
id,
start: Instant::now(),
armed: true,
}
}
fn disarm(mut self) {
self.armed = false;
}
}
impl Drop for MetricsGuard {
fn drop(&mut self) {
metrics::counter!(format!("async-cpupool.{}.thread.closed", self.name), "clean" => (!self.armed).to_string()).increment(1);
metrics::histogram!(format!("async-cpupool.{}.thread.seconds", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64());
tracing::trace!("Stopping {}-{}", self.name, self.id);
}
}
fn run(
name: &'static str,
id: u64,
receiver: Queue<SendFn>,
signal: DropListener,
closed_tx: DropNotifier,
) {
let guard = MetricsGuard::guard(name, id);
let mut signal = std::pin::pin!(signal.listen());
loop {
match block_on(select(&mut signal, receiver.pop())) {
selector::Either::Left(_) => break,
selector::Either::Right(send_fn) => invoke_send_fn(name, send_fn),
}
}
guard.disarm();
drop(closed_tx);
}
fn invoke_send_fn(name: &'static str, send_fn: SendFn) {
let start = Instant::now();
metrics::counter!(format!("async-cpupool.{name}.operation.start")).increment(1);
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
(send_fn)();
}));
metrics::counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string()).increment(1);
metrics::histogram!(format!("async-cpupool.{name}.operation.seconds"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64());
if let Err(e) = res {
tracing::trace!("panic in spawned task: {e:?}");
}
}