From 9a03220bb33af686fdfe618cf92f867ececb22ec Mon Sep 17 00:00:00 2001 From: asonix Date: Sat, 1 Jul 2023 21:54:07 -0500 Subject: [PATCH] Finish rewrite, strongly type IDs, parallelize hypertrees --- .gitignore | 2 +- examples/run.rs | 8 +- src/flatten_results.rs | 49 -- src/hypertree.rs | 718 ++++++++++++++++++------ src/lib.rs | 1200 +++++++++++++++++----------------------- 5 files changed, 1063 insertions(+), 914 deletions(-) delete mode 100644 src/flatten_results.rs diff --git a/.gitignore b/.gitignore index 5e41ccc..59678aa 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ /Cargo.lock /.envrc /.direnv -/db.redb +/vectordb diff --git a/examples/run.rs b/examples/run.rs index cdda717..f32a329 100644 --- a/examples/run.rs +++ b/examples/run.rs @@ -3,9 +3,10 @@ use std::time::Instant; use rand::{thread_rng, Rng}; fn main() { - let db = vectordb::VectorDb::open("./db.redb", String::from("moments"), 4).expect("Launched"); + let db = + vectordb::VectorDb::<7, 2>::open("./vectordb", String::from("moments")).expect("Launched"); - let total_chunks = 8; + let total_chunks = 32; let chunk_size = 1024 * 32; for _ in 0..total_chunks { @@ -17,8 +18,11 @@ fn main() { let now = Instant::now(); db.insert_many_vectors(&vectors).expect("Inserted vectors"); println!("Inserting vectors - Took {:?}", now.elapsed()); + std::thread::sleep(std::time::Duration::from_secs(5)); } + std::thread::sleep(std::time::Duration::from_secs(30)); + let existing_vector = thread_rng().gen::<[f32; 7]>().into(); db.insert_vector(&existing_vector).expect("Insert works"); diff --git a/src/flatten_results.rs b/src/flatten_results.rs deleted file mode 100644 index 37b1919..0000000 --- a/src/flatten_results.rs +++ /dev/null @@ -1,49 +0,0 @@ -pub(crate) struct FlattenResults { - iterator: I, - nested: Option, -} - -pub(crate) trait IteratorExt: std::iter::Iterator> { - fn flatten_results(self) -> FlattenResults - where - Self: Sized; -} - -impl IteratorExt for I -where - I: std::iter::Iterator>, - J: std::iter::Iterator>, -{ - fn flatten_results(self) -> FlattenResults { - FlattenResults { - iterator: self, - nested: None, - } - } -} - -impl std::iter::Iterator for FlattenResults -where - I: std::iter::Iterator>, - J: std::iter::Iterator>, -{ - type Item = Result; - - fn next(&mut self) -> Option { - if let Some(nested) = self.nested.as_mut() { - if let Some(res) = nested.next() { - return Some(res); - } - self.nested.take(); - } - - match self.iterator.next() { - Some(Ok(nested)) => { - self.nested = Some(nested); - self.next() - } - Some(Err(e)) => Some(Err(e)), - None => None, - } - } -} diff --git a/src/hypertree.rs b/src/hypertree.rs index 4532d28..15fc258 100644 --- a/src/hypertree.rs +++ b/src/hypertree.rs @@ -1,42 +1,31 @@ +use std::sync::atomic::AtomicBool; + use rand::{seq::IteratorRandom, thread_rng}; use redb::{ - Database, MultimapTable, MultimapTableDefinition, Range, ReadOnlyMultimapTable, ReadOnlyTable, - ReadableMultimapTable, ReadableTable, RedbKey, RedbValue, StorageError, Table, TableDefinition, - TypeName, + Database, MultimapTable, MultimapTableDefinition, MultimapTableHandle, Range, + ReadOnlyMultimapTable, ReadOnlyTable, ReadableMultimapTable, ReadableTable, RedbKey, RedbValue, + StorageError, Table, TableDefinition, TypeName, }; -use crate::{next_id, next_multimap_id, DecodeError, TreeError, Vector, VectorBytes}; +use crate::{next_id, DecodeError, InternalVectorId, TreeError, Vector, VectorBytes}; +#[derive(Debug)] pub(super) struct HypertreeDb { database: Database, + ingest_queue: Database, + rebuilding: AtomicBool, + cleanup: AtomicBool, max_bucket_size: usize, _size: std::marker::PhantomData>, } -const IDS_TABLE: TableDefinition<'static, &'static str, u128> = - TableDefinition::new("vectordb::ids"); +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +struct HyperplaneId(u128); -const BUCKET_MULTIMAP: MultimapTableDefinition<'static, u128, u128> = - MultimapTableDefinition::new("vectordb::bucket_table"); - -const fn hypertree_table_def( -) -> TableDefinition<'static, HyperplaneListBytes<'static, N>, u128> { - TableDefinition::new("vectordb::hypertree_table") -} - -const fn hyperplane_table_def( -) -> TableDefinition<'static, u128, HyperplaneBytes<'static, N>> { - TableDefinition::new("vectordb::hyperplane_table") -} - -const fn inverse_hyperplane_table_def( -) -> TableDefinition<'static, HyperplaneBytes<'static, N>, u128> { - TableDefinition::new("vectordb::inverse_hyperplane_table") -} - -const fn hyperplane_byte_length(n: usize) -> usize { - (n + 1) * crate::F32_BYTE_ARRAY_SIZE -} +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +struct BucketId(u128); #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] enum Bounded { @@ -54,22 +43,22 @@ enum BoundedByte { #[derive(Debug)] struct HyperplaneNode { - parents: HyperplaneList, - hyperplane_id: u128, - below_id: u128, - above_id: u128, + parents: Hypertree, + hyperplane_id: HyperplaneId, + below_id: BucketId, + above_id: BucketId, } #[derive(Debug)] -struct HyperplaneList { - hyperplanes: Vec<(u128, Bounded)>, +struct Hypertree { + hyperplanes: Vec<(HyperplaneId, Bounded)>, _size: std::marker::PhantomData>, } #[derive(Clone, Debug)] -struct HyperplaneListBytes<'a, const N: usize> { +struct HypertreeBytes<'a, const N: usize> { hyperplanes: std::borrow::Cow<'a, [u8]>, - _size: std::marker::PhantomData>, + _size: std::marker::PhantomData>, } #[derive(Clone, Debug)] @@ -85,28 +74,80 @@ struct HyperplaneBytes<'a, const N: usize> { } #[derive(Clone, Debug)] -struct HyperplaneListBytesRange<'a, const N: usize> { - lower: HyperplaneListBytes<'a, N>, - upper: HyperplaneListBytes<'a, N>, +struct HypertreeBytesRange<'a, const N: usize> { + lower: HypertreeBytes<'a, N>, + upper: HypertreeBytes<'a, N>, +} + +struct SetOnDrop<'a>(&'a AtomicBool); + +const IDS_TABLE: TableDefinition<'static, &'static str, u128> = + TableDefinition::new("vectordb::ids"); + +const BUCKET_MULTIMAP: MultimapTableDefinition<'static, BucketId, InternalVectorId> = + MultimapTableDefinition::new("vectordb::bucket_table"); + +const QUEUE_TABLE: MultimapTableDefinition<'static, &'static str, InternalVectorId> = + MultimapTableDefinition::new("vectordb::queue_table"); + +const REBUILD_TABLE: TableDefinition<'static, &'static str, u64> = + TableDefinition::new("vectordb::rebuild_table"); + +const QUEUE_KEY: &'static str = "queue"; + +const REBUILD_KEY: &'static str = "rebuild"; + +const fn internal_vector_table( +) -> TableDefinition<'static, InternalVectorId, VectorBytes<'static, N>> { + TableDefinition::new("vectordb::internal_vector_table") +} + +const fn hypertree_table_def( +) -> TableDefinition<'static, HypertreeBytes<'static, N>, BucketId> { + TableDefinition::new("vectordb::hypertree_table") +} + +const fn hyperplane_table_def( +) -> TableDefinition<'static, HyperplaneId, HyperplaneBytes<'static, N>> { + TableDefinition::new("vectordb::hyperplane_table") +} + +const fn inverse_hyperplane_table_def( +) -> TableDefinition<'static, HyperplaneBytes<'static, N>, HyperplaneId> { + TableDefinition::new("vectordb::inverse_hyperplane_table") +} + +const fn hyperplane_byte_length(n: usize) -> usize { + (n + 1) * crate::F32_BYTE_ARRAY_SIZE } fn do_add_to_index<'db, 'txn, const N: usize, T>( ids_table: &mut Table<'db, 'txn, &'static str, u128>, vector_table: &T, - hyperplane_table: &mut Table<'db, 'txn, u128, HyperplaneBytes>, - inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, u128>, - hypertree_table: &mut Table<'db, 'txn, HyperplaneListBytes<'static, N>, u128>, - bucket_multimap: &mut MultimapTable<'db, 'txn, u128, u128>, - insert_vector_id: u128, + internal_vector_table: &mut Table<'db, 'txn, InternalVectorId, VectorBytes>, + hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes>, + inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, HyperplaneId>, + hypertree_table: &mut Table<'db, 'txn, HypertreeBytes<'static, N>, BucketId>, + bucket_multimap: &mut MultimapTable<'db, 'txn, BucketId, InternalVectorId>, + insert_vector_id: InternalVectorId, max_bucket_size: usize, ) -> Result<(), TreeError> where - T: ReadableTable>, + T: ReadableTable>, { - let Some(insert_vector) = crate::do_get_vector(vector_table, insert_vector_id)?? else { + if internal_vector_table.get(insert_vector_id)?.is_some() { return Ok(()); + } + + let Some(insert_vector) = crate::do_get_vector(vector_table, insert_vector_id)? else { + // TODO: maybe should error? + todo!("Didn't find vector") }; + // This will add vectors to our tree vector store even if we have no tree yet + // This is fine, since rebuilding the tree will include this vector + internal_vector_table.insert(insert_vector_id, insert_vector.encode())?; + let Some(hyperplane_list) = get_first_list(hypertree_table)? else { return Ok(()); }; @@ -118,42 +159,67 @@ where &insert_vector, )??; - if let Some((hyperplane_list, bucket_id)) = bucket_opt { - bucket_multimap.insert(bucket_id, insert_vector_id)?; + let Some((hyperplane_list, bucket_id)) = bucket_opt else { + // TODO: maybe should error? + todo!("Didn't find bucket") + }; - let size = bucket_multimap.get(bucket_id)?.count(); + bucket_multimap.insert(bucket_id, insert_vector_id)?; - if size > max_bucket_size { - build_hyperplane( - ids_table, - vector_table, - hyperplane_table, - inverse_hyperplane_table, - hypertree_table, - bucket_multimap, - bucket_id, - hyperplane_list, - max_bucket_size, - )?; - } + let size = bucket_multimap.get(bucket_id)?.count(); + + if size > max_bucket_size { + build_hyperplane( + ids_table, + internal_vector_table, + hyperplane_table, + inverse_hyperplane_table, + hypertree_table, + bucket_multimap, + bucket_id, + hyperplane_list, + max_bucket_size, + )?; } Ok(()) } +fn next_multimap_id<'db, 'txn, H>( + table: &mut Table<'db, 'txn, &'static str, u128>, + handle: H, +) -> Result +where + H: MultimapTableHandle, +{ + let mut next_id = table.get(handle.name())?.map(|id| id.value()).unwrap_or(0); + + while let Some(id) = table.insert(handle.name(), next_id + 1)? { + let id = id.value(); + + if id == next_id { + break; + } + + next_id = id; + } + + Ok(next_id) +} + fn do_build_hypertree<'db, 'txn, const N: usize, T>( - bucket_multimap: &mut MultimapTable<'db, 'txn, u128, u128>, - hyperplane_table: &mut Table<'db, 'txn, u128, HyperplaneBytes>, - hypertree_table: &mut Table<'db, 'txn, HyperplaneListBytes, u128>, + bucket_multimap: &mut MultimapTable<'db, 'txn, BucketId, InternalVectorId>, + hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes>, + hypertree_table: &mut Table<'db, 'txn, HypertreeBytes, BucketId>, ids_table: &mut Table<'db, 'txn, &str, u128>, - inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, u128>, + inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, HyperplaneId>, vector_table: &T, max_bucket_size: usize, ) -> Result<(), TreeError> where - T: ReadableTable>, + T: ReadableTable>, { - let source_bucket = next_multimap_id(ids_table, &BUCKET_MULTIMAP)?; + let source_bucket = BucketId(next_multimap_id(ids_table, BUCKET_MULTIMAP)?); for res in vector_table.iter()? { let (id, _) = res?; @@ -172,7 +238,7 @@ where hypertree_table, bucket_multimap, source_bucket, - HyperplaneList { + Hypertree { hyperplanes: Vec::new(), _size: std::marker::PhantomData, }, @@ -224,19 +290,19 @@ where fn build_hyperplane<'db, 'txn, const N: usize, T>( ids_table: &mut Table<'db, 'txn, &'static str, u128>, vector_table: &T, - hyperplane_table: &mut Table<'db, 'txn, u128, HyperplaneBytes>, - inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, u128>, - hypertree_table: &mut Table<'db, 'txn, HyperplaneListBytes, u128>, - bucket_multimap: &mut MultimapTable<'db, 'txn, u128, u128>, - source_bucket: u128, - parents: HyperplaneList, + hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes>, + inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, HyperplaneId>, + hypertree_table: &mut Table<'db, 'txn, HypertreeBytes, BucketId>, + bucket_multimap: &mut MultimapTable<'db, 'txn, BucketId, InternalVectorId>, + source_bucket: BucketId, + parents: Hypertree, max_bucket_size: usize, ) -> Result>, TreeError> where - T: ReadableTable>, + T: ReadableTable>, { - let below_id = next_multimap_id(ids_table, &BUCKET_MULTIMAP)?; - let above_id = next_multimap_id(ids_table, &BUCKET_MULTIMAP)?; + let below_id = BucketId(next_multimap_id(ids_table, BUCKET_MULTIMAP)?); + let above_id = BucketId(next_multimap_id(ids_table, BUCKET_MULTIMAP)?); if bucket_multimap.get(source_bucket)?.count() <= max_bucket_size { return Ok(None); @@ -324,17 +390,17 @@ where fn insert_hyperplane<'db, 'txn, const N: usize>( ids_table: &mut Table<'db, 'txn, &'static str, u128>, - hyperplane_table: &mut Table<'db, 'txn, u128, HyperplaneBytes>, - inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, u128>, + hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes>, + inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, HyperplaneId>, hyperplane: &Hyperplane, -) -> Result { +) -> Result { let encoded = hyperplane.encode(); if let Some(id) = inverse_hyperplane_table.get(&encoded)? { return Ok(id.value()); } - let id = next_id(ids_table, &hyperplane_table_def::())?; + let id = HyperplaneId(next_id(ids_table, hyperplane_table_def::())?); hyperplane_table.insert(id, &encoded)?; inverse_hyperplane_table.insert(encoded, id)?; @@ -343,15 +409,15 @@ fn insert_hyperplane<'db, 'txn, const N: usize>( } fn do_get_similar_vectors<'txn, const N: usize, T>( - hypertree_table: &ReadOnlyTable<'txn, HyperplaneListBytes, u128>, - hyperplane_table: &ReadOnlyTable<'txn, u128, HyperplaneBytes>, - bucket_multimap: &ReadOnlyMultimapTable<'txn, u128, u128>, + hypertree_table: &ReadOnlyTable<'txn, HypertreeBytes, BucketId>, + hyperplane_table: &ReadOnlyTable<'txn, HyperplaneId, HyperplaneBytes>, + bucket_multimap: &ReadOnlyMultimapTable<'txn, BucketId, InternalVectorId>, vector_table: &T, - query_vector: Vector, + query_vector: &Vector, limit: usize, -) -> Result, TreeError> +) -> Result)>, TreeError> where - T: ReadableTable>, + T: ReadableTable>, { let Some(hyperplane_list) = get_first_list(hypertree_table)? else { return Ok(vec![]); @@ -361,7 +427,7 @@ where hypertree_table, hyperplane_table, hyperplane_list, - &query_vector, + query_vector, )??; let Some((_, bucket_id)) = bucket_opt else { @@ -389,8 +455,7 @@ where }; let new_vector = match super::do_get_vector(vector_table, vector_id) { - Ok(Ok(opt)) => opt?, - Ok(Err(e)) => return Some(Err(TreeError::from(e))), + Ok(opt) => opt?, Err(e) => return Some(Err(TreeError::from(e))), }; @@ -406,22 +471,14 @@ where candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Numbers are finite")); candidates.dedup_by_key(|a| a.1); - candidates.dedup_by_key(|a| a.2.encode()); candidates.truncate(limit); - let vectors = candidates - .into_iter() - .map(|(_, vector_id, _)| vector_id) - .collect::>(); - - Ok(vectors) + Ok(candidates) } -fn get_first_list( - hypertree_table: &T, -) -> Result>, TreeError> +fn get_first_list(hypertree_table: &T) -> Result>, TreeError> where - T: ReadableTable, u128>, + T: ReadableTable, BucketId>, { let Some(res) = hypertree_table.iter()?.next() else { return Ok(None); @@ -429,7 +486,7 @@ where let (k, _) = res?; - let list = HyperplaneList::decode(&k.value())?; + let list = Hypertree::decode(&k.value())?; Ok(Some(list)) } @@ -437,12 +494,12 @@ where fn find_similarity_bucket( hypertree_table: &T, hyperplane_table: &U, - mut hyperplane_list: HyperplaneList, + mut hyperplane_list: Hypertree, vector: &Vector, -) -> Result, u128)>, DecodeError>, StorageError> +) -> Result, BucketId)>, DecodeError>, StorageError> where - T: ReadableTable, u128>, - U: ReadableTable>, + T: ReadableTable, BucketId>, + U: ReadableTable>, { let mut depth = 0; @@ -485,12 +542,12 @@ where fn next_list<'table, const N: usize, T, V>( table: &'table T, - hyperplane_list: &HyperplaneList, + hyperplane_list: &Hypertree, depth: usize, direction: Bounded, -) -> Result>, DecodeError>, StorageError> +) -> Result>, DecodeError>, StorageError> where - T: ReadableTable, V> + 'table, + T: ReadableTable, V> + 'table, V: RedbValue + 'static, { let mut range = scan(table, hyperplane_list, depth, direction)?; @@ -501,17 +558,17 @@ where let (hyperplane_list_bytes, _) = res?; - Ok(HyperplaneList::decode(&hyperplane_list_bytes.value()).map(Some)) + Ok(Hypertree::decode(&hyperplane_list_bytes.value()).map(Some)) } fn scan<'table, const N: usize, T, V>( table: &'table T, - hyperplane_list: &HyperplaneList, + hyperplane_list: &Hypertree, depth: usize, direction: Bounded, -) -> Result, V>, StorageError> +) -> Result, V>, StorageError> where - T: ReadableTable, V> + 'table, + T: ReadableTable, V> + 'table, V: RedbValue, { let range = hyperplane_list.to_range(depth, direction); @@ -521,28 +578,42 @@ where impl HypertreeDb { pub(super) fn open>( - path: P, + tree_dir: P, max_bucket_size: usize, ) -> Result { - let mut database = Database::create(path)?; + std::fs::create_dir_all(tree_dir.as_ref())?; + + let mut database = Database::create(tree_dir.as_ref().join("hypertree.redb"))?; + let mut ingest_queue = Database::create(tree_dir.as_ref().join("ingest.redb"))?; + database.check_integrity()?; + ingest_queue.check_integrity()?; + database.compact()?; + ingest_queue.compact()?; + + let txn = database.begin_write()?; + txn.open_table(REBUILD_TABLE)?; + txn.commit()?; Ok(Self { database, + ingest_queue, + rebuilding: AtomicBool::new(false), + cleanup: AtomicBool::new(false), max_bucket_size, _size: std::marker::PhantomData, }) } - pub(super) fn find_similarities<'db, 'txn, T>( + pub(super) fn find_similarities( &self, vector_table: &T, - query_vector: Vector, + query_vector: &Vector, limit: usize, - ) -> Result, TreeError> + ) -> Result)>, TreeError> where - T: ReadableTable>, + T: ReadableTable>, { let txn = self.database.begin_read()?; @@ -560,14 +631,25 @@ impl HypertreeDb { ) } - pub(super) fn add_to_index<'db, 'txn, T>( + pub(super) fn add_to_index( &self, - vector_id: u128, + vector_id: InternalVectorId, vector_table: &T, ) -> Result<(), TreeError> where - T: ReadableTable>, + T: ReadableTable>, { + if self.rebuilding.load(std::sync::atomic::Ordering::Acquire) { + let txn = self.ingest_queue.begin_write()?; + + txn.open_multimap_table(QUEUE_TABLE)? + .insert(QUEUE_KEY, vector_id)?; + + txn.commit()?; + + return Ok(()); + } + let txn = self.database.begin_write()?; let mut ids_table = txn.open_table(IDS_TABLE)?; @@ -575,10 +657,12 @@ impl HypertreeDb { let mut inverse_hyperplane_table = txn.open_table(inverse_hyperplane_table_def())?; let mut hypertree_table = txn.open_table(hypertree_table_def())?; let mut bucket_multimap = txn.open_multimap_table(BUCKET_MULTIMAP)?; + let mut internal_vector_table = txn.open_table(internal_vector_table())?; do_add_to_index( &mut ids_table, vector_table, + &mut internal_vector_table, &mut hyperplane_table, &mut inverse_hyperplane_table, &mut hypertree_table, @@ -587,6 +671,13 @@ impl HypertreeDb { self.max_bucket_size, )?; + drop(internal_vector_table); + drop(bucket_multimap); + drop(hypertree_table); + drop(inverse_hyperplane_table); + drop(hyperplane_table); + drop(ids_table); + txn.commit()?; Ok(()) @@ -594,12 +685,28 @@ impl HypertreeDb { pub(super) fn add_many_to_index( &self, - vectors: &[u128], + vectors: &[InternalVectorId], vector_table: &T, ) -> Result<(), TreeError> where - T: ReadableTable>, + T: ReadableTable>, { + if self.rebuilding.load(std::sync::atomic::Ordering::Acquire) { + let txn = self.ingest_queue.begin_write()?; + + let mut queue_table = txn.open_multimap_table(QUEUE_TABLE)?; + + for vector_id in vectors { + queue_table.insert(QUEUE_KEY, *vector_id)?; + } + + drop(queue_table); + + txn.commit()?; + + return Ok(()); + } + let txn = self.database.begin_write()?; let mut ids_table = txn.open_table(IDS_TABLE)?; @@ -607,11 +714,13 @@ impl HypertreeDb { let mut inverse_hyperplane_table = txn.open_table(inverse_hyperplane_table_def())?; let mut hypertree_table = txn.open_table(hypertree_table_def())?; let mut bucket_multimap = txn.open_multimap_table(BUCKET_MULTIMAP)?; + let mut internal_vector_table = txn.open_table(internal_vector_table())?; for vector_id in vectors { do_add_to_index( &mut ids_table, vector_table, + &mut internal_vector_table, &mut hyperplane_table, &mut inverse_hyperplane_table, &mut hypertree_table, @@ -621,36 +730,164 @@ impl HypertreeDb { )?; } + drop(internal_vector_table); + drop(bucket_multimap); + drop(hypertree_table); + drop(inverse_hyperplane_table); + drop(hyperplane_table); + drop(ids_table); + txn.commit()?; Ok(()) } - pub(super) fn rebuild_hypertree<'db, 'txn, T>(&self, vector_table: &T) -> Result<(), TreeError> + pub(super) fn should_rebuild(&self, vector_table: &T) -> Result where - T: ReadableTable>, + T: ReadableTable>, { + let txn = self.database.begin_read()?; + + let vector_len = vector_table.len()?; + + if vector_len == 0 { + return Ok(false); + } + + if vector_len / 2 < u64::try_from(self.max_bucket_size).expect("bucket size is reasonable") + { + return Ok(false); + } + + let rebuild_table = txn.open_table(REBUILD_TABLE)?; + + if let Some(previous_size) = rebuild_table.get(REBUILD_KEY)? { + if vector_len / 2 < previous_size.value() { + return Ok(false); + } + } + + Ok(true) + } + + pub(super) fn is_rebuilding(&self) -> bool { + self.rebuilding.load(std::sync::atomic::Ordering::Acquire) + } + + pub(super) fn rebuild_hypertree(&self, vector_table: &T) -> Result<(), TreeError> + where + T: ReadableTable>, + { + println!("Rebuild"); + if self + .rebuilding + .swap(true, std::sync::atomic::Ordering::AcqRel) + { + println!("Already rebuilding"); + return Ok(()); + } + + if self.cleanup.load(std::sync::atomic::Ordering::Acquire) { + println!("In cleanup"); + return Ok(()); + } + + let rebuild_guard = SetOnDrop(&self.rebuilding); + let txn = self.database.begin_write()?; - let mut ids_table = txn.open_table(IDS_TABLE)?; - let mut hyperplane_table = txn.open_table(hyperplane_table_def())?; - let mut inverse_hyperplane_table = txn.open_table(inverse_hyperplane_table_def())?; - let mut hypertree_table = txn.open_table(hypertree_table_def())?; let mut bucket_multimap = txn.open_multimap_table(BUCKET_MULTIMAP)?; + let mut hyperplane_table = txn.open_table(hyperplane_table_def())?; + let mut hypertree_table = txn.open_table(hypertree_table_def())?; + let mut ids_table = txn.open_table(IDS_TABLE)?; + let mut inverse_hyperplane_table = txn.open_table(inverse_hyperplane_table_def())?; + let mut rebuild_table = txn.open_table(REBUILD_TABLE)?; - ids_table.drain(..)?; - hyperplane_table.drain(..)?; - inverse_hyperplane_table.drain(..)?; - hypertree_table.drain(..)?; + let vector_len = vector_table.len()?; + + if vector_len == 0 { + println!("No vectors"); + return Ok(()); + } + + if vector_len / 2 < u64::try_from(self.max_bucket_size).expect("bucket size is reasonable") + { + println!("Too small to build"); + return Ok(()); + } + + if let Some(previous_size) = rebuild_table.get(REBUILD_KEY)? { + if vector_len / 2 < previous_size.value() { + println!("Too soon after last build"); + return Ok(()); + } + } + + rebuild_table.insert(REBUILD_KEY, vector_len)?; + + println!("Cleaning tables"); + // TODO: drain fix + let id_keys = ids_table + .iter()? + .map(|res| res.map(|(name, _)| name.value().to_string())) + .collect::, _>>()?; + + for key in id_keys { + ids_table.remove(key.as_str())?; + } + + let hyperplane_ids = hyperplane_table + .iter()? + .map(|res| res.map(|(id, _)| id.value())) + .collect::, _>>()?; + + for id in hyperplane_ids { + hyperplane_table.remove(id)?; + } + + let hyperplanes = inverse_hyperplane_table + .iter()? + .map(|res| { + res.map_err(TreeError::from).and_then(|(plane, _)| { + Hyperplane::decode(&plane.value()).map_err(TreeError::from) + }) + }) + .collect::, _>>()?; + + for plane in hyperplanes { + inverse_hyperplane_table.remove(plane.encode())?; + } + + let hypertrees = hypertree_table + .iter()? + .map(|res| { + res.map_err(TreeError::from) + .and_then(|(tree, _)| Hypertree::decode(&tree.value()).map_err(TreeError::from)) + }) + .collect::, _>>()?; + + for tree in hypertrees { + hypertree_table.remove(tree.encode())?; + } + + /* + * TODO: drain fix + ids_table.drain::<&'static str>(..)?; + hyperplane_table.drain::(..)?; + inverse_hyperplane_table.drain::>(..)?; + hypertree_table.drain::>(..)?; + */ let buckets = bucket_multimap .iter()? .map(|res| res.map(|(k, _)| k.value())) .collect::, _>>()?; + for bucket in buckets { bucket_multimap.remove_all(bucket)?; } + println!("Building hypertree"); do_build_hypertree( &mut bucket_multimap, &mut hyperplane_table, @@ -659,7 +896,76 @@ impl HypertreeDb { &mut inverse_hyperplane_table, vector_table, self.max_bucket_size, - ) + )?; + + drop(rebuild_table); + drop(inverse_hyperplane_table); + drop(ids_table); + drop(hypertree_table); + drop(hyperplane_table); + drop(bucket_multimap); + + txn.commit()?; + + self.cleanup + .store(true, std::sync::atomic::Ordering::Release); + drop(rebuild_guard); + + Ok(()) + } + + pub(super) fn cleanup(&self, vector_table: &T) -> Result + where + T: ReadableTable>, + { + println!("Performing cleanup"); + let cleanup_guard = SetOnDrop(&self.cleanup); + self.cleanup + .store(true, std::sync::atomic::Ordering::Release); + + let txn = self.ingest_queue.begin_write()?; + + let mut queue_table = txn.open_multimap_table(QUEUE_TABLE)?; + + let mut requeue = Vec::new(); + + let mut chunk = Vec::with_capacity(1024); + for res in queue_table.remove_all(QUEUE_KEY)? { + let id = res?.value(); + + if vector_table.get(id)?.is_none() { + requeue.push(id); + continue; + } + + chunk.push(id); + + if chunk.len() < 1024 { + continue; + } + + self.add_many_to_index(&chunk, vector_table)?; + + chunk.clear(); + } + + let needs_requeue = requeue.len() > 0; + + for id in requeue { + queue_table.insert(QUEUE_KEY, id)?; + } + + self.add_many_to_index(&chunk, vector_table)?; + + drop(queue_table); + + txn.commit()?; + + drop(cleanup_guard); + + println!("Cleanup complete"); + + Ok(needs_requeue) } } @@ -715,13 +1021,13 @@ impl Hyperplane { } } -impl HyperplaneList { +impl Hypertree { const SEGMENT_LEN: usize = 17; // u128 + u8 - fn append(&self, hyperplane_id: u128, bound: Bounded) -> Self { + fn append(&self, hyperplane_id: HyperplaneId, bound: Bounded) -> Self { let mut hyperplanes = self.hyperplanes.clone(); hyperplanes.push((hyperplane_id, bound)); - HyperplaneList { + Hypertree { hyperplanes, _size: std::marker::PhantomData, } @@ -734,11 +1040,11 @@ impl HyperplaneList { .unwrap_or(false) } - fn get(&self, depth: usize) -> Option { + fn get(&self, depth: usize) -> Option { self.hyperplanes.get(depth).map(|(h, _)| *h) } - fn to_range(&self, depth: usize, direction: Bounded) -> HyperplaneListBytesRange<'static, N> { + fn to_range(&self, depth: usize, direction: Bounded) -> HypertreeBytesRange<'static, N> { let bound_capacity = (depth + 1) * Self::SEGMENT_LEN; let mut lower = Vec::with_capacity(bound_capacity); @@ -767,19 +1073,19 @@ impl HyperplaneList { } } - HyperplaneListBytesRange { - lower: HyperplaneListBytes { + HypertreeBytesRange { + lower: HypertreeBytes { hyperplanes: std::borrow::Cow::Owned(lower), _size: std::marker::PhantomData, }, - upper: HyperplaneListBytes { + upper: HypertreeBytes { hyperplanes: std::borrow::Cow::Owned(upper), _size: std::marker::PhantomData, }, } } - fn encode(&self) -> HyperplaneListBytes<'static, N> { + fn encode(&self) -> HypertreeBytes<'static, N> { let capacity = self.hyperplanes.len() * Self::SEGMENT_LEN; let mut bytes = Vec::with_capacity(capacity); @@ -789,20 +1095,20 @@ impl HyperplaneList { bytes.push(b.encode().to_byte()); } - HyperplaneListBytes { + HypertreeBytes { hyperplanes: std::borrow::Cow::Owned(bytes), _size: std::marker::PhantomData, } } fn decode( - HyperplaneListBytes { hyperplanes, _size }: &HyperplaneListBytes, + HypertreeBytes { hyperplanes, _size }: &HypertreeBytes, ) -> Result { if hyperplanes.len() % Self::SEGMENT_LEN != 0 { return Err(DecodeError::InvalidLength); } - Ok(HyperplaneList { + Ok(Hypertree { hyperplanes: hyperplanes .chunks_exact(Self::SEGMENT_LEN) .map(|chunk| { @@ -810,7 +1116,7 @@ impl HyperplaneList { for (slot, byte) in id_bytes.iter_mut().zip(&chunk[..16]) { *slot = *byte; } - let hyperplane_id = u128::from_be_bytes(id_bytes); + let hyperplane_id = HyperplaneId::from_be_bytes(id_bytes); let bounded = Bounded::decode(BoundedByte::from_byte(chunk[16]))?; @@ -822,6 +1128,16 @@ impl HyperplaneList { } } +impl HyperplaneId { + fn from_be_bytes(bytes: [u8; 16]) -> Self { + Self(u128::from_be_bytes(bytes)) + } + + fn to_be_bytes(&self) -> [u8; 16] { + self.0.to_be_bytes() + } +} + impl Bounded { const fn to_range_bounds(&self) -> (BoundedByte, BoundedByte) { match self { @@ -878,14 +1194,14 @@ impl Vector { } } -impl<'a, const N: usize> std::ops::RangeBounds> - for HyperplaneListBytesRange<'a, N> +impl<'a, const N: usize> std::ops::RangeBounds> + for HypertreeBytesRange<'a, N> { - fn start_bound(&self) -> std::ops::Bound<&HyperplaneListBytes<'a, N>> { + fn start_bound(&self) -> std::ops::Bound<&HypertreeBytes<'a, N>> { std::ops::Bound::Included(&self.lower) } - fn end_bound(&self) -> std::ops::Bound<&HyperplaneListBytes<'a, N>> { + fn end_bound(&self) -> std::ops::Bound<&HypertreeBytes<'a, N>> { std::ops::Bound::Excluded(&self.upper) } } @@ -933,8 +1249,8 @@ impl RedbKey for HyperplaneBytes<'static, N> { } } -impl RedbValue for HyperplaneListBytes<'static, N> { - type SelfType<'a> = HyperplaneListBytes<'a, N>; +impl RedbValue for HypertreeBytes<'static, N> { + type SelfType<'a> = HypertreeBytes<'a, N>; type AsBytes<'a> = &'a [u8]; fn fixed_width() -> Option { @@ -946,12 +1262,12 @@ impl RedbValue for HyperplaneListBytes<'static, N> { Self: 'a, { assert_eq!( - data.len() % HyperplaneList::::SEGMENT_LEN, + data.len() % Hypertree::::SEGMENT_LEN, 0, "Byte length is not Hyperplane List byte length" ); - HyperplaneListBytes { + HypertreeBytes { hyperplanes: std::borrow::Cow::Borrowed(data), _size: std::marker::PhantomData, } @@ -966,12 +1282,98 @@ impl RedbValue for HyperplaneListBytes<'static, N> { } fn type_name() -> TypeName { - TypeName::new("vectordb::HyperplaneListBytes") + TypeName::new("vectordb::HypertreeBytes") } } -impl RedbKey for HyperplaneListBytes<'static, N> { +impl RedbKey for HypertreeBytes<'static, N> { fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering { data1.cmp(data2) } } + +impl RedbValue for HyperplaneId { + type SelfType<'a> = HyperplaneId; + type AsBytes<'a> = ::AsBytes<'a>; + + fn fixed_width() -> Option { + u128::fixed_width() + } + + fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a> + where + Self: 'a, + { + Self(u128::from_bytes(data)) + } + + fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a> + where + Self: 'a, + Self: 'b, + { + u128::as_bytes(&value.0) + } + + fn type_name() -> TypeName { + TypeName::new("vectordb::HyperplaneId") + } +} + +impl RedbKey for HyperplaneId { + fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering { + u128::compare(data1, data2) + } +} + +impl RedbValue for BucketId { + type SelfType<'a> = BucketId; + type AsBytes<'a> = ::AsBytes<'a>; + + fn fixed_width() -> Option { + u128::fixed_width() + } + + fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a> + where + Self: 'a, + { + Self(u128::from_bytes(data)) + } + + fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a> + where + Self: 'a, + Self: 'b, + { + u128::as_bytes(&value.0) + } + + fn type_name() -> TypeName { + TypeName::new("vectordb::BucketId") + } +} + +impl RedbKey for BucketId { + fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering { + u128::compare(data1, data2) + } +} + +impl std::fmt::Display for HyperplaneId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + u128::fmt(&self.0, f) + } +} + +impl std::fmt::Display for BucketId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + u128::fmt(&self.0, f) + } +} + +impl<'a> Drop for SetOnDrop<'a> { + fn drop(&mut self) { + self.0.store(false, std::sync::atomic::Ordering::Release); + } +} diff --git a/src/lib.rs b/src/lib.rs index 34f3311..13faeff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,116 +1,91 @@ -mod flatten_results; mod hypertree; use std::{marker::PhantomData, path::Path}; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use redb::{ CommitError, CompactionError, Database, DatabaseError, MultimapTable, MultimapTableDefinition, - MultimapTableHandle, ReadTransaction, ReadableMultimapTable, ReadableTable, RedbKey, RedbValue, - StorageError, Table, TableDefinition, TableError, TableHandle, TransactionError, TypeName, - WriteTransaction, + ReadTransaction, ReadableMultimapTable, ReadableTable, RedbKey, RedbValue, StorageError, Table, + TableDefinition, TableError, TableHandle, TransactionError, TypeName, WriteTransaction, }; -pub struct VectorDb { +#[derive(Clone, Debug)] +pub struct VectorDb { + inner: std::sync::Arc>, +} + +#[derive(Debug)] +struct VectorDbInner { database: Database, + hypertrees: [hypertree::HypertreeDb; H], + next_rebuild: std::sync::atomic::AtomicUsize, namespace: String, - max_bucket_size: usize, - target_hypertrees: u64, _size: PhantomData>, } -impl VectorDb { - pub fn open>( - db_path: P, - namespace: String, - target_hypertrees: u64, - ) -> Result> { - let mut database = Database::create(db_path)?; +#[derive(Clone, Debug, PartialEq, PartialOrd)] +pub struct Vector([f32; N]); - database.check_integrity()?; - database.compact()?; +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct VectorBytes<'a, const N: usize> { + vector_bytes: std::borrow::Cow<'a, [u8]>, + _phantom: PhantomData>, +} - Ok(VectorDb { - database, - namespace, - max_bucket_size: 128, - target_hypertrees, - _size: PhantomData, - }) - } +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct VectorId(u128); - pub fn insert_vector(&self, vector: &Vector) -> Result { - let mut txn = self.database.begin_write()?; +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +struct InternalVectorId(u128); - let id = insert_vector(&mut txn, &self.namespace, vector)?; +#[derive(Debug)] +pub enum DecodeError { + Empty, + InvalidLength, + Bound, +} - add_to_index::(&mut txn, &self.namespace, id, self.max_bucket_size)?; - - manage_hypertrees::( - &mut txn, - &self.namespace, - self.max_bucket_size, - self.target_hypertrees, - )?; - - txn.commit()?; - - Ok(id) - } - - pub fn insert_many_vectors(&self, vectors: &[Vector]) -> Result, TreeError> { - let mut txn = self.database.begin_write()?; - - let ids = insert_many_vectors(&mut txn, &self.namespace, vectors)?; - - add_many_to_index::(&mut txn, &self.namespace, &ids, self.max_bucket_size)?; - - manage_hypertrees::( - &mut txn, - &self.namespace, - self.max_bucket_size, - self.target_hypertrees, - )?; - - txn.commit()?; - - Ok(ids) - } - - pub fn get_vector(&self, vector_id: u128) -> Result>, TreeError> { - let mut txn = self.database.begin_read()?; - - Ok(get_vector(&mut txn, &self.namespace, vector_id)??) - } - - pub fn create_hypertree(&self) -> Result<(), TreeError> { - let mut txn = self.database.begin_write()?; - - build_hypertree::(&mut txn, &self.namespace, self.max_bucket_size)?; - - txn.commit()?; - - Ok(()) - } - - pub fn find_similarities( - &self, - vector: Vector, - limit: usize, - ) -> Result, TreeError> { - let mut txn = self.database.begin_read()?; - - get_similar_vectors(&mut txn, &self.namespace, vector, limit) - } +#[derive(Debug)] +pub enum TreeError { + Commit(CommitError), + Compact(CompactionError), + Database(DatabaseError), + Decode(DecodeError), + Io(std::io::Error), + Storage(StorageError), + Table(TableError), + Transaction(TransactionError), } const IDS_TABLE: TableDefinition<'static, &'static str, u128> = TableDefinition::new("vectordb::ids"); -macro_rules! vector_multimap { +macro_rules! vector_insert_table { ($variable_name:ident, $namespace:ident, $n:ident) => { - let $variable_name = format!("vectord::vector_multimap::<{}>::{}", $n, $namespace); + let $variable_name = format!("vectord::vector_insert_table::<{}>::{}", $n, $namespace); let $variable_name = - MultimapTableDefinition::<'_, VectorBytes<'static, $n>, u128>::new(&$variable_name); + TableDefinition::<'_, VectorId, InternalVectorId>::new(&$variable_name); + }; +} + +macro_rules! inverse_vector_insert_multimap { + ($variable_name:ident, $namespace:ident, $n:ident) => { + let $variable_name = format!( + "vectord::inverse_vector_insert_multimap::<{}>::{}", + $n, $namespace + ); + let $variable_name = + MultimapTableDefinition::<'_, InternalVectorId, VectorId>::new(&$variable_name); + }; +} + +macro_rules! inverse_vector_table { + ($variable_name:ident, $namespace:ident, $n:ident) => { + let $variable_name = format!("vectord::inverse_vector_table::<{}>::{}", $n, $namespace); + let $variable_name = + TableDefinition::<'_, VectorBytes<'static, $n>, InternalVectorId>::new(&$variable_name); }; } @@ -118,90 +93,263 @@ macro_rules! vector_table { ($variable_name:ident, $namespace:ident, $n:ident) => { let $variable_name = format!("vectordb::vector_table::<{}>::{}", $n, $namespace); let $variable_name = - TableDefinition::<'_, u128, VectorBytes<'static, $n>>::new(&$variable_name); + TableDefinition::<'_, InternalVectorId, VectorBytes<'static, $n>>::new(&$variable_name); }; } -macro_rules! bucket_multimap { - ($variable_name:ident, $namespace:ident, $n:ident) => { - let $variable_name = format!("vectordb::bucket_table::<{}>::{}", $n, $namespace); - let $variable_name = MultimapTableDefinition::<'_, u128, u128>::new(&$variable_name); - }; -} +impl VectorDb { + pub fn open>(db_directory: P, namespace: String) -> Result { + std::fs::create_dir_all(db_directory.as_ref())?; -macro_rules! hypertree_table { - ($variable_name:ident, $namespace:ident, $n:ident) => { - let $variable_name = format!("vectordb::hypertree_table::<{}>::{}", $n, $namespace); - let $variable_name = - TableDefinition::<'_, HyperplaneBytesList<'static, N>, u128>::new(&$variable_name); - }; -} + let mut database = Database::create(db_directory.as_ref().join("vectordb.redb"))?; -macro_rules! hyperplane_table { - ($variable_name:ident, $namespace:ident, $n:ident) => { - let $variable_name = format!("vectordb::hyperplane_table::<{}>::{}", $n, $namespace); - let $variable_name = - TableDefinition::<'_, u128, HyperplaneBytes<'static, N>>::new(&$variable_name); - }; -} + database.check_integrity()?; + database.compact()?; -macro_rules! inverse_hyperplane_table { - ($variable_name:ident, $namespace:ident, $n:ident) => { - let $variable_name = format!( - "vectordb::inverse_hyperplane_table::<{}>::{}", - $n, $namespace - ); - let $variable_name = - TableDefinition::<'_, HyperplaneBytes<'static, N>, u128>::new(&$variable_name); - }; -} + let max_bucket_size = 128; -macro_rules! toplevel_table { - ($variable_name:ident, $namespace:ident, $n:ident) => { - let $variable_name = format!("vectordb::toplevel_table::<{}>::{}", $n, $namespace); - let $variable_name = TableDefinition::<'_, u128, u64>::new(&$variable_name); - }; -} + let hypertrees = (0..H) + .map(|i| { + hypertree::HypertreeDb::open( + db_directory + .as_ref() + .join("hypertrees") + .join(format!("tree-{i}")), + max_bucket_size, + ) + }) + .collect::, _>>()? + .try_into() + .expect("Correct number of hypertree dbs"); -fn insert_hyperplane<'db, 'txn, const N: usize>( - ids_table: &mut Table<'db, 'txn, &'static str, u128>, - hyperplane_table: &mut Table<'db, 'txn, u128, HyperplaneBytes>, - inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, u128>, - namespace: &str, - hyperplane: &Hyperplane, -) -> Result { - hyperplane_table!(hyperplane_def, namespace, N); - - let encoded = hyperplane.encode(); - - if let Some(id) = inverse_hyperplane_table.get(&encoded)? { - return Ok(id.value()); + Ok(VectorDb { + inner: std::sync::Arc::new(VectorDbInner { + database, + hypertrees, + next_rebuild: std::sync::atomic::AtomicUsize::new(0), + namespace, + _size: PhantomData, + }), + }) } - let id = next_id(ids_table, &hyperplane_def)?; + pub fn insert_vector(&self, vector: &Vector) -> Result { + let mut txn = self.inner.database.begin_write()?; - hyperplane_table.insert(id, &encoded)?; - inverse_hyperplane_table.insert(encoded, id)?; + let (id, internal_id) = insert_vector(&mut txn, &self.inner.namespace, vector)?; - Ok(id) + let namespace = &self.inner.namespace; + vector_table!(vector_def, namespace, N); + let vector_table = txn.open_table(vector_def)?; + + if let Some(internal_id) = internal_id { + self.inner + .hypertrees + .par_iter() + .try_for_each(|hypertree| hypertree.add_to_index(internal_id, &vector_table))?; + } + + drop(vector_table); + txn.commit()?; + + self.try_rebuild()?; + + Ok(id) + } + + pub fn insert_many_vectors(&self, vectors: &[Vector]) -> Result, TreeError> { + let mut txn = self.inner.database.begin_write()?; + + let (ids, internal_ids) = insert_many_vectors(&mut txn, &self.inner.namespace, vectors)?; + + let namespace = &self.inner.namespace; + vector_table!(vector_def, namespace, N); + let vector_table = txn.open_table(vector_def)?; + + self.inner + .hypertrees + .par_iter() + .try_for_each(|hypertree| hypertree.add_many_to_index(&internal_ids, &vector_table))?; + + drop(vector_table); + txn.commit()?; + + self.try_rebuild()?; + + Ok(ids) + } + + pub fn get_vector(&self, vector_id: VectorId) -> Result>, TreeError> { + let mut txn = self.inner.database.begin_read()?; + + Ok(get_vector(&mut txn, &self.inner.namespace, vector_id)?) + } + + pub fn find_similarities( + &self, + vector: Vector, + limit: usize, + ) -> Result, TreeError> { + let txn = self.inner.database.begin_read()?; + + let namespace = &self.inner.namespace; + inverse_vector_insert_multimap!(inverse_vector_insert_multidef, namespace, N); + vector_table!(vector_def, namespace, N); + let inverse_vector_insert_multimap = + txn.open_multimap_table(inverse_vector_insert_multidef)?; + let vector_table = txn.open_table(vector_def)?; + + let mut candidates = self + .inner + .hypertrees + .par_iter() + .map(|hypertree| hypertree.find_similarities(&vector_table, &vector, limit)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>(); + + candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Numbers are finite")); + candidates.dedup_by_key(|a| a.1); + candidates.truncate(limit); + + let vectors = candidates + .into_iter() + .map(|(_, internal_vector_id, _)| { + inverse_vector_insert_multimap + .get(internal_vector_id) + .and_then(|iter| { + iter.map(|res| res.map(|value| value.value())) + .collect::, _>>() + }) + }) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + + Ok(vectors) + } + + fn try_rebuild(&self) -> Result<(), TreeError> { + let namespace = &self.inner.namespace; + vector_table!(vector_def, namespace, N); + + let next_rebuild = self.next_rebuild(); + + let txn = self.inner.database.begin_read()?; + let vector_table = txn.open_table(vector_def)?; + + if self.inner.hypertrees[next_rebuild].should_rebuild(&vector_table)? { + let this = self.clone(); + + rayon::spawn(move || { + if !this.progress_next_rebuild(next_rebuild) { + // another thread progressed the rebuild first + return; + } + + let namespace = &this.inner.namespace; + vector_table!(vector_def, namespace, N); + + let Ok(txn) = this.inner.database.begin_read() else { + eprintln!("Failed to read db"); + return; + }; + + let Ok(vector_table) = txn.open_table(vector_def) else { + eprintln!("Failed to open vector table"); + return; + }; + + for hypertree in &this.inner.hypertrees { + if hypertree.is_rebuilding() { + return; + } + } + + let Ok(_) = this.inner.hypertrees[next_rebuild].rebuild_hypertree(&vector_table) else { + eprintln!("Failed to rebuild hypertree"); + return; + }; + + loop { + let Ok(txn) = this.inner.database.begin_read() else { + eprintln!("Failed to read db"); + return; + }; + + let Ok(vector_table) = txn.open_table(vector_def) else { + eprintln!("Failed to open vector table"); + return; + }; + + let Ok(needs_requeue) = this.inner.hypertrees[next_rebuild].cleanup(&vector_table) else { + eprintln!("Failed to cleanup vector table"); + return; + }; + + if !needs_requeue { + break; + } + + println!("Needs requeue, looping"); + rayon::yield_now(); + } + }); + } + + Ok(()) + } + + fn next_rebuild(&self) -> usize { + self.inner + .next_rebuild + .load(std::sync::atomic::Ordering::Acquire) + } + + fn progress_next_rebuild(&self, last_rebuild: usize) -> bool { + let next_rebuild = if last_rebuild + 1 >= H { + 0 + } else { + last_rebuild + 1 + }; + + match self.inner.next_rebuild.compare_exchange( + last_rebuild, + next_rebuild, + std::sync::atomic::Ordering::AcqRel, + std::sync::atomic::Ordering::Relaxed, + ) { + Ok(_) => true, + Err(_) => false, + } + } } fn insert_vector<'db, const N: usize>( transaction: &mut WriteTransaction<'db>, namespace: &str, vector: &Vector, -) -> Result { +) -> Result<(VectorId, Option), TreeError> { + inverse_vector_insert_multimap!(inverse_vector_insert_multidef, namespace, N); + inverse_vector_table!(inverse_vector_def, namespace, N); + vector_insert_table!(vector_insert_def, namespace, N); vector_table!(vector_def, namespace, N); - vector_multimap!(vector_multidef, namespace, N); let mut ids_table = transaction.open_table(IDS_TABLE)?; + let mut inverse_vector_insert_multimap = + transaction.open_multimap_table(inverse_vector_insert_multidef)?; + let mut inverse_vector_table = transaction.open_table(inverse_vector_def)?; + let mut vector_insert_table = transaction.open_table(vector_insert_def)?; let mut vector_table = transaction.open_table(vector_def)?; - let mut vector_multimap = transaction.open_multimap_table(vector_multidef)?; do_insert_vector( &mut ids_table, + &mut inverse_vector_insert_multimap, + &mut inverse_vector_table, + &mut vector_insert_table, &mut vector_table, - &mut vector_multimap, namespace, vector, ) @@ -211,551 +359,117 @@ fn insert_many_vectors<'db, const N: usize>( transaction: &mut WriteTransaction<'db>, namespace: &str, vectors: &[Vector], -) -> Result, TreeError> { +) -> Result<(Vec, Vec), TreeError> { + inverse_vector_insert_multimap!(inverse_vector_insert_multidef, namespace, N); + inverse_vector_table!(inverse_vector_def, namespace, N); + vector_insert_table!(vector_insert_def, namespace, N); vector_table!(vector_def, namespace, N); - vector_multimap!(vector_multidef, namespace, N); let mut ids_table = transaction.open_table(IDS_TABLE)?; + let mut inverse_vector_insert_multimap = + transaction.open_multimap_table(inverse_vector_insert_multidef)?; + let mut inverse_vector_table = transaction.open_table(inverse_vector_def)?; + let mut vector_insert_table = transaction.open_table(vector_insert_def)?; let mut vector_table = transaction.open_table(vector_def)?; - let mut vector_multimap = transaction.open_multimap_table(vector_multidef)?; - vectors + let (vectors, internal_vectors) = vectors .iter() .map(|vector| { do_insert_vector( &mut ids_table, + &mut inverse_vector_insert_multimap, + &mut inverse_vector_table, + &mut vector_insert_table, &mut vector_table, - &mut vector_multimap, namespace, vector, ) }) - .collect::, _>>() + .collect::, _>>()? + .into_iter() + .unzip::<_, _, _, Vec>>(); + + let internal_vectors = internal_vectors.into_iter().filter_map(|id| id).collect(); + + Ok((vectors, internal_vectors)) } fn do_insert_vector<'db, 'txn, const N: usize>( ids_table: &mut Table<'db, 'txn, &'static str, u128>, - vector_table: &mut Table<'db, 'txn, u128, VectorBytes>, - vector_multimap: &mut MultimapTable<'db, 'txn, VectorBytes, u128>, + inverse_vector_insert_multimap: &mut MultimapTable<'db, 'txn, InternalVectorId, VectorId>, + inverse_vector_table: &mut Table<'db, 'txn, VectorBytes, InternalVectorId>, + vector_insert_table: &mut Table<'db, 'txn, VectorId, InternalVectorId>, + vector_table: &mut Table<'db, 'txn, InternalVectorId, VectorBytes>, namespace: &str, vector: &Vector, -) -> Result { +) -> Result<(VectorId, Option), TreeError> { + vector_insert_table!(vector_insert_def, namespace, N); vector_table!(vector_def, namespace, N); - let vector_id = next_id(ids_table, &vector_def)?; + let vector_id = VectorId(next_id(ids_table, vector_insert_def)?); let vector_bytes = vector.encode(); - vector_table.insert(vector_id, vector_bytes.clone())?; - vector_multimap.insert(vector_bytes, vector_id)?; + let internal_vector_id_opt = inverse_vector_table + .get(&vector_bytes)? + .map(|id| id.value()); - Ok(vector_id) + let internal_vector_id = if let Some(internal_vector_id) = internal_vector_id_opt { + vector_insert_table.insert(vector_id, internal_vector_id)?; + inverse_vector_insert_multimap.insert(internal_vector_id, vector_id)?; + + None + } else { + let internal_vector_id = InternalVectorId(next_id(ids_table, vector_def)?); + + vector_table.insert(internal_vector_id, &vector_bytes)?; + inverse_vector_table.insert(&vector_bytes, internal_vector_id)?; + vector_insert_table.insert(vector_id, internal_vector_id)?; + inverse_vector_insert_multimap.insert(internal_vector_id, vector_id)?; + + Some(internal_vector_id) + }; + + Ok((vector_id, internal_vector_id)) } fn get_vector<'db, const N: usize>( transaction: &mut ReadTransaction<'db>, namespace: &str, - vector_id: u128, -) -> Result>, DecodeError>, TreeError> { + vector_id: VectorId, +) -> Result>, TreeError> { + vector_insert_table!(vector_insert_def, namespace, N); vector_table!(vector_def, namespace, N); + let vector_insert_table = transaction.open_table(vector_insert_def)?; let vector_table = transaction.open_table(vector_def)?; - do_get_vector(&vector_table, vector_id).map_err(TreeError::from) + let Some(internal_vector_id) = vector_insert_table.get(vector_id)? else { + return Ok(None); + }; + + do_get_vector(&vector_table, internal_vector_id.value()) } fn do_get_vector( vector_table: &T, - vector_id: u128, -) -> Result>, DecodeError>, StorageError> + internal_vector_id: InternalVectorId, +) -> Result>, TreeError> where - T: ReadableTable>, + T: ReadableTable>, { - let opt = vector_table.get(vector_id)?; - - Ok(opt - .map(|vector| Vector::decode(&vector.value())) - .transpose()) -} - -fn manage_hypertrees<'db, 'txn, const N: usize>( - transaction: &mut WriteTransaction<'db>, - namespace: &str, - max_bucket_size: usize, - target_hypertrees: u64, -) -> Result<(), TreeError> { - if target_hypertrees == 0 { - return Ok(()); - } - - bucket_multimap!(bucket_multidef, namespace, N); - hyperplane_table!(hyperplane_def, namespace, N); - hypertree_table!(hypertree_def, namespace, N); - inverse_hyperplane_table!(inverse_hyperplane_def, namespace, N); - toplevel_table!(toplevel_def, namespace, N); - vector_multimap!(vector_multidef, namespace, N); - vector_table!(vector_def, namespace, N); - - let mut bucket_multimap = transaction.open_multimap_table(bucket_multidef)?; - let mut hyperplane_table = transaction.open_table(hyperplane_def)?; - let mut hypertree_table = transaction.open_table(hypertree_def)?; - let mut ids_table = transaction.open_table(IDS_TABLE)?; - let mut inverse_hyperplane_table = transaction.open_table(inverse_hyperplane_def)?; - let mut toplevel_table = transaction.open_table(toplevel_def)?; - let vector_multimap = transaction.open_multimap_table(vector_multidef)?; - let vector_table = transaction.open_table(vector_def)?; - - let vector_len = vector_table.len()?; - - if vector_len == 0 { - return Ok(()); - } - - let toplevel_size = toplevel_table.len()?; - - let mut toplevel_iter = toplevel_table.iter()?; - - let (hyperplane_id, creation_size) = match toplevel_iter.next() { - Some(res) => { - drop(toplevel_iter); - - let (a, b) = res?; - (a.value(), b.value()) - } - n @ None => { - drop(n); - drop(toplevel_iter); - - if u64::try_from(max_bucket_size).expect("size is reasonable") <= vector_len / 2 { - do_build_hypertree( - &mut bucket_multimap, - &mut hyperplane_table, - &mut hypertree_table, - &mut ids_table, - &mut inverse_hyperplane_table, - &mut toplevel_table, - &vector_multimap, - &vector_table, - namespace, - max_bucket_size, - )?; - } - - return Ok(()); - } + let Some(vector_bytes) = vector_table.get(internal_vector_id)? else { + return Ok(None); }; - if creation_size > vector_len / 2 { - return Ok(()); - } + let vector = Vector::decode(&vector_bytes.value())?; - if toplevel_size >= target_hypertrees { - println!("Removing hypertree {}", hyperplane_id); - do_remove_hypertree( - &mut toplevel_table, - &mut hypertree_table, - &mut bucket_multimap, - hyperplane_id, - )?; - } - - println!("Building hypertree"); - do_build_hypertree( - &mut bucket_multimap, - &mut hyperplane_table, - &mut hypertree_table, - &mut ids_table, - &mut inverse_hyperplane_table, - &mut toplevel_table, - &vector_multimap, - &vector_table, - namespace, - max_bucket_size, - )?; - - Ok(()) -} - -fn do_remove_hypertree<'db, 'txn, const N: usize>( - toplevel_table: &mut Table<'db, 'txn, u128, u64>, - hypertree_table: &mut Table<'db, 'txn, HyperplaneBytesList<'static, N>, u128>, - bucket_multimap: &mut MultimapTable<'db, 'txn, u128, u128>, - toplevel_hyperplane_id: u128, -) -> Result<(), TreeError> { - toplevel_table.remove(toplevel_hyperplane_id)?; - - let mut lower_bound_bytes = Vec::with_capacity(HyperplaneList::::SEGMENT_LEN); - let mut upper_bound_bytes = Vec::with_capacity(HyperplaneList::::SEGMENT_LEN); - - lower_bound_bytes.extend(toplevel_hyperplane_id.to_be_bytes()); - upper_bound_bytes.extend(toplevel_hyperplane_id.to_be_bytes()); - - lower_bound_bytes.push(BoundedByte::BelowBound.to_byte()); - upper_bound_bytes.push(BoundedByte::AboveBound.to_byte()); - - let lower = HyperplaneBytesList { - hyperplanes: std::borrow::Cow::Owned(lower_bound_bytes), - _size: PhantomData, - }; - - let upper = HyperplaneBytesList { - hyperplanes: std::borrow::Cow::Owned(upper_bound_bytes), - _size: PhantomData, - }; - - // TODO: convert to drain - for res in hypertree_table.range(HyperplaneBytesListRange { lower, upper })? { - let (_, bucket_id) = res?; - let bucket_id = bucket_id.value(); - - bucket_multimap.remove_all(bucket_id)?.for_each(|_| ()); - } - - Ok(()) -} - -fn add_to_index<'db, const N: usize>( - transaction: &mut WriteTransaction<'db>, - namespace: &str, - insert_vector_id: u128, - max_bucket_size: usize, -) -> Result<(), TreeError> { - toplevel_table!(toplevel_def, namespace, N); - vector_table!(vector_def, namespace, N); - vector_multimap!(vector_multidef, namespace, N); - hyperplane_table!(hyperplane_def, namespace, N); - inverse_hyperplane_table!(inverse_hyperplane_def, namespace, N); - hypertree_table!(hypertree_def, namespace, N); - bucket_multimap!(bucket_multidef, namespace, N); - - let mut ids_table = transaction.open_table(IDS_TABLE)?; - let vector_table = transaction.open_table(vector_def)?; - let vector_multimap = transaction.open_multimap_table(vector_multidef)?; - let mut hyperplane_table = transaction.open_table(hyperplane_def)?; - let mut inverse_hyperplane_table = transaction.open_table(inverse_hyperplane_def)?; - let toplevel_table = transaction.open_table(toplevel_def)?; - let mut hypertree_table = transaction.open_table(hypertree_def)?; - let mut bucket_multimap = transaction.open_multimap_table(bucket_multidef)?; - - do_add_to_index( - &mut ids_table, - &vector_table, - &vector_multimap, - &mut hyperplane_table, - &mut inverse_hyperplane_table, - &toplevel_table, - &mut hypertree_table, - &mut bucket_multimap, - namespace, - insert_vector_id, - max_bucket_size, - ) -} - -fn add_many_to_index<'db, const N: usize>( - transaction: &mut WriteTransaction<'db>, - namespace: &str, - insert_vector_ids: &[u128], - max_bucket_size: usize, -) -> Result<(), TreeError> { - vector_table!(vector_def, namespace, N); - vector_multimap!(vector_multidef, namespace, N); - hyperplane_table!(hyperplane_def, namespace, N); - inverse_hyperplane_table!(inverse_hyperplane_def, namespace, N); - toplevel_table!(toplevel_def, namespace, N); - hypertree_table!(hypertree_def, namespace, N); - bucket_multimap!(bucket_multidef, namespace, N); - - let mut ids_table = transaction.open_table(IDS_TABLE)?; - let vector_table = transaction.open_table(vector_def)?; - let vector_multimap = transaction.open_multimap_table(vector_multidef)?; - let mut hyperplane_table = transaction.open_table(hyperplane_def)?; - let mut inverse_hyperplane_table = transaction.open_table(inverse_hyperplane_def)?; - let toplevel_table = transaction.open_table(toplevel_def)?; - let mut hypertree_table = transaction.open_table(hypertree_def)?; - let mut bucket_multimap = transaction.open_multimap_table(bucket_multidef)?; - - for insert_vector_id in insert_vector_ids { - do_add_to_index( - &mut ids_table, - &vector_table, - &vector_multimap, - &mut hyperplane_table, - &mut inverse_hyperplane_table, - &toplevel_table, - &mut hypertree_table, - &mut bucket_multimap, - namespace, - *insert_vector_id, - max_bucket_size, - )?; - } - - Ok(()) -} - -fn do_add_to_index<'db, 'txn, const N: usize>( - ids_table: &mut Table<'db, 'txn, &'static str, u128>, - vector_table: &Table<'db, 'txn, u128, VectorBytes<'static, N>>, - vector_multimap: &MultimapTable<'db, 'txn, VectorBytes, u128>, - hyperplane_table: &mut Table<'db, 'txn, u128, HyperplaneBytes>, - inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes, u128>, - toplevel_table: &Table<'db, 'txn, u128, u64>, - hypertree_table: &mut Table<'db, 'txn, HyperplaneBytesList<'static, N>, u128>, - bucket_multimap: &mut MultimapTable<'db, 'txn, u128, u128>, - namespace: &str, - insert_vector_id: u128, - max_bucket_size: usize, -) -> Result<(), TreeError> { - let Some(insert_vector) = do_get_vector(vector_table, insert_vector_id)?? else { - return Ok(()); - }; - - if let Some(existing_id) = vector_multimap.get(insert_vector.encode())?.next() { - if insert_vector_id != existing_id?.value() { - return Ok(()); - } - }; - - let toplevel = get_toplevel_hyperplane_lists(toplevel_table, hypertree_table)??; - - let buckets = toplevel - .into_par_iter() - .filter_map(|hyperplane_list| { - find_similarity_bucket( - hypertree_table, - hyperplane_table, - hyperplane_list, - &insert_vector, - ) - .map_err(TreeError::from) - .and_then(|res| res.map_err(TreeError::from)) - .transpose() - }) - .collect::, _>>()?; - - for (hyperplane_list, bucket_id) in buckets { - bucket_multimap.insert(bucket_id, insert_vector_id)?; - - let size = bucket_multimap.get(bucket_id)?.count(); - - if size > max_bucket_size { - build_hyperplane( - ids_table, - vector_table, - hyperplane_table, - inverse_hyperplane_table, - hypertree_table, - bucket_multimap, - namespace, - bucket_id, - hyperplane_list, - max_bucket_size, - )?; - } - } - - Ok(()) -} - -fn get_similar_vectors<'db, const N: usize>( - transaction: &mut ReadTransaction<'db>, - namespace: &str, - query_vector: Vector, - limit: usize, -) -> Result, TreeError> { - toplevel_table!(toplevel_def, namespace, N); - hypertree_table!(hypertree_def, namespace, N); - hyperplane_table!(hyperplane_def, namespace, N); - bucket_multimap!(bucket_multidef, namespace, N); - vector_table!(vector_def, namespace, N); - vector_multimap!(vector_multidef, namespace, N); - - let vector_table = transaction.open_table(vector_def)?; - let toplevel_table = transaction.open_table(toplevel_def)?; - let hypertree_table = transaction.open_table(hypertree_def)?; - let hyperplane_table = transaction.open_table(hyperplane_def)?; - let bucket_multimap = transaction.open_multimap_table(bucket_multidef)?; - - let toplevel = get_toplevel_hyperplane_lists(&toplevel_table, &hypertree_table)??; - - let mut candidates = toplevel - .into_par_iter() - .filter_map(|hyperplane_list| { - find_similarity_bucket( - &hypertree_table, - &hyperplane_table, - hyperplane_list, - &query_vector, - ) - .map_err(TreeError::from) - .and_then(|res| res.map_err(TreeError::from)) - .transpose() - }) - .map(|bucket_res| { - let bucket_id = match bucket_res { - Ok((_, bucket_id)) => bucket_id, - Err(e) => return Err(e), - }; - - let bucket = match bucket_multimap.get(bucket_id) { - Ok(bucket) => bucket, - Err(e) => return Err(TreeError::from(e)), - }; - let size = bucket.count(); - - println!("candidates from bucket {}, size {}", bucket_id, size); - - let bucket = match bucket_multimap.get(bucket_id) { - Ok(bucket) => bucket, - Err(e) => return Err(TreeError::from(e)), - }; - - bucket - .filter_map(|res| { - let vector_id = match res { - Ok(vector_id) => vector_id.value(), - Err(e) => return Some(Err(TreeError::from(e))), - }; - - let new_vector = match do_get_vector(&vector_table, vector_id) { - Ok(Ok(opt)) => opt?, - Ok(Err(e)) => return Some(Err(TreeError::from(e))), - Err(e) => return Some(Err(TreeError::from(e))), - }; - - let closeness = query_vector.squared_euclidean_distance(&new_vector); - - if closeness.is_finite() { - Some(Ok((closeness, vector_id, new_vector))) - } else { - None - } - }) - .collect::, _>>() - }) - .collect::)>>, _>>()? - .into_iter() - .flatten() - .collect::>(); - - candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Numbers are finite")); - candidates.dedup_by_key(|a| a.1); - candidates.dedup_by_key(|a| a.2.encode()); - candidates.truncate(limit); - - let vector_multimap = transaction.open_multimap_table(vector_multidef)?; - - let vectors = candidates - .into_iter() - .map(|(_, _, vector)| { - let iter = match vector_multimap.get(&vector.encode()) { - Ok(iter) => iter, - Err(e) => return Err(TreeError::from(e)), - }; - - let iter = iter.map(|res| res.map(|id| id.value()).map_err(TreeError::from)); - - Ok(iter) - }) - .flatten_results() - .collect::, _>>()?; - - Ok(vectors) -} - -#[derive(Debug)] -pub enum TreeError { - Commit(CommitError), - Compact(CompactionError), - Database(DatabaseError), - Decode(DecodeError), - Storage(StorageError), - Table(TableError), - Transaction(TransactionError), -} - -impl From for TreeError { - fn from(value: CommitError) -> Self { - Self::Commit(value) - } -} - -impl From for TreeError { - fn from(value: DatabaseError) -> Self { - Self::Database(value) - } -} - -impl From for TreeError { - fn from(value: DecodeError) -> Self { - Self::Decode(value) - } -} - -impl From for TreeError { - fn from(value: CompactionError) -> Self { - Self::Compact(value) - } -} - -impl From for TreeError { - fn from(value: StorageError) -> Self { - Self::Storage(value) - } -} - -impl From for TreeError { - fn from(value: TableError) -> Self { - Self::Table(value) - } -} - -impl From for TreeError { - fn from(value: TransactionError) -> Self { - Self::Transaction(value) - } -} - -fn build_hypertree<'db, const N: usize>( - transaction: &mut WriteTransaction<'db>, - namespace: &str, - max_bucket_size: usize, -) -> Result<(), TreeError> { - bucket_multimap!(bucket_multidef, namespace, N); - inverse_hyperplane_table!(inverse_hyperplane_def, namespace, N); - hyperplane_table!(hyperplane_def, namespace, N); - hypertree_table!(hypertree_def, namespace, N); - toplevel_table!(toplevel_def, namespace, N); - vector_multimap!(vector_multidef, namespace, N); - vector_table!(vector_def, namespace, N); - - let mut bucket_multimap = transaction.open_multimap_table(bucket_multidef)?; - let mut hyperplane_table = transaction.open_table(hyperplane_def)?; - let mut hypertree_table = transaction.open_table(hypertree_def)?; - let mut ids_table = transaction.open_table(IDS_TABLE)?; - let mut inverse_hyperplane_table = transaction.open_table(inverse_hyperplane_def)?; - let mut toplevel_table = transaction.open_table(toplevel_def)?; - let vector_multimap = transaction.open_multimap_table(vector_multidef)?; - let vector_table = transaction.open_table(vector_def)?; - - do_build_hypertree( - &mut bucket_multimap, - &mut hyperplane_table, - &mut hypertree_table, - &mut ids_table, - &mut inverse_hyperplane_table, - &mut toplevel_table, - &vector_multimap, - &vector_table, - namespace, - max_bucket_size, - ) + Ok(Some(vector)) } fn next_id<'db, 'txn, H>( table: &mut Table<'db, 'txn, &'static str, u128>, - handle: &H, + handle: H, ) -> Result where H: TableHandle, @@ -775,48 +489,14 @@ where Ok(next_id) } -fn next_multimap_id<'db, 'txn, H>( - table: &mut Table<'db, 'txn, &'static str, u128>, - handle: &H, -) -> Result -where - H: MultimapTableHandle, -{ - let mut next_id = table.get(handle.name())?.map(|id| id.value()).unwrap_or(0); - - while let Some(id) = table.insert(handle.name(), next_id + 1)? { - let id = id.value(); - - if id == next_id { - break; - } - - next_id = id; - } - - Ok(next_id) -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -struct VectorBytes<'a, const N: usize> { - vector_bytes: std::borrow::Cow<'a, [u8]>, - _phantom: PhantomData>, -} - -#[derive(Clone, Debug, PartialEq, PartialOrd)] -pub struct Vector([f32; N]); - -#[derive(Debug)] -pub enum DecodeError { - Empty, - InvalidLength, - Bound, -} - type F32ByteArray = [u8; 4]; const F32_BYTE_ARRAY_SIZE: usize = std::mem::size_of::(); +const fn vector_byte_length(n: usize) -> usize { + n * crate::F32_BYTE_ARRAY_SIZE +} + impl Vector { const BYTES_LEN: usize = N * F32_BYTE_ARRAY_SIZE; @@ -857,14 +537,29 @@ impl Vector { Ok(Self(array)) } -} -const fn hyperplane_byte_length(n: usize) -> usize { - (n + 1) * crate::F32_BYTE_ARRAY_SIZE -} + fn average(&self, rhs: &Self) -> Self { + let mut index = 0; -const fn vector_byte_length(n: usize) -> usize { - n * crate::F32_BYTE_ARRAY_SIZE + // TODO: array.zip + Vector([(); N].map(|_| { + let avg = (self.0[index] + rhs.0[index]) / 2.0; + index += 1; + avg + })) + } + + fn dot_product(&self, rhs: &Self) -> f32 { + self.0.iter().zip(rhs.0.iter()).map(|(a, b)| a * b).sum() + } + + pub fn squared_euclidean_distance(&self, rhs: &Self) -> f32 { + self.0 + .iter() + .zip(rhs.0.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum() + } } impl RedbValue for VectorBytes<'static, N> { @@ -910,28 +605,71 @@ impl RedbKey for VectorBytes<'static, N> { } } -impl Vector { - fn average(&self, rhs: &Self) -> Self { - let mut index = 0; +impl RedbValue for VectorId { + type SelfType<'a> = VectorId; + type AsBytes<'a> = ::AsBytes<'a>; - // TODO: array.zip - Vector([(); N].map(|_| { - let avg = (self.0[index] + rhs.0[index]) / 2.0; - index += 1; - avg - })) + fn fixed_width() -> Option { + u128::fixed_width() } - fn dot_product(&self, rhs: &Self) -> f32 { - self.0.iter().zip(rhs.0.iter()).map(|(a, b)| a * b).sum() + fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a> + where + Self: 'a, + { + Self(u128::from_bytes(data)) } - pub fn squared_euclidean_distance(&self, rhs: &Self) -> f32 { - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(a, b)| (a - b).powi(2)) - .sum() + fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a> + where + Self: 'a, + Self: 'b, + { + u128::as_bytes(&value.0) + } + + fn type_name() -> TypeName { + TypeName::new("vectordb::VectorId") + } +} + +impl RedbKey for VectorId { + fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering { + u128::compare(data1, data2) + } +} + +impl RedbValue for InternalVectorId { + type SelfType<'a> = InternalVectorId; + type AsBytes<'a> = ::AsBytes<'a>; + + fn fixed_width() -> Option { + u128::fixed_width() + } + + fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a> + where + Self: 'a, + { + Self(u128::from_bytes(data)) + } + + fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a> + where + Self: 'a, + Self: 'b, + { + u128::as_bytes(&value.0) + } + + fn type_name() -> TypeName { + TypeName::new("vectordb::InternalVectorId") + } +} + +impl RedbKey for InternalVectorId { + fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering { + u128::compare(data1, data2) } } @@ -1046,3 +784,57 @@ impl<'a, const N: usize> std::ops::Sub<&'a Vector> for &'a Vector { })) } } + +impl From for TreeError { + fn from(value: CommitError) -> Self { + Self::Commit(value) + } +} + +impl From for TreeError { + fn from(value: CompactionError) -> Self { + Self::Compact(value) + } +} + +impl From for TreeError { + fn from(value: DatabaseError) -> Self { + Self::Database(value) + } +} + +impl From for TreeError { + fn from(value: DecodeError) -> Self { + Self::Decode(value) + } +} + +impl From for TreeError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From for TreeError { + fn from(value: StorageError) -> Self { + Self::Storage(value) + } +} + +impl From for TreeError { + fn from(value: TableError) -> Self { + Self::Table(value) + } +} + +impl From for TreeError { + fn from(value: TransactionError) -> Self { + Self::Transaction(value) + } +} + +impl std::fmt::Display for VectorId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + u128::fmt(&self.0, f) + } +}