vectordb/src/hypertree.rs

185 lines
5 KiB
Rust

use std::sync::Arc;
use crate::{pool::Pool, reaper::Reaper, InternalVectorId, TreeError, Vector};
pub(super) use self::db::{HypertreeRepo, SimilarityStyle};
mod db;
type SimilaritiesResult<const N: usize> =
Result<Vec<(f32, InternalVectorId, Vector<N>)>, TreeError>;
#[derive(Debug)]
pub(super) struct HypertreeMeta<const N: usize> {
hypertree: Arc<db::HypertreeRepo<N>>,
add_to_index: Pool<db::HypertreeRepo<N>, AddToIndexCommand<N>>,
find_similarities: Pool<db::HypertreeRepo<N>, FindSimilaritiesCommand<N>>,
}
struct FindSimilaritiesCommand<const N: usize> {
vector: Arc<Vector<N>>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
responder: oneshot::Sender<SimilaritiesResult<N>>,
parent: tracing::Span,
}
struct AddToIndexCommand<const N: usize> {
vectors: Arc<[(InternalVectorId, Vector<N>)]>,
responder: oneshot::Sender<Result<(), TreeError>>,
parent: tracing::Span,
}
impl<const N: usize> HypertreeMeta<N> {
pub(super) fn build(hypertree: db::HypertreeRepo<N>, reaper: Reaper) -> Self {
let hypertree = Arc::new(hypertree);
let parallelism = std::thread::available_parallelism()
.map(usize::from)
.unwrap_or(1);
let add_to_index =
Pool::builder(Arc::clone(&hypertree), reaper.clone(), add_to_index_runner)
.with_name(String::from("vectordb-hypertree-index"))
.finish();
let find_similarities = Pool::builder(Arc::clone(&hypertree), reaper, similarities_runner)
.with_lower_limit(2)
.with_upper_limit(16 * parallelism)
.with_name(String::from("vectordb-hypertree-similarities"))
.finish();
Self {
hypertree,
add_to_index,
find_similarities,
}
}
pub(super) fn perform_maintenance(&self) -> Result<(), TreeError> {
if self.hypertree.rebuild_hypertree()? {
self.hypertree.cleanup()?;
}
Ok(())
}
pub(super) fn add_many_to_index(
&self,
vectors: Arc<[(InternalVectorId, Vector<N>)]>,
) -> Result<oneshot::Receiver<Result<(), TreeError>>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.add_to_index
.send_blocking(AddToIndexCommand {
vectors,
responder,
parent: tracing::Span::current(),
})
.is_err()
{
return Err(TreeError::Closed);
}
Ok(rx)
}
pub(super) fn find_similarities(
&self,
vector: Arc<Vector<N>>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
) -> Result<oneshot::Receiver<SimilaritiesResult<N>>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.find_similarities
.send_blocking(FindSimilaritiesCommand {
vector,
threshold,
limit,
similarity_style,
responder,
parent: tracing::Span::current(),
})
.is_err()
{
return Err(TreeError::Closed);
}
Ok(rx)
}
}
fn similarities_runner<const N: usize>(
hypertree: &db::HypertreeRepo<N>,
rx: flume::Receiver<FindSimilaritiesCommand<N>>,
stopper: flume::Receiver<()>,
) {
crate::with_stopper(rx, stopper, |command| {
let FindSimilaritiesCommand {
vector,
threshold,
limit,
similarity_style,
responder,
parent,
} = command;
let span = tracing::debug_span!(parent: parent, "FindSimilarities");
let guard = span.enter();
let res = std::panic::catch_unwind(move || {
hypertree.find_similarities(&vector, threshold, limit, similarity_style)
});
match res {
Ok(res) => {
if responder.send(res).is_err() {
tracing::warn!("Requester disconnected");
}
}
Err(e) => {
tracing::error!("operation panicked: {e:?}");
}
}
drop(guard);
})
}
fn add_to_index_runner<const N: usize>(
hypertree: &db::HypertreeRepo<N>,
rx: flume::Receiver<AddToIndexCommand<N>>,
stopper: flume::Receiver<()>,
) {
crate::with_stopper(rx, stopper, |command| {
let AddToIndexCommand {
vectors,
responder,
parent,
} = command;
let span = tracing::debug_span!(parent: parent, "AddToIndex");
let guard = span.enter();
let res = std::panic::catch_unwind(move || hypertree.add_many_to_index(&vectors));
match res {
Ok(res) => {
if responder.send(res).is_err() {
tracing::warn!("Requester disconnected");
}
}
Err(e) => {
tracing::error!("operation panicked: {e:?}");
}
}
drop(guard);
})
}