185 lines
5 KiB
Rust
185 lines
5 KiB
Rust
use std::sync::Arc;
|
|
|
|
use crate::{pool::Pool, reaper::Reaper, InternalVectorId, TreeError, Vector};
|
|
|
|
pub(super) use self::db::{HypertreeRepo, SimilarityStyle};
|
|
|
|
mod db;
|
|
|
|
type SimilaritiesResult<const N: usize> =
|
|
Result<Vec<(f32, InternalVectorId, Vector<N>)>, TreeError>;
|
|
|
|
#[derive(Debug)]
|
|
pub(super) struct HypertreeMeta<const N: usize> {
|
|
hypertree: Arc<db::HypertreeRepo<N>>,
|
|
add_to_index: Pool<db::HypertreeRepo<N>, AddToIndexCommand<N>>,
|
|
find_similarities: Pool<db::HypertreeRepo<N>, FindSimilaritiesCommand<N>>,
|
|
}
|
|
|
|
struct FindSimilaritiesCommand<const N: usize> {
|
|
vector: Arc<Vector<N>>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
similarity_style: SimilarityStyle,
|
|
responder: oneshot::Sender<SimilaritiesResult<N>>,
|
|
parent: tracing::Span,
|
|
}
|
|
|
|
struct AddToIndexCommand<const N: usize> {
|
|
vectors: Arc<[(InternalVectorId, Vector<N>)]>,
|
|
responder: oneshot::Sender<Result<(), TreeError>>,
|
|
parent: tracing::Span,
|
|
}
|
|
|
|
impl<const N: usize> HypertreeMeta<N> {
|
|
pub(super) fn build(hypertree: db::HypertreeRepo<N>, reaper: Reaper) -> Self {
|
|
let hypertree = Arc::new(hypertree);
|
|
|
|
let parallelism = std::thread::available_parallelism()
|
|
.map(usize::from)
|
|
.unwrap_or(1);
|
|
|
|
let add_to_index =
|
|
Pool::builder(Arc::clone(&hypertree), reaper.clone(), add_to_index_runner)
|
|
.with_name(String::from("vectordb-hypertree-index"))
|
|
.finish();
|
|
|
|
let find_similarities = Pool::builder(Arc::clone(&hypertree), reaper, similarities_runner)
|
|
.with_lower_limit(2)
|
|
.with_upper_limit(16 * parallelism)
|
|
.with_name(String::from("vectordb-hypertree-similarities"))
|
|
.finish();
|
|
|
|
Self {
|
|
hypertree,
|
|
add_to_index,
|
|
find_similarities,
|
|
}
|
|
}
|
|
|
|
pub(super) fn perform_maintenance(&self) -> Result<(), TreeError> {
|
|
if self.hypertree.rebuild_hypertree()? {
|
|
self.hypertree.cleanup()?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub(super) fn add_many_to_index(
|
|
&self,
|
|
vectors: Arc<[(InternalVectorId, Vector<N>)]>,
|
|
) -> Result<oneshot::Receiver<Result<(), TreeError>>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.add_to_index
|
|
.send_blocking(AddToIndexCommand {
|
|
vectors,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
Ok(rx)
|
|
}
|
|
|
|
pub(super) fn find_similarities(
|
|
&self,
|
|
vector: Arc<Vector<N>>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
similarity_style: SimilarityStyle,
|
|
) -> Result<oneshot::Receiver<SimilaritiesResult<N>>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.find_similarities
|
|
.send_blocking(FindSimilaritiesCommand {
|
|
vector,
|
|
threshold,
|
|
limit,
|
|
similarity_style,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
Ok(rx)
|
|
}
|
|
}
|
|
|
|
fn similarities_runner<const N: usize>(
|
|
hypertree: &db::HypertreeRepo<N>,
|
|
rx: flume::Receiver<FindSimilaritiesCommand<N>>,
|
|
stopper: flume::Receiver<()>,
|
|
) {
|
|
crate::with_stopper(rx, stopper, |command| {
|
|
let FindSimilaritiesCommand {
|
|
vector,
|
|
threshold,
|
|
limit,
|
|
similarity_style,
|
|
responder,
|
|
parent,
|
|
} = command;
|
|
|
|
let span = tracing::debug_span!(parent: parent, "FindSimilarities");
|
|
let guard = span.enter();
|
|
|
|
let res = std::panic::catch_unwind(move || {
|
|
hypertree.find_similarities(&vector, threshold, limit, similarity_style)
|
|
});
|
|
|
|
match res {
|
|
Ok(res) => {
|
|
if responder.send(res).is_err() {
|
|
tracing::warn!("Requester disconnected");
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("operation panicked: {e:?}");
|
|
}
|
|
}
|
|
|
|
drop(guard);
|
|
})
|
|
}
|
|
|
|
fn add_to_index_runner<const N: usize>(
|
|
hypertree: &db::HypertreeRepo<N>,
|
|
rx: flume::Receiver<AddToIndexCommand<N>>,
|
|
stopper: flume::Receiver<()>,
|
|
) {
|
|
crate::with_stopper(rx, stopper, |command| {
|
|
let AddToIndexCommand {
|
|
vectors,
|
|
responder,
|
|
parent,
|
|
} = command;
|
|
|
|
let span = tracing::debug_span!(parent: parent, "AddToIndex");
|
|
let guard = span.enter();
|
|
|
|
let res = std::panic::catch_unwind(move || hypertree.add_many_to_index(&vectors));
|
|
|
|
match res {
|
|
Ok(res) => {
|
|
if responder.send(res).is_err() {
|
|
tracing::warn!("Requester disconnected");
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("operation panicked: {e:?}");
|
|
}
|
|
}
|
|
|
|
drop(guard);
|
|
})
|
|
}
|