vectordb/src/pool.rs

201 lines
5.3 KiB
Rust

use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use crate::{reaper::Reaper, thread::Thread};
pub(super) struct Pool<State, Message> {
threads: std::sync::Mutex<Vec<Thread>>,
state: Arc<State>,
reaper: Reaper,
func: fn(&State, flume::Receiver<Message>, flume::Receiver<()>),
tx: flume::Sender<Message>,
rx: flume::Receiver<Message>,
size: AtomicUsize,
lower_limit: usize,
upper_limit: usize,
}
pub(super) struct PoolBuilder<State, Message> {
state: Arc<State>,
reaper: Reaper,
func: fn(&State, flume::Receiver<Message>, flume::Receiver<()>),
lower_limit: usize,
upper_limit: usize,
name: Option<String>,
}
impl<State, Message> PoolBuilder<State, Message> {
pub(super) fn finish(self) -> Pool<State, Message>
where
State: Send + Sync + 'static,
Message: Send + 'static,
{
let (tx, rx) = flume::bounded(8);
let handler_rx = rx.clone();
let func = self.func;
let threads = (0..self.lower_limit)
.map(|_| {
let thread_state = Arc::clone(&self.state);
let handler_rx = handler_rx.clone();
let name = self
.name
.clone()
.unwrap_or_else(|| String::from("vectordb-pool"));
Thread::build(name).spawn(move |stopper| {
(func)(&thread_state, handler_rx, stopper);
})
})
.collect();
let size = AtomicUsize::new(self.lower_limit);
Pool {
threads: std::sync::Mutex::new(threads),
state: self.state,
reaper: self.reaper,
func: self.func,
tx,
rx,
size,
lower_limit: self.lower_limit,
upper_limit: self.upper_limit,
}
}
pub(super) fn with_name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub(super) fn with_lower_limit(mut self, lower_limit: usize) -> Self {
self.lower_limit = lower_limit;
self
}
pub(super) fn with_upper_limit(mut self, upper_limit: usize) -> Self {
self.upper_limit = upper_limit;
self
}
}
impl<State, Message> Pool<State, Message>
where
State: Send + Sync + 'static,
Message: Send + 'static,
{
pub(super) fn builder(
state: Arc<State>,
reaper: Reaper,
func: fn(&State, flume::Receiver<Message>, flume::Receiver<()>),
) -> PoolBuilder<State, Message> {
PoolBuilder {
state,
reaper,
func,
lower_limit: 1,
upper_limit: 1,
name: None,
}
}
pub(super) fn send_blocking(&self, message: Message) -> Result<(), Message> {
self.prepare_send();
self.tx.send(message).map_err(|e| e.into_inner())
}
pub(super) async fn send_async(&self, message: Message) -> Result<(), Message> {
self.prepare_send();
self.tx
.send_async(message)
.await
.map_err(|e| e.into_inner())
}
fn spawn_handler(&self) {
let mut size = self.size.load(Ordering::Acquire);
while size < self.upper_limit {
match self.size.compare_exchange_weak(
size,
size + 1,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(new_size) => size = new_size,
}
}
if size < self.upper_limit {
let state = Arc::clone(&self.state);
let handler_rx = self.rx.clone();
let func = self.func;
let thread = Thread::build(String::from("vectordb-pool")).spawn(move |stopper| {
(func)(&state, handler_rx, stopper);
});
self.threads.lock().unwrap().push(thread);
}
}
fn reap_handler(&self) {
let mut size = self.size.load(Ordering::Acquire);
while size > self.lower_limit {
match self.size.compare_exchange_weak(
size,
size - 1,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(new_size) => size = new_size,
}
}
if size > self.lower_limit {
let thread = self.threads.lock().unwrap().pop().expect("Size is > 1");
thread.cancel();
let _ = self.reaper.send(thread);
}
}
fn prepare_send(&self) {
if self.tx.is_full() {
self.spawn_handler();
}
if self.tx.is_empty() {
self.reap_handler();
}
}
}
impl<State, Message> std::fmt::Debug for Pool<State, Message>
where
State: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Pool")
.field("threads", &"Mutex<Vec<Thread>>")
.field("state", &self.state)
.field("reaper", &self.reaper)
.field("tx", &self.tx)
.field("rx", &self.rx)
.field("size", &self.size)
.field("lower_limit", &self.lower_limit)
.field("upper_limit", &self.upper_limit)
.field("func", &self.func)
.finish()
}
}