1287 lines
36 KiB
Rust
1287 lines
36 KiB
Rust
mod hypertree;
|
|
mod pool;
|
|
mod reaper;
|
|
mod thread;
|
|
|
|
use std::{collections::HashMap, marker::PhantomData, path::Path};
|
|
|
|
use hypertree::{HypertreeMeta, SimilarityStyle};
|
|
use pool::Pool;
|
|
use reaper::Reaper;
|
|
use redb::{
|
|
CommitError, CompactionError, Database, DatabaseError, MultimapTable, MultimapTableDefinition,
|
|
ReadTransaction, ReadableMultimapTable, ReadableTable, RedbKey, RedbValue, StorageError, Table,
|
|
TableDefinition, TableError, TableHandle, TransactionError, TypeName, WriteTransaction,
|
|
};
|
|
use thread::Thread;
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct VectorDb<const N: usize> {
|
|
inner: std::sync::Arc<VectorMeta<N>>,
|
|
_reaper: Reaper,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct VectorMeta<const N: usize> {
|
|
reader: Pool<VectorRepo<N>, ReaderCommand<N>>,
|
|
insert_vector: Pool<VectorRepo<N>, InsertVectorsCommand<N>>,
|
|
_maintenance: Thread,
|
|
}
|
|
|
|
enum ReaderCommand<const N: usize> {
|
|
FindSimilarities {
|
|
vector: std::sync::Arc<Vector<N>>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
similarity_style: SimilarityStyle,
|
|
responder: oneshot::Sender<Result<Vec<VectorId>, TreeError>>,
|
|
parent: tracing::Span,
|
|
},
|
|
GetVectors {
|
|
vector_ids: Vec<VectorId>,
|
|
responder: oneshot::Sender<Result<HashMap<VectorId, Vector<N>>, TreeError>>,
|
|
parent: tracing::Span,
|
|
},
|
|
}
|
|
|
|
struct InsertVectorsCommand<const N: usize> {
|
|
vectors: Vec<Vector<N>>,
|
|
responder: oneshot::Sender<Result<Vec<VectorId>, TreeError>>,
|
|
parent: tracing::Span,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct VectorRepo<const N: usize> {
|
|
database: Database,
|
|
hypertrees: Vec<HypertreeMeta<N>>,
|
|
maintenance_sender: flume::Sender<()>,
|
|
last_rebuild: std::sync::atomic::AtomicU64,
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, PartialOrd)]
|
|
pub struct Vector<const N: usize>([f32; N]);
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
|
struct VectorBytes<'a, const N: usize> {
|
|
vector_bytes: std::borrow::Cow<'a, [u8]>,
|
|
_phantom: PhantomData<Vector<N>>,
|
|
}
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
|
#[repr(transparent)]
|
|
pub struct VectorId(u128);
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
|
#[repr(transparent)]
|
|
struct InternalVectorId(u128);
|
|
|
|
#[derive(Debug)]
|
|
pub enum DecodeError {
|
|
Empty,
|
|
InvalidLength,
|
|
Bound,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum TreeError {
|
|
Commit(CommitError),
|
|
Compact(CompactionError),
|
|
Database(DatabaseError),
|
|
Decode(DecodeError),
|
|
Io(std::io::Error),
|
|
Storage(StorageError),
|
|
Table(TableError),
|
|
Transaction(TransactionError),
|
|
Canceled,
|
|
Closed,
|
|
}
|
|
|
|
const IDS_TABLE: TableDefinition<'static, &'static str, u128> =
|
|
TableDefinition::new("vectordb::ids");
|
|
|
|
const VECTOR_INSERT_TABLE: TableDefinition<'static, VectorId, InternalVectorId> =
|
|
TableDefinition::new("vectordb::vector_insert_table");
|
|
|
|
const INVERSE_VECTOR_INSERT_MULTIMAP: MultimapTableDefinition<'static, InternalVectorId, VectorId> =
|
|
MultimapTableDefinition::new("vectordb::inverse_vector_insert_multimap");
|
|
|
|
const fn inverse_vector_table_definition<const N: usize>(
|
|
) -> TableDefinition<'static, VectorBytes<'static, N>, InternalVectorId> {
|
|
TableDefinition::new("vectordb::inverse_vector_table")
|
|
}
|
|
|
|
const fn vector_table_definition<const N: usize>(
|
|
) -> TableDefinition<'static, InternalVectorId, VectorBytes<'static, N>> {
|
|
TableDefinition::new("vectordb::vector_table")
|
|
}
|
|
|
|
impl<const N: usize> VectorDb<N> {
|
|
#[tracing::instrument(level = "debug", skip(db_directory), fields(directory = ?db_directory.as_ref()))]
|
|
pub fn open<P: AsRef<Path>>(
|
|
db_directory: P,
|
|
target_hypertree_count: usize,
|
|
) -> Result<Self, TreeError> {
|
|
let reaper = Reaper::new();
|
|
|
|
let (tx, rx) = flume::bounded(1);
|
|
|
|
let repo = VectorRepo::open(db_directory, target_hypertree_count, reaper.clone(), tx)?;
|
|
|
|
let meta = VectorMeta::build(repo, reaper.clone(), rx);
|
|
|
|
Ok(Self {
|
|
inner: std::sync::Arc::new(meta),
|
|
_reaper: reaper,
|
|
})
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub fn insert_vector_blocking(&self, vector: Vector<N>) -> Result<VectorId, TreeError> {
|
|
let vec = self.inner.insert_vectors_blocking(vec![vector])?;
|
|
|
|
Ok(vec[0])
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub async fn insert_vector(&self, vector: Vector<N>) -> Result<VectorId, TreeError> {
|
|
let vec = self.inner.insert_vectors(vec![vector]).await?;
|
|
|
|
Ok(vec[0])
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self, vectors))]
|
|
pub fn insert_vectors_blocking(
|
|
&self,
|
|
vectors: Vec<Vector<N>>,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner.insert_vectors_blocking(vectors)
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self, vectors))]
|
|
pub async fn insert_vectors(
|
|
&self,
|
|
vectors: Vec<Vector<N>>,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner.insert_vectors(vectors).await
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub fn find_similarities_blocking(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner
|
|
.find_similarities_blocking(vector, threshold, limit, SimilarityStyle::Similar)
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub async fn find_similarities(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner
|
|
.find_similarities(vector, threshold, limit, SimilarityStyle::Similar)
|
|
.await
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub fn find_furthest_similarities_blocking(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner.find_similarities_blocking(
|
|
vector,
|
|
threshold,
|
|
limit,
|
|
SimilarityStyle::FurthestSimilar,
|
|
)
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub async fn find_furthest_similarities(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner
|
|
.find_similarities(vector, threshold, limit, SimilarityStyle::FurthestSimilar)
|
|
.await
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub fn find_dissimilarities_blocking(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner
|
|
.find_similarities_blocking(vector, threshold, limit, SimilarityStyle::Dissimilar)
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub async fn find_dissimilarities(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner
|
|
.find_similarities(vector, threshold, limit, SimilarityStyle::Dissimilar)
|
|
.await
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub fn find_closest_dissimilarities_blocking(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner.find_similarities_blocking(
|
|
vector,
|
|
threshold,
|
|
limit,
|
|
SimilarityStyle::ClosestDissimilar,
|
|
)
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub async fn find_closest_dissimilarities(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
self.inner
|
|
.find_similarities(vector, threshold, limit, SimilarityStyle::ClosestDissimilar)
|
|
.await
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub async fn get_vector(&self, vector_id: VectorId) -> Result<Option<Vector<N>>, TreeError> {
|
|
let mut hm = self.inner.get_vectors(vec![vector_id]).await?;
|
|
|
|
Ok(hm.remove(&vector_id))
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self))]
|
|
pub fn get_vector_blocking(&self, vector_id: VectorId) -> Result<Option<Vector<N>>, TreeError> {
|
|
let mut hm = self.inner.get_vectors_blocking(vec![vector_id])?;
|
|
|
|
Ok(hm.remove(&vector_id))
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self, vector_ids))]
|
|
pub async fn get_vectors(
|
|
&self,
|
|
vector_ids: Vec<VectorId>,
|
|
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
|
|
self.inner.get_vectors(vector_ids).await
|
|
}
|
|
|
|
#[tracing::instrument(level = "debug", skip(self, vector_ids))]
|
|
pub fn get_vectors_blocking(
|
|
&self,
|
|
vector_ids: Vec<VectorId>,
|
|
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
|
|
self.inner.get_vectors_blocking(vector_ids)
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> VectorMeta<N> {
|
|
fn build(
|
|
repo: VectorRepo<N>,
|
|
reaper: Reaper,
|
|
maintenance_receiver: flume::Receiver<()>,
|
|
) -> Self {
|
|
let state = std::sync::Arc::new(repo);
|
|
|
|
let parallelism = std::thread::available_parallelism()
|
|
.map(usize::from)
|
|
.unwrap_or(1);
|
|
|
|
let reader = Pool::builder(std::sync::Arc::clone(&state), reaper.clone(), reader_runner)
|
|
.with_lower_limit(2)
|
|
.with_upper_limit(16 * parallelism)
|
|
.with_name(String::from("vectordb-reader"))
|
|
.finish();
|
|
|
|
let insert_vector =
|
|
Pool::builder(std::sync::Arc::clone(&state), reaper, insert_vectors_runner)
|
|
.with_name(String::from("vectordb-inserter"))
|
|
.finish();
|
|
|
|
let maintenance =
|
|
Thread::build(String::from("vectordb-maintenance")).spawn(move |stopper| {
|
|
let mut index = 0;
|
|
|
|
with_stopper(maintenance_receiver, stopper, |()| {
|
|
if let Err(e) = state.hypertrees[index].perform_maintenance() {
|
|
tracing::warn!("Error peforming maintenance: {e:?}");
|
|
}
|
|
|
|
index += 1;
|
|
if index >= state.hypertrees.len() {
|
|
index = 0;
|
|
}
|
|
})
|
|
});
|
|
|
|
Self {
|
|
reader,
|
|
insert_vector,
|
|
_maintenance: maintenance,
|
|
}
|
|
}
|
|
|
|
async fn insert_vectors(&self, vectors: Vec<Vector<N>>) -> Result<Vec<VectorId>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.insert_vector
|
|
.send_async(InsertVectorsCommand {
|
|
vectors,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.await
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
rx.await?
|
|
}
|
|
|
|
fn insert_vectors_blocking(&self, vectors: Vec<Vector<N>>) -> Result<Vec<VectorId>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.insert_vector
|
|
.send_blocking(InsertVectorsCommand {
|
|
vectors,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
rx.recv()?
|
|
}
|
|
|
|
async fn get_vectors(
|
|
&self,
|
|
vector_ids: Vec<VectorId>,
|
|
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.reader
|
|
.send_async(ReaderCommand::GetVectors {
|
|
vector_ids,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.await
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
rx.await?
|
|
}
|
|
|
|
fn get_vectors_blocking(
|
|
&self,
|
|
vector_ids: Vec<VectorId>,
|
|
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.reader
|
|
.send_blocking(ReaderCommand::GetVectors {
|
|
vector_ids,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
rx.recv()?
|
|
}
|
|
|
|
async fn find_similarities(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
similarity_style: SimilarityStyle,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.reader
|
|
.send_async(ReaderCommand::FindSimilarities {
|
|
vector: std::sync::Arc::new(vector),
|
|
threshold,
|
|
limit,
|
|
similarity_style,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.await
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
rx.await?
|
|
}
|
|
|
|
fn find_similarities_blocking(
|
|
&self,
|
|
vector: Vector<N>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
similarity_style: SimilarityStyle,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
let (responder, rx) = oneshot::channel();
|
|
|
|
if self
|
|
.reader
|
|
.send_blocking(ReaderCommand::FindSimilarities {
|
|
vector: std::sync::Arc::new(vector),
|
|
threshold,
|
|
limit,
|
|
similarity_style,
|
|
responder,
|
|
parent: tracing::Span::current(),
|
|
})
|
|
.is_err()
|
|
{
|
|
return Err(TreeError::Closed);
|
|
}
|
|
|
|
rx.recv()?
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> VectorRepo<N> {
|
|
fn open<P: AsRef<Path>>(
|
|
db_directory: P,
|
|
target_hypertree_count: usize,
|
|
reaper: Reaper,
|
|
maintenance_sender: flume::Sender<()>,
|
|
) -> Result<Self, TreeError> {
|
|
std::fs::create_dir_all(db_directory.as_ref())?;
|
|
|
|
let mut database = Database::create(db_directory.as_ref().join("vectordb.redb"))?;
|
|
|
|
database.check_integrity()?;
|
|
database.compact()?;
|
|
|
|
let txn = database.begin_write()?;
|
|
txn.open_table(vector_table_definition::<N>())?;
|
|
txn.commit()?;
|
|
|
|
let max_bucket_size = 128;
|
|
|
|
let hypertrees = (0..target_hypertree_count)
|
|
.map(|i| {
|
|
let tree = hypertree::HypertreeRepo::open(
|
|
db_directory
|
|
.as_ref()
|
|
.join("hypertrees")
|
|
.join(format!("tree-{i}")),
|
|
max_bucket_size,
|
|
)?;
|
|
|
|
tree.cleanup()?;
|
|
|
|
Ok(HypertreeMeta::build(tree, reaper.clone()))
|
|
})
|
|
.collect::<Result<Vec<_>, TreeError>>()?;
|
|
|
|
Ok(VectorRepo {
|
|
database,
|
|
hypertrees,
|
|
last_rebuild: std::sync::atomic::AtomicU64::new(0),
|
|
maintenance_sender,
|
|
})
|
|
}
|
|
|
|
fn insert_many_vectors(&self, vectors: Vec<Vector<N>>) -> Result<Vec<VectorId>, TreeError> {
|
|
let mut txn = self.database.begin_write()?;
|
|
|
|
let (ids, internal_ids) = insert_many_vectors(
|
|
&mut txn,
|
|
&vectors,
|
|
&self.last_rebuild,
|
|
&self.maintenance_sender,
|
|
)?;
|
|
|
|
let vectors = std::sync::Arc::from(
|
|
internal_ids
|
|
.into_iter()
|
|
.zip(vectors.into_iter())
|
|
.collect::<Vec<_>>(),
|
|
);
|
|
|
|
self.hypertrees
|
|
.iter()
|
|
.map(|hypertree| hypertree.add_many_to_index(std::sync::Arc::clone(&vectors)))
|
|
.collect::<Result<Vec<_>, _>>()?
|
|
.into_iter()
|
|
.try_for_each(|receiver| receiver.recv()?)?;
|
|
|
|
txn.commit()?;
|
|
|
|
Ok(ids)
|
|
}
|
|
|
|
fn get_vectors(
|
|
&self,
|
|
vector_ids: &[VectorId],
|
|
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
|
|
let txn = self.database.begin_read()?;
|
|
|
|
get_vectors(&txn, vector_ids)
|
|
}
|
|
|
|
fn find_similarities(
|
|
&self,
|
|
vector: std::sync::Arc<Vector<N>>,
|
|
threshold: Option<f32>,
|
|
limit: usize,
|
|
similarity_style: SimilarityStyle,
|
|
) -> Result<Vec<VectorId>, TreeError> {
|
|
let txn = self.database.begin_read()?;
|
|
|
|
let inverse_vector_insert_multimap =
|
|
txn.open_multimap_table(INVERSE_VECTOR_INSERT_MULTIMAP)?;
|
|
|
|
let mut candidates = self
|
|
.hypertrees
|
|
.iter()
|
|
.map(|hypertree| {
|
|
hypertree.find_similarities(
|
|
std::sync::Arc::clone(&vector),
|
|
threshold,
|
|
limit,
|
|
similarity_style,
|
|
)
|
|
})
|
|
.collect::<Result<Vec<_>, _>>()?
|
|
.into_iter()
|
|
.map(|receiver| receiver.recv()?)
|
|
.collect::<Result<Vec<_>, _>>()?
|
|
.into_iter()
|
|
.flatten()
|
|
.collect::<Vec<_>>();
|
|
|
|
match similarity_style {
|
|
SimilarityStyle::Dissimilar | SimilarityStyle::FurthestSimilar => {
|
|
candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Numbers are finite"));
|
|
}
|
|
SimilarityStyle::Similar | SimilarityStyle::ClosestDissimilar => {
|
|
candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Numbers are finite"));
|
|
}
|
|
}
|
|
candidates.dedup_by_key(|a| a.1);
|
|
candidates.truncate(limit);
|
|
|
|
let vectors = candidates
|
|
.into_iter()
|
|
.map(|(_, internal_vector_id, _)| {
|
|
inverse_vector_insert_multimap
|
|
.get(internal_vector_id)
|
|
.and_then(|iter| {
|
|
iter.map(|res| res.map(|value| value.value()))
|
|
.collect::<Result<Vec<_>, _>>()
|
|
})
|
|
})
|
|
.collect::<Result<Vec<_>, _>>()?
|
|
.into_iter()
|
|
.flatten()
|
|
.collect();
|
|
|
|
Ok(vectors)
|
|
}
|
|
}
|
|
|
|
fn reader_runner<const N: usize>(
|
|
repo: &VectorRepo<N>,
|
|
rx: flume::Receiver<ReaderCommand<N>>,
|
|
stopper: flume::Receiver<()>,
|
|
) {
|
|
with_stopper(rx, stopper, |message| match message {
|
|
ReaderCommand::FindSimilarities {
|
|
vector,
|
|
threshold,
|
|
limit,
|
|
similarity_style,
|
|
responder,
|
|
parent,
|
|
} => {
|
|
let span = tracing::debug_span!(parent: parent, "FindSimilarities");
|
|
let guard = span.enter();
|
|
|
|
let res = std::panic::catch_unwind(|| {
|
|
repo.find_similarities(vector, threshold, limit, similarity_style)
|
|
});
|
|
|
|
match res {
|
|
Ok(res) => {
|
|
if responder.send(res).is_err() {
|
|
tracing::warn!("Requester disconnected");
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("operation panicked: {e:?}");
|
|
}
|
|
}
|
|
|
|
drop(guard);
|
|
}
|
|
ReaderCommand::GetVectors {
|
|
vector_ids,
|
|
responder,
|
|
parent,
|
|
} => {
|
|
let span = tracing::debug_span!(parent: parent, "GetVectors");
|
|
let guard = span.enter();
|
|
|
|
let res = std::panic::catch_unwind(|| repo.get_vectors(&vector_ids));
|
|
|
|
match res {
|
|
Ok(res) => {
|
|
if responder.send(res).is_err() {
|
|
tracing::warn!("Requester disconnected");
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("operation panicked: {e:?}");
|
|
}
|
|
}
|
|
|
|
drop(guard);
|
|
}
|
|
})
|
|
}
|
|
|
|
fn insert_vectors_runner<const N: usize>(
|
|
repo: &VectorRepo<N>,
|
|
rx: flume::Receiver<InsertVectorsCommand<N>>,
|
|
stopper: flume::Receiver<()>,
|
|
) {
|
|
with_stopper(rx, stopper, |command| {
|
|
let InsertVectorsCommand {
|
|
vectors,
|
|
responder,
|
|
parent,
|
|
} = command;
|
|
|
|
let span = tracing::debug_span!(parent: parent, "InsertVectors");
|
|
let guard = span.enter();
|
|
|
|
let res = std::panic::catch_unwind(|| repo.insert_many_vectors(vectors));
|
|
|
|
match res {
|
|
Ok(res) => {
|
|
if responder.send(res).is_err() {
|
|
tracing::warn!("Requester disconnected");
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("operation panicked: {e:?}");
|
|
}
|
|
}
|
|
|
|
drop(guard);
|
|
})
|
|
}
|
|
|
|
fn with_stopper<Msg, Callback>(
|
|
rx: flume::Receiver<Msg>,
|
|
stopper: flume::Receiver<()>,
|
|
mut cb: Callback,
|
|
) where
|
|
Callback: FnMut(Msg),
|
|
{
|
|
let stopping = std::sync::atomic::AtomicBool::new(false);
|
|
|
|
while !stopping.load(std::sync::atomic::Ordering::Acquire) {
|
|
flume::Selector::new()
|
|
.recv(&stopper, |_| {
|
|
stopping.store(true, std::sync::atomic::Ordering::Release)
|
|
})
|
|
.recv(&rx, |res| match res {
|
|
Ok(msg) => (cb)(msg),
|
|
Err(_) => stopping.store(true, std::sync::atomic::Ordering::Release),
|
|
})
|
|
.wait();
|
|
}
|
|
}
|
|
|
|
fn insert_many_vectors<const N: usize>(
|
|
transaction: &mut WriteTransaction<'_>,
|
|
vectors: &[Vector<N>],
|
|
len: &std::sync::atomic::AtomicU64,
|
|
maintenance_sender: &flume::Sender<()>,
|
|
) -> Result<(Vec<VectorId>, Vec<InternalVectorId>), TreeError> {
|
|
let mut ids_table = transaction.open_table(IDS_TABLE)?;
|
|
let mut inverse_vector_insert_multimap =
|
|
transaction.open_multimap_table(INVERSE_VECTOR_INSERT_MULTIMAP)?;
|
|
let mut inverse_vector_table = transaction.open_table(inverse_vector_table_definition())?;
|
|
let mut vector_insert_table = transaction.open_table(VECTOR_INSERT_TABLE)?;
|
|
let mut vector_table = transaction.open_table(vector_table_definition())?;
|
|
|
|
let (vectors, internal_vectors) = vectors
|
|
.iter()
|
|
.map(|vector| {
|
|
do_insert_vector(
|
|
&mut ids_table,
|
|
&mut inverse_vector_insert_multimap,
|
|
&mut inverse_vector_table,
|
|
&mut vector_insert_table,
|
|
&mut vector_table,
|
|
vector,
|
|
)
|
|
})
|
|
.collect::<Result<Vec<_>, _>>()?
|
|
.into_iter()
|
|
.unzip::<_, _, _, Vec<Option<InternalVectorId>>>();
|
|
|
|
let internal_vectors = internal_vectors.into_iter().flatten().collect();
|
|
|
|
let vectors_len = vector_table.len()?;
|
|
let previous_len = len.load(std::sync::atomic::Ordering::Acquire);
|
|
if vectors_len > 0
|
|
&& vectors_len / 2 > previous_len
|
|
&& len
|
|
.compare_exchange(
|
|
previous_len,
|
|
vectors_len,
|
|
std::sync::atomic::Ordering::AcqRel,
|
|
std::sync::atomic::Ordering::Relaxed,
|
|
)
|
|
.is_ok()
|
|
{
|
|
let _ = maintenance_sender.try_send(());
|
|
}
|
|
|
|
Ok((vectors, internal_vectors))
|
|
}
|
|
|
|
fn do_insert_vector<'db, 'txn, const N: usize>(
|
|
ids_table: &mut Table<'db, 'txn, &'static str, u128>,
|
|
inverse_vector_insert_multimap: &mut MultimapTable<'db, 'txn, InternalVectorId, VectorId>,
|
|
inverse_vector_table: &mut Table<'db, 'txn, VectorBytes<N>, InternalVectorId>,
|
|
vector_insert_table: &mut Table<'db, 'txn, VectorId, InternalVectorId>,
|
|
vector_table: &mut Table<'db, 'txn, InternalVectorId, VectorBytes<N>>,
|
|
vector: &Vector<N>,
|
|
) -> Result<(VectorId, Option<InternalVectorId>), TreeError> {
|
|
let vector_id = VectorId(next_id(ids_table, VECTOR_INSERT_TABLE)?);
|
|
|
|
let vector_bytes = vector.encode();
|
|
|
|
let internal_vector_id_opt = inverse_vector_table
|
|
.get(&vector_bytes)?
|
|
.map(|id| id.value());
|
|
|
|
let internal_vector_id = if let Some(internal_vector_id) = internal_vector_id_opt {
|
|
vector_insert_table.insert(vector_id, internal_vector_id)?;
|
|
inverse_vector_insert_multimap.insert(internal_vector_id, vector_id)?;
|
|
|
|
None
|
|
} else {
|
|
let internal_vector_id =
|
|
InternalVectorId(next_id(ids_table, vector_table_definition::<N>())?);
|
|
|
|
vector_table.insert(internal_vector_id, &vector_bytes)?;
|
|
inverse_vector_table.insert(&vector_bytes, internal_vector_id)?;
|
|
vector_insert_table.insert(vector_id, internal_vector_id)?;
|
|
inverse_vector_insert_multimap.insert(internal_vector_id, vector_id)?;
|
|
|
|
Some(internal_vector_id)
|
|
};
|
|
|
|
Ok((vector_id, internal_vector_id))
|
|
}
|
|
|
|
fn get_vectors<const N: usize>(
|
|
transaction: &ReadTransaction<'_>,
|
|
vector_ids: &[VectorId],
|
|
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
|
|
let vector_insert_table = transaction.open_table(VECTOR_INSERT_TABLE)?;
|
|
let vector_table = transaction.open_table(vector_table_definition())?;
|
|
|
|
let mut out = HashMap::with_capacity(vector_ids.len());
|
|
|
|
for vector_id in vector_ids {
|
|
let Some(internal_vector_id) = vector_insert_table.get(*vector_id)? else {
|
|
continue;
|
|
};
|
|
|
|
let Some(vector) = do_get_vector(&vector_table, internal_vector_id.value())? else {
|
|
continue;
|
|
};
|
|
|
|
out.insert(*vector_id, vector);
|
|
}
|
|
|
|
Ok(out)
|
|
}
|
|
|
|
fn do_get_vector<const N: usize, T>(
|
|
vector_table: &T,
|
|
internal_vector_id: InternalVectorId,
|
|
) -> Result<Option<Vector<N>>, TreeError>
|
|
where
|
|
T: ReadableTable<InternalVectorId, VectorBytes<'static, N>>,
|
|
{
|
|
let Some(vector_bytes) = vector_table.get(internal_vector_id)? else {
|
|
return Ok(None);
|
|
};
|
|
|
|
let vector = Vector::decode(&vector_bytes.value())?;
|
|
|
|
Ok(Some(vector))
|
|
}
|
|
|
|
fn next_id<H>(
|
|
table: &mut Table<'_, '_, &'static str, u128>,
|
|
handle: H,
|
|
) -> Result<u128, StorageError>
|
|
where
|
|
H: TableHandle,
|
|
{
|
|
let mut next_id = table.get(handle.name())?.map(|id| id.value()).unwrap_or(0);
|
|
|
|
while let Some(id) = table.insert(handle.name(), next_id + 1)? {
|
|
let id = id.value();
|
|
|
|
if id == next_id {
|
|
break;
|
|
}
|
|
|
|
next_id = id;
|
|
}
|
|
|
|
Ok(next_id)
|
|
}
|
|
|
|
type F32ByteArray = [u8; 4];
|
|
|
|
const F32_BYTE_ARRAY_SIZE: usize = std::mem::size_of::<F32ByteArray>();
|
|
|
|
const fn vector_byte_length(n: usize) -> usize {
|
|
n * crate::F32_BYTE_ARRAY_SIZE
|
|
}
|
|
|
|
impl<const N: usize> Vector<N> {
|
|
const BYTES_LEN: usize = N * F32_BYTE_ARRAY_SIZE;
|
|
|
|
fn encode(&self) -> VectorBytes<'static, N> {
|
|
VectorBytes {
|
|
vector_bytes: std::borrow::Cow::Owned(
|
|
self.0.iter().flat_map(|f| f.to_be_bytes()).collect(),
|
|
),
|
|
_phantom: PhantomData,
|
|
}
|
|
}
|
|
|
|
fn decode(
|
|
VectorBytes {
|
|
vector_bytes,
|
|
_phantom,
|
|
}: &VectorBytes<N>,
|
|
) -> Result<Self, DecodeError> {
|
|
if vector_bytes.is_empty() {
|
|
return Err(DecodeError::Empty);
|
|
}
|
|
|
|
if vector_bytes.len() != Self::BYTES_LEN {
|
|
return Err(DecodeError::InvalidLength);
|
|
}
|
|
|
|
let mut index = 0;
|
|
// TODO: array.zip
|
|
let array = [(); N].map(|()| {
|
|
let f = f32::from_be_bytes(
|
|
vector_bytes[index..(index + F32_BYTE_ARRAY_SIZE)]
|
|
.try_into()
|
|
.expect("f32 byte array size is correct"),
|
|
);
|
|
index += F32_BYTE_ARRAY_SIZE;
|
|
f
|
|
});
|
|
|
|
Ok(Self(array))
|
|
}
|
|
|
|
fn average(&self, rhs: &Self) -> Self {
|
|
let mut index = 0;
|
|
|
|
// TODO: array.zip
|
|
Vector([(); N].map(|_| {
|
|
let avg = (self.0[index] + rhs.0[index]) / 2.0;
|
|
index += 1;
|
|
avg
|
|
}))
|
|
}
|
|
|
|
fn dot_product(&self, rhs: &Self) -> f32 {
|
|
self.0.iter().zip(rhs.0.iter()).map(|(a, b)| a * b).sum()
|
|
}
|
|
|
|
pub fn squared_euclidean_distance(&self, rhs: &Self) -> f32 {
|
|
self.0
|
|
.iter()
|
|
.zip(rhs.0.iter())
|
|
.map(|(a, b)| (a - b).powi(2))
|
|
.sum()
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> RedbValue for VectorBytes<'static, N> {
|
|
type SelfType<'a> = VectorBytes<'a, N>;
|
|
type AsBytes<'a> = &'a [u8];
|
|
|
|
fn fixed_width() -> Option<usize> {
|
|
None
|
|
}
|
|
|
|
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
|
|
where
|
|
Self: 'a,
|
|
{
|
|
assert_eq!(
|
|
data.len(),
|
|
vector_byte_length(N),
|
|
"Byte length is not vector byte length"
|
|
);
|
|
|
|
VectorBytes {
|
|
vector_bytes: std::borrow::Cow::Borrowed(data),
|
|
_phantom: std::marker::PhantomData,
|
|
}
|
|
}
|
|
|
|
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
|
|
where
|
|
Self: 'a,
|
|
Self: 'b,
|
|
{
|
|
&value.vector_bytes
|
|
}
|
|
|
|
fn type_name() -> TypeName {
|
|
TypeName::new("vectordb::VectorBytes")
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> RedbKey for VectorBytes<'static, N> {
|
|
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
|
|
data1.cmp(data2)
|
|
}
|
|
}
|
|
|
|
impl RedbValue for VectorId {
|
|
type SelfType<'a> = VectorId;
|
|
type AsBytes<'a> = <u128 as RedbValue>::AsBytes<'a>;
|
|
|
|
fn fixed_width() -> Option<usize> {
|
|
u128::fixed_width()
|
|
}
|
|
|
|
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
|
|
where
|
|
Self: 'a,
|
|
{
|
|
Self(u128::from_bytes(data))
|
|
}
|
|
|
|
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
|
|
where
|
|
Self: 'a,
|
|
Self: 'b,
|
|
{
|
|
u128::as_bytes(&value.0)
|
|
}
|
|
|
|
fn type_name() -> TypeName {
|
|
TypeName::new("vectordb::VectorId")
|
|
}
|
|
}
|
|
|
|
impl RedbKey for VectorId {
|
|
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
|
|
u128::compare(data1, data2)
|
|
}
|
|
}
|
|
|
|
impl RedbValue for InternalVectorId {
|
|
type SelfType<'a> = InternalVectorId;
|
|
type AsBytes<'a> = <u128 as RedbValue>::AsBytes<'a>;
|
|
|
|
fn fixed_width() -> Option<usize> {
|
|
u128::fixed_width()
|
|
}
|
|
|
|
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
|
|
where
|
|
Self: 'a,
|
|
{
|
|
Self(u128::from_bytes(data))
|
|
}
|
|
|
|
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
|
|
where
|
|
Self: 'a,
|
|
Self: 'b,
|
|
{
|
|
u128::as_bytes(&value.0)
|
|
}
|
|
|
|
fn type_name() -> TypeName {
|
|
TypeName::new("vectordb::InternalVectorId")
|
|
}
|
|
}
|
|
|
|
impl RedbKey for InternalVectorId {
|
|
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
|
|
u128::compare(data1, data2)
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> From<[f32; N]> for Vector<N> {
|
|
fn from(value: [f32; N]) -> Self {
|
|
Self(value)
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> std::ops::Mul<f32> for Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn mul(self, rhs: f32) -> Self::Output {
|
|
Vector(self.0.map(|a| a * rhs))
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> std::ops::Div<f32> for Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn div(self, rhs: f32) -> Self::Output {
|
|
Vector(self.0.map(|a| a / rhs))
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> std::ops::Add<Vector<N>> for Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn add(self, rhs: Self) -> Self::Output {
|
|
let mut index = 0;
|
|
|
|
// TODO: array.zip
|
|
Vector([(); N].map(|()| {
|
|
let sum = self.0[index] + rhs.0[index];
|
|
index += 1;
|
|
sum
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl<'a, const N: usize> std::ops::Add<&'a Vector<N>> for Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn add(self, rhs: &'a Self) -> Self::Output {
|
|
let mut index = 0;
|
|
|
|
// TODO: array.zip
|
|
Vector([(); N].map(|()| {
|
|
let sum = self.0[index] + rhs.0[index];
|
|
index += 1;
|
|
sum
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl<'a, const N: usize> std::ops::Add<&'a Vector<N>> for &'a Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn add(self, rhs: &'a Vector<N>) -> Self::Output {
|
|
let mut index = 0;
|
|
|
|
// TODO: array.zip
|
|
Vector([(); N].map(|()| {
|
|
let sum = self.0[index] + rhs.0[index];
|
|
index += 1;
|
|
sum
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl<const N: usize> std::ops::Sub<Vector<N>> for Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn sub(self, rhs: Self) -> Self::Output {
|
|
let mut index = 0;
|
|
|
|
// TODO: array.zip
|
|
Vector([(); N].map(|()| {
|
|
let diff = self.0[index] - rhs.0[index];
|
|
index += 1;
|
|
diff
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl<'a, const N: usize> std::ops::Sub<&'a Vector<N>> for Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn sub(self, rhs: &'a Self) -> Self::Output {
|
|
let mut index = 0;
|
|
|
|
// TODO: array.zip
|
|
Vector([(); N].map(|()| {
|
|
let diff = self.0[index] - rhs.0[index];
|
|
index += 1;
|
|
diff
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl<'a, const N: usize> std::ops::Sub<&'a Vector<N>> for &'a Vector<N> {
|
|
type Output = Vector<N>;
|
|
|
|
fn sub(self, rhs: &'a Vector<N>) -> Self::Output {
|
|
let mut index = 0;
|
|
|
|
// TODO: array.zip
|
|
Vector([(); N].map(|()| {
|
|
let diff = self.0[index] - rhs.0[index];
|
|
index += 1;
|
|
diff
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl From<oneshot::RecvError> for TreeError {
|
|
fn from(_: oneshot::RecvError) -> Self {
|
|
Self::Canceled
|
|
}
|
|
}
|
|
|
|
impl From<CommitError> for TreeError {
|
|
fn from(value: CommitError) -> Self {
|
|
Self::Commit(value)
|
|
}
|
|
}
|
|
|
|
impl From<CompactionError> for TreeError {
|
|
fn from(value: CompactionError) -> Self {
|
|
Self::Compact(value)
|
|
}
|
|
}
|
|
|
|
impl From<DatabaseError> for TreeError {
|
|
fn from(value: DatabaseError) -> Self {
|
|
Self::Database(value)
|
|
}
|
|
}
|
|
|
|
impl From<DecodeError> for TreeError {
|
|
fn from(value: DecodeError) -> Self {
|
|
Self::Decode(value)
|
|
}
|
|
}
|
|
|
|
impl From<std::io::Error> for TreeError {
|
|
fn from(value: std::io::Error) -> Self {
|
|
Self::Io(value)
|
|
}
|
|
}
|
|
|
|
impl From<StorageError> for TreeError {
|
|
fn from(value: StorageError) -> Self {
|
|
Self::Storage(value)
|
|
}
|
|
}
|
|
|
|
impl From<TableError> for TreeError {
|
|
fn from(value: TableError) -> Self {
|
|
Self::Table(value)
|
|
}
|
|
}
|
|
|
|
impl From<TransactionError> for TreeError {
|
|
fn from(value: TransactionError) -> Self {
|
|
Self::Transaction(value)
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for VectorId {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
u128::fmt(&self.0, f)
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for DecodeError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Self::Empty => write!(f, "No bytes present"),
|
|
Self::InvalidLength => write!(f, "Invalid number of bytes for type"),
|
|
Self::Bound => write!(f, "Direction bound is invalid"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for TreeError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Self::Commit(_) => write!(f, "Error committing transaction"),
|
|
Self::Compact(_) => write!(f, "Error compacting database"),
|
|
Self::Database(_) => write!(f, "Error opening database"),
|
|
Self::Decode(_) => write!(f, "Error decoding bytes"),
|
|
Self::Io(_) => write!(f, "Error creating db directory"),
|
|
Self::Storage(_) => write!(f, "Error in storage"),
|
|
Self::Table(_) => write!(f, "Error openning table"),
|
|
Self::Transaction(_) => write!(f, "Error opening transaction"),
|
|
Self::Canceled => write!(f, "Operation panicked"),
|
|
Self::Closed => write!(f, "Database is closing"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for DecodeError {}
|
|
|
|
impl std::error::Error for TreeError {
|
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
|
match self {
|
|
Self::Commit(e) => Some(e),
|
|
Self::Compact(e) => Some(e),
|
|
Self::Database(e) => Some(e),
|
|
Self::Decode(e) => Some(e),
|
|
Self::Io(e) => Some(e),
|
|
Self::Storage(e) => Some(e),
|
|
Self::Table(e) => Some(e),
|
|
Self::Transaction(e) => Some(e),
|
|
Self::Canceled => None,
|
|
Self::Closed => None,
|
|
}
|
|
}
|
|
}
|