201 lines
5.3 KiB
Rust
201 lines
5.3 KiB
Rust
use std::sync::{
|
|
atomic::{AtomicUsize, Ordering},
|
|
Arc,
|
|
};
|
|
|
|
use crate::{reaper::Reaper, thread::Thread};
|
|
|
|
pub(super) struct Pool<State, Message> {
|
|
threads: std::sync::Mutex<Vec<Thread>>,
|
|
state: Arc<State>,
|
|
reaper: Reaper,
|
|
func: fn(&State, flume::Receiver<Message>, flume::Receiver<()>),
|
|
tx: flume::Sender<Message>,
|
|
rx: flume::Receiver<Message>,
|
|
size: AtomicUsize,
|
|
lower_limit: usize,
|
|
upper_limit: usize,
|
|
}
|
|
|
|
pub(super) struct PoolBuilder<State, Message> {
|
|
state: Arc<State>,
|
|
reaper: Reaper,
|
|
func: fn(&State, flume::Receiver<Message>, flume::Receiver<()>),
|
|
lower_limit: usize,
|
|
upper_limit: usize,
|
|
name: Option<String>,
|
|
}
|
|
|
|
impl<State, Message> PoolBuilder<State, Message> {
|
|
pub(super) fn finish(self) -> Pool<State, Message>
|
|
where
|
|
State: Send + Sync + 'static,
|
|
Message: Send + 'static,
|
|
{
|
|
let (tx, rx) = flume::bounded(8);
|
|
|
|
let handler_rx = rx.clone();
|
|
|
|
let func = self.func;
|
|
|
|
let threads = (0..self.lower_limit)
|
|
.map(|_| {
|
|
let thread_state = Arc::clone(&self.state);
|
|
let handler_rx = handler_rx.clone();
|
|
let name = self
|
|
.name
|
|
.clone()
|
|
.unwrap_or_else(|| String::from("vectordb-pool"));
|
|
|
|
Thread::build(name).spawn(move |stopper| {
|
|
(func)(&thread_state, handler_rx, stopper);
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
let size = AtomicUsize::new(self.lower_limit);
|
|
|
|
Pool {
|
|
threads: std::sync::Mutex::new(threads),
|
|
state: self.state,
|
|
reaper: self.reaper,
|
|
func: self.func,
|
|
tx,
|
|
rx,
|
|
size,
|
|
lower_limit: self.lower_limit,
|
|
upper_limit: self.upper_limit,
|
|
}
|
|
}
|
|
|
|
pub(super) fn with_name(mut self, name: String) -> Self {
|
|
self.name = Some(name);
|
|
self
|
|
}
|
|
|
|
pub(super) fn with_lower_limit(mut self, lower_limit: usize) -> Self {
|
|
self.lower_limit = lower_limit;
|
|
self
|
|
}
|
|
|
|
pub(super) fn with_upper_limit(mut self, upper_limit: usize) -> Self {
|
|
self.upper_limit = upper_limit;
|
|
self
|
|
}
|
|
}
|
|
|
|
impl<State, Message> Pool<State, Message>
|
|
where
|
|
State: Send + Sync + 'static,
|
|
Message: Send + 'static,
|
|
{
|
|
pub(super) fn builder(
|
|
state: Arc<State>,
|
|
reaper: Reaper,
|
|
func: fn(&State, flume::Receiver<Message>, flume::Receiver<()>),
|
|
) -> PoolBuilder<State, Message> {
|
|
PoolBuilder {
|
|
state,
|
|
reaper,
|
|
func,
|
|
lower_limit: 1,
|
|
upper_limit: 1,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
pub(super) fn send_blocking(&self, message: Message) -> Result<(), Message> {
|
|
self.prepare_send();
|
|
self.tx.send(message).map_err(|e| e.into_inner())
|
|
}
|
|
|
|
pub(super) async fn send_async(&self, message: Message) -> Result<(), Message> {
|
|
self.prepare_send();
|
|
self.tx
|
|
.send_async(message)
|
|
.await
|
|
.map_err(|e| e.into_inner())
|
|
}
|
|
|
|
fn spawn_handler(&self) {
|
|
let mut size = self.size.load(Ordering::Acquire);
|
|
|
|
while size < self.upper_limit {
|
|
match self.size.compare_exchange_weak(
|
|
size,
|
|
size + 1,
|
|
Ordering::AcqRel,
|
|
Ordering::Relaxed,
|
|
) {
|
|
Ok(_) => break,
|
|
Err(new_size) => size = new_size,
|
|
}
|
|
}
|
|
|
|
if size < self.upper_limit {
|
|
let state = Arc::clone(&self.state);
|
|
let handler_rx = self.rx.clone();
|
|
let func = self.func;
|
|
|
|
let thread = Thread::build(String::from("vectordb-pool")).spawn(move |stopper| {
|
|
(func)(&state, handler_rx, stopper);
|
|
});
|
|
|
|
self.threads.lock().unwrap().push(thread);
|
|
}
|
|
}
|
|
|
|
fn reap_handler(&self) {
|
|
let mut size = self.size.load(Ordering::Acquire);
|
|
|
|
while size > self.lower_limit {
|
|
match self.size.compare_exchange_weak(
|
|
size,
|
|
size - 1,
|
|
Ordering::AcqRel,
|
|
Ordering::Relaxed,
|
|
) {
|
|
Ok(_) => break,
|
|
Err(new_size) => size = new_size,
|
|
}
|
|
}
|
|
|
|
if size > self.lower_limit {
|
|
let thread = self.threads.lock().unwrap().pop().expect("Size is > 1");
|
|
|
|
thread.cancel();
|
|
|
|
let _ = self.reaper.send(thread);
|
|
}
|
|
}
|
|
|
|
fn prepare_send(&self) {
|
|
if self.tx.is_full() {
|
|
self.spawn_handler();
|
|
}
|
|
|
|
if self.tx.is_empty() {
|
|
self.reap_handler();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<State, Message> std::fmt::Debug for Pool<State, Message>
|
|
where
|
|
State: std::fmt::Debug,
|
|
{
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("Pool")
|
|
.field("threads", &"Mutex<Vec<Thread>>")
|
|
.field("state", &self.state)
|
|
.field("reaper", &self.reaper)
|
|
.field("tx", &self.tx)
|
|
.field("rx", &self.rx)
|
|
.field("size", &self.size)
|
|
.field("lower_limit", &self.lower_limit)
|
|
.field("upper_limit", &self.upper_limit)
|
|
.field("func", &self.func)
|
|
.finish()
|
|
}
|
|
}
|