vectordb/src/hypertree/db.rs
2023-07-04 17:39:17 -05:00

1299 lines
37 KiB
Rust

use std::sync::atomic::AtomicBool;
use rand::{seq::IteratorRandom, thread_rng};
use redb::{
Database, MultimapTable, MultimapTableDefinition, MultimapTableHandle, Range,
ReadOnlyMultimapTable, ReadOnlyTable, ReadableMultimapTable, ReadableTable, RedbKey, RedbValue,
StorageError, Table, TableDefinition, TypeName,
};
use crate::{next_id, DecodeError, InternalVectorId, TreeError, Vector, VectorBytes};
#[derive(Debug)]
pub(crate) struct HypertreeRepo<const N: usize> {
database: Database,
ingest_queue: Database,
rebuilding: AtomicBool,
cleanup: AtomicBool,
max_bucket_size: usize,
_size: std::marker::PhantomData<Vector<N>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
struct HyperplaneId(u128);
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
struct BucketId(u128);
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum Bounded {
Above,
Below,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum BoundedByte {
BelowBound,
Below,
Above,
AboveBound,
}
#[derive(Debug)]
struct HyperplaneNode<const N: usize> {
parents: Hypertree<N>,
hyperplane_id: HyperplaneId,
below_id: BucketId,
above_id: BucketId,
}
#[derive(Debug)]
struct Hypertree<const N: usize> {
hyperplanes: Vec<(HyperplaneId, Bounded)>,
_size: std::marker::PhantomData<Hyperplane<N>>,
}
#[derive(Clone, Debug)]
struct HypertreeBytes<'a, const N: usize> {
hyperplanes: std::borrow::Cow<'a, [u8]>,
_size: std::marker::PhantomData<Hypertree<N>>,
}
#[derive(Clone, Debug)]
struct Hyperplane<const N: usize> {
coefficients: Vector<N>,
constant: f32,
}
#[derive(Clone, Debug)]
struct HyperplaneBytes<'a, const N: usize> {
hyperplane_bytes: std::borrow::Cow<'a, [u8]>,
_phantom: std::marker::PhantomData<Hyperplane<N>>,
}
#[derive(Clone, Debug)]
struct HypertreeBytesRange<'a, const N: usize> {
lower: HypertreeBytes<'a, N>,
upper: HypertreeBytes<'a, N>,
}
struct SetOnDrop<'a>(&'a AtomicBool);
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) enum SimilarityStyle {
Similar,
FurthestSimilar,
Dissimilar,
ClosestDissimilar,
}
const IDS_TABLE: TableDefinition<'static, &'static str, u128> =
TableDefinition::new("vectordb::ids");
const BUCKET_MULTIMAP: MultimapTableDefinition<'static, BucketId, InternalVectorId> =
MultimapTableDefinition::new("vectordb::bucket_table");
const REBUILD_TABLE: TableDefinition<'static, &'static str, u64> =
TableDefinition::new("vectordb::rebuild_table");
const REBUILD_KEY: &str = "rebuild";
const fn queue_table<const N: usize>(
) -> TableDefinition<'static, InternalVectorId, VectorBytes<'static, N>> {
TableDefinition::new("vectordb::queue_table")
}
const fn internal_vector_table<const N: usize>(
) -> TableDefinition<'static, InternalVectorId, VectorBytes<'static, N>> {
TableDefinition::new("vectordb::internal_vector_table")
}
const fn hypertree_table_def<const N: usize>(
) -> TableDefinition<'static, HypertreeBytes<'static, N>, BucketId> {
TableDefinition::new("vectordb::hypertree_table")
}
const fn hyperplane_table_def<const N: usize>(
) -> TableDefinition<'static, HyperplaneId, HyperplaneBytes<'static, N>> {
TableDefinition::new("vectordb::hyperplane_table")
}
const fn inverse_hyperplane_table_def<const N: usize>(
) -> TableDefinition<'static, HyperplaneBytes<'static, N>, HyperplaneId> {
TableDefinition::new("vectordb::inverse_hyperplane_table")
}
const fn hyperplane_byte_length(n: usize) -> usize {
(n + 1) * crate::F32_BYTE_ARRAY_SIZE
}
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(level = "trace", skip_all, fields(vector_id = ?insert_vector_id, vector = ?insert_vector))]
fn do_add_to_index<'db, 'txn, const N: usize>(
ids_table: &mut Table<'db, 'txn, &'static str, u128>,
internal_vector_table: &mut Table<'db, 'txn, InternalVectorId, VectorBytes<N>>,
hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes<N>>,
inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes<N>, HyperplaneId>,
hypertree_table: &mut Table<'db, 'txn, HypertreeBytes<'static, N>, BucketId>,
bucket_multimap: &mut MultimapTable<'db, 'txn, BucketId, InternalVectorId>,
insert_vector_id: InternalVectorId,
insert_vector: &Vector<N>,
max_bucket_size: usize,
) -> Result<(), TreeError> {
if internal_vector_table.get(insert_vector_id)?.is_some() {
return Ok(());
}
// This will add vectors to our tree vector store even if we have no tree yet
// This is fine, since rebuilding the tree will include this vector
internal_vector_table.insert(insert_vector_id, insert_vector.encode())?;
let Some(hyperplane_list) = get_first_list(hypertree_table)? else {
return Ok(());
};
let bucket_opt = find_similarity_bucket(
hypertree_table,
hyperplane_table,
hyperplane_list,
insert_vector,
|bound| bound,
)??;
let Some((hyperplane_list, bucket_id)) = bucket_opt else {
// TODO: maybe should error?
unreachable!("Didn't find bucket")
};
bucket_multimap.insert(bucket_id, insert_vector_id)?;
let size = bucket_multimap.get(bucket_id)?.count();
if size > max_bucket_size {
build_hyperplane(
ids_table,
internal_vector_table,
hyperplane_table,
inverse_hyperplane_table,
hypertree_table,
bucket_multimap,
bucket_id,
hyperplane_list,
max_bucket_size,
)?;
}
Ok(())
}
fn next_multimap_id<H>(
table: &mut Table<'_, '_, &'static str, u128>,
handle: H,
) -> Result<u128, StorageError>
where
H: MultimapTableHandle,
{
let mut next_id = table.get(handle.name())?.map(|id| id.value()).unwrap_or(0);
while let Some(id) = table.insert(handle.name(), next_id + 1)? {
let id = id.value();
if id == next_id {
break;
}
next_id = id;
}
Ok(next_id)
}
fn do_build_hypertree<'db, 'txn, const N: usize>(
bucket_multimap: &mut MultimapTable<'db, 'txn, BucketId, InternalVectorId>,
hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes<N>>,
hypertree_table: &mut Table<'db, 'txn, HypertreeBytes<N>, BucketId>,
ids_table: &mut Table<'db, 'txn, &str, u128>,
internal_vector_table: &mut Table<'db, 'txn, InternalVectorId, VectorBytes<N>>,
inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes<N>, HyperplaneId>,
max_bucket_size: usize,
) -> Result<(), TreeError> {
let source_bucket = BucketId(next_multimap_id(ids_table, BUCKET_MULTIMAP)?);
for res in internal_vector_table.iter()? {
let (id, _) = res?;
let id = id.value();
bucket_multimap.insert(source_bucket, id)?;
}
let mut nodes: Vec<HyperplaneNode<N>> = Vec::new();
if let Some(node) = build_hyperplane(
ids_table,
internal_vector_table,
hyperplane_table,
inverse_hyperplane_table,
hypertree_table,
bucket_multimap,
source_bucket,
Hypertree {
hyperplanes: Vec::new(),
_size: std::marker::PhantomData,
},
max_bucket_size,
)? {
nodes.push(node);
}
while let Some(node) = nodes.pop() {
let HyperplaneNode {
parents,
hyperplane_id,
below_id,
above_id,
} = node;
if let Some(node) = build_hyperplane(
ids_table,
internal_vector_table,
hyperplane_table,
inverse_hyperplane_table,
hypertree_table,
bucket_multimap,
below_id,
parents.append(hyperplane_id, Bounded::Below),
max_bucket_size,
)? {
nodes.push(node);
}
if let Some(node) = build_hyperplane(
ids_table,
internal_vector_table,
hyperplane_table,
inverse_hyperplane_table,
hypertree_table,
bucket_multimap,
above_id,
parents.append(hyperplane_id, Bounded::Above),
max_bucket_size,
)? {
nodes.push(node);
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn build_hyperplane<'db, 'txn, const N: usize, T>(
ids_table: &mut Table<'db, 'txn, &'static str, u128>,
vector_table: &T,
hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes<N>>,
inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes<N>, HyperplaneId>,
hypertree_table: &mut Table<'db, 'txn, HypertreeBytes<N>, BucketId>,
bucket_multimap: &mut MultimapTable<'db, 'txn, BucketId, InternalVectorId>,
source_bucket: BucketId,
parents: Hypertree<N>,
max_bucket_size: usize,
) -> Result<Option<HyperplaneNode<N>>, TreeError>
where
T: ReadableTable<InternalVectorId, VectorBytes<'static, N>>,
{
let below_id = BucketId(next_multimap_id(ids_table, BUCKET_MULTIMAP)?);
let above_id = BucketId(next_multimap_id(ids_table, BUCKET_MULTIMAP)?);
if bucket_multimap.get(source_bucket)?.count() <= max_bucket_size {
return Ok(None);
}
let mut samples = bucket_multimap
.get(source_bucket)?
.choose_multiple(&mut thread_rng(), 2);
if samples.len() != 2 {
panic!("Invalid number of samples chosen");
}
let a_id = samples.pop().expect("Exists")?.value();
let b_id = samples.pop().expect("Exists")?.value();
drop(samples);
let (a, b) = (
vector_table.get(a_id)?.expect("Vector exists"),
vector_table.get(b_id)?.expect("Vector exists"),
);
let (a, b) = (&Vector::decode(&a.value())?, &Vector::decode(&b.value())?);
let coefficients = a - b;
let point_on_plane = a.average(b);
let constant = -coefficients.dot_product(&point_on_plane);
let hyperplane = Hyperplane {
coefficients,
constant,
};
let hyperplane_id = insert_hyperplane(
ids_table,
hyperplane_table,
inverse_hyperplane_table,
&hyperplane,
)?;
let below_list = parents.append(hyperplane_id, Bounded::Below);
let above_list = parents.append(hyperplane_id, Bounded::Above);
let parents_bytes = parents.encode();
let below_list_bytes = below_list.encode();
let above_list_bytes = above_list.encode();
hypertree_table.remove(parents_bytes)?;
hypertree_table.insert(below_list_bytes, below_id)?;
hypertree_table.insert(above_list_bytes, above_id)?;
let mut below = Vec::new();
let mut above = Vec::new();
for res in bucket_multimap.remove_all(source_bucket)? {
let id = res?.value();
let vector_bytes = vector_table.get(id)?.expect("Vector exists");
let vector = Vector::decode(&vector_bytes.value())?;
if vector.direction_from(&hyperplane) == Bounded::Below {
below.push(id);
} else {
above.push(id);
}
}
for id in below {
bucket_multimap.insert(below_id, id)?;
}
for id in above {
bucket_multimap.insert(above_id, id)?;
}
Ok(Some(HyperplaneNode {
parents,
hyperplane_id,
below_id,
above_id,
}))
}
fn insert_hyperplane<'db, 'txn, const N: usize>(
ids_table: &mut Table<'db, 'txn, &'static str, u128>,
hyperplane_table: &mut Table<'db, 'txn, HyperplaneId, HyperplaneBytes<N>>,
inverse_hyperplane_table: &mut Table<'db, 'txn, HyperplaneBytes<N>, HyperplaneId>,
hyperplane: &Hyperplane<N>,
) -> Result<HyperplaneId, TreeError> {
let encoded = hyperplane.encode();
if let Some(id) = inverse_hyperplane_table.get(&encoded)? {
return Ok(id.value());
}
let id = HyperplaneId(next_id(ids_table, hyperplane_table_def::<N>())?);
hyperplane_table.insert(id, &encoded)?;
inverse_hyperplane_table.insert(encoded, id)?;
Ok(id)
}
#[allow(clippy::too_many_arguments)]
fn get_similar_vectors<'txn, const N: usize, T>(
hypertree_table: &ReadOnlyTable<'txn, HypertreeBytes<N>, BucketId>,
hyperplane_table: &ReadOnlyTable<'txn, HyperplaneId, HyperplaneBytes<N>>,
bucket_multimap: &ReadOnlyMultimapTable<'txn, BucketId, InternalVectorId>,
vector_table: &T,
query_vector: &Vector<N>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
) -> Result<Vec<(f32, InternalVectorId, Vector<N>)>, TreeError>
where
T: ReadableTable<InternalVectorId, VectorBytes<'static, N>>,
{
let Some(hyperplane_list) = get_first_list(hypertree_table)? else {
return Ok(vec![]);
};
let bucket_opt = find_similarity_bucket(
hypertree_table,
hyperplane_table,
hyperplane_list,
query_vector,
|bound| similarity_style.for_bound(bound),
)??;
let Some((_, bucket_id)) = bucket_opt else {
return Ok(vec![]);
};
let bucket = match bucket_multimap.get(bucket_id) {
Ok(bucket) => bucket,
Err(e) => return Err(TreeError::from(e)),
};
let size = bucket.count();
tracing::debug!("candidates from bucket {}, size {}", bucket_id, size);
let bucket = match bucket_multimap.get(bucket_id) {
Ok(bucket) => bucket,
Err(e) => return Err(TreeError::from(e)),
};
let mut candidates = bucket
.filter_map(|res| {
let vector_id = match res {
Ok(vector_id) => vector_id.value(),
Err(e) => return Some(Err(TreeError::from(e))),
};
let new_vector = match vector_table.get(vector_id).transpose()? {
Ok(vector_bytes) => match Vector::decode(&vector_bytes.value()) {
Ok(vector) => vector,
Err(e) => return Some(Err(TreeError::from(e))),
},
Err(e) => return Some(Err(TreeError::from(e))),
};
let closeness = query_vector.squared_euclidean_distance(&new_vector);
if closeness.is_finite() {
if let Some(threshold) = threshold {
match similarity_style {
SimilarityStyle::Similar | SimilarityStyle::ClosestDissimilar
if closeness < threshold =>
{
Some(Ok((closeness, vector_id, new_vector)))
}
SimilarityStyle::Dissimilar | SimilarityStyle::FurthestSimilar
if closeness > threshold =>
{
Some(Ok((closeness, vector_id, new_vector)))
}
_ => None,
}
} else {
Some(Ok((closeness, vector_id, new_vector)))
}
} else {
None
}
})
.collect::<Result<Vec<_>, _>>()?;
match similarity_style {
SimilarityStyle::Similar | SimilarityStyle::ClosestDissimilar => {
candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Numbers are finite"));
}
SimilarityStyle::Dissimilar | SimilarityStyle::FurthestSimilar => {
candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Numbers are finite"));
}
}
candidates.dedup_by_key(|a| a.1);
candidates.truncate(limit);
Ok(candidates)
}
fn get_first_list<const N: usize, T>(hypertree_table: &T) -> Result<Option<Hypertree<N>>, TreeError>
where
T: ReadableTable<HypertreeBytes<'static, N>, BucketId>,
{
let Some(res) = hypertree_table.iter()?.next() else {
return Ok(None);
};
let (k, _) = res?;
let list = Hypertree::decode(&k.value())?;
Ok(Some(list))
}
fn find_similarity_bucket<const N: usize, T, U, F>(
hypertree_table: &T,
hyperplane_table: &U,
mut hyperplane_list: Hypertree<N>,
vector: &Vector<N>,
tweak: F,
) -> Result<Result<Option<(Hypertree<N>, BucketId)>, DecodeError>, StorageError>
where
T: ReadableTable<HypertreeBytes<'static, N>, BucketId>,
U: ReadableTable<HyperplaneId, HyperplaneBytes<'static, N>>,
F: Fn(Bounded) -> Bounded,
{
let mut depth = 0;
while let Some(hyperplane_id) = hyperplane_list.get(depth) {
let Some(hyperplane_bytes) = hyperplane_table.get(hyperplane_id)? else {
return Ok(Ok(None));
};
let hyperplane_bytes = hyperplane_bytes.value();
let hyperplane = match Hyperplane::decode(&hyperplane_bytes) {
Ok(hyperplane) => hyperplane,
Err(e) => return Ok(Err(e)),
};
let direction = (tweak)(vector.direction_from(&hyperplane));
if hyperplane_list.matches(depth, direction) {
depth += 1;
continue;
}
match next_list(hypertree_table, &hyperplane_list, depth, direction)? {
Ok(Some(list)) => {
hyperplane_list = list;
depth += 1;
continue;
}
Ok(None) => {
break;
}
Err(e) => return Ok(Err(e)),
}
}
hypertree_table
.get(hyperplane_list.encode())
.map(|opt| Ok(opt.map(|guard| (hyperplane_list, guard.value()))))
}
fn next_list<'table, const N: usize, T, V>(
table: &'table T,
hyperplane_list: &Hypertree<N>,
depth: usize,
direction: Bounded,
) -> Result<Result<Option<Hypertree<N>>, DecodeError>, StorageError>
where
T: ReadableTable<HypertreeBytes<'static, N>, V> + 'table,
V: RedbValue + 'static,
{
let mut range = scan(table, hyperplane_list, depth, direction)?;
let Some(res) = range.next() else {
return Ok(Ok(None));
};
let (hyperplane_list_bytes, _) = res?;
Ok(Hypertree::decode(&hyperplane_list_bytes.value()).map(Some))
}
fn scan<'table, const N: usize, T, V>(
table: &'table T,
hyperplane_list: &Hypertree<N>,
depth: usize,
direction: Bounded,
) -> Result<Range<'table, HypertreeBytes<'static, N>, V>, StorageError>
where
T: ReadableTable<HypertreeBytes<'static, N>, V> + 'table,
V: RedbValue,
{
let range = hyperplane_list.to_range(depth, direction);
table.range(range)
}
impl<const N: usize> HypertreeRepo<N> {
pub(crate) fn open<P: AsRef<std::path::Path>>(
tree_dir: P,
max_bucket_size: usize,
) -> Result<Self, TreeError> {
std::fs::create_dir_all(tree_dir.as_ref())?;
let mut database = Database::create(tree_dir.as_ref().join("hypertree.redb"))?;
let mut ingest_queue = Database::create(tree_dir.as_ref().join("ingest.redb"))?;
database.check_integrity()?;
ingest_queue.check_integrity()?;
database.compact()?;
ingest_queue.compact()?;
let txn = database.begin_write()?;
txn.open_table(REBUILD_TABLE)?;
txn.commit()?;
Ok(Self {
database,
ingest_queue,
rebuilding: AtomicBool::new(false),
cleanup: AtomicBool::new(false),
max_bucket_size,
_size: std::marker::PhantomData,
})
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn find_similarities(
&self,
query_vector: &Vector<N>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
) -> Result<Vec<(f32, InternalVectorId, Vector<N>)>, TreeError> {
let txn = self.database.begin_read()?;
let bucket_multimap = txn.open_multimap_table(BUCKET_MULTIMAP)?;
let hyperplane_table = txn.open_table(hyperplane_table_def())?;
let hypertree_table = txn.open_table(hypertree_table_def())?;
let internal_vector_table = txn.open_table(internal_vector_table())?;
get_similar_vectors(
&hypertree_table,
&hyperplane_table,
&bucket_multimap,
&internal_vector_table,
query_vector,
threshold,
limit,
similarity_style,
)
}
#[tracing::instrument(level = "trace", skip_all)]
pub(super) fn add_many_to_index(
&self,
vectors: &[(InternalVectorId, Vector<N>)],
) -> Result<(), TreeError> {
if self.rebuilding.load(std::sync::atomic::Ordering::Acquire) {
let txn = self.ingest_queue.begin_write()?;
let mut queue_table = txn.open_table(queue_table())?;
for (vector_id, vector) in vectors {
queue_table.insert(*vector_id, vector.encode())?;
}
drop(queue_table);
txn.commit()?;
return Ok(());
}
let txn = self.database.begin_write()?;
let mut ids_table = txn.open_table(IDS_TABLE)?;
let mut hyperplane_table = txn.open_table(hyperplane_table_def())?;
let mut inverse_hyperplane_table = txn.open_table(inverse_hyperplane_table_def())?;
let mut hypertree_table = txn.open_table(hypertree_table_def())?;
let mut bucket_multimap = txn.open_multimap_table(BUCKET_MULTIMAP)?;
let mut internal_vector_table = txn.open_table(internal_vector_table())?;
for (vector_id, vector) in vectors {
do_add_to_index(
&mut ids_table,
&mut internal_vector_table,
&mut hyperplane_table,
&mut inverse_hyperplane_table,
&mut hypertree_table,
&mut bucket_multimap,
*vector_id,
vector,
self.max_bucket_size,
)?;
}
drop(internal_vector_table);
drop(bucket_multimap);
drop(hypertree_table);
drop(inverse_hyperplane_table);
drop(hyperplane_table);
drop(ids_table);
txn.commit()?;
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
fn force_rebuild_hypertree(&self) -> Result<(), TreeError> {
let start = std::time::Instant::now();
let txn = self.database.begin_write()?;
let mut bucket_multimap = txn.open_multimap_table(BUCKET_MULTIMAP)?;
let mut hyperplane_table = txn.open_table(hyperplane_table_def())?;
let mut hypertree_table = txn.open_table(hypertree_table_def())?;
let mut ids_table = txn.open_table(IDS_TABLE)?;
let mut internal_vector_table = txn.open_table(internal_vector_table())?;
let mut inverse_hyperplane_table = txn.open_table(inverse_hyperplane_table_def())?;
let mut rebuild_table = txn.open_table(REBUILD_TABLE)?;
let vector_len = internal_vector_table.len()?;
if vector_len == 0 {
tracing::trace!("Not rebuilding - no vectors");
return Ok(());
}
if vector_len < u64::try_from(self.max_bucket_size).expect("bucket size is reasonable") / 2
{
tracing::trace!("Not rebuilding - small vector count");
return Ok(());
}
tracing::debug!("Rebuilding hypertree: {vector_len} vectors");
rebuild_table.insert(REBUILD_KEY, vector_len)?;
ids_table.drain::<&'static str>(..)?;
hyperplane_table.drain::<HyperplaneId>(..)?;
inverse_hyperplane_table.drain::<HyperplaneBytes<'static, N>>(..)?;
hypertree_table.drain::<HypertreeBytes<'static, N>>(..)?;
let buckets = bucket_multimap
.iter()?
.map(|res| res.map(|(k, _)| k.value()))
.collect::<Result<Vec<_>, _>>()?;
for bucket in buckets {
bucket_multimap.remove_all(bucket)?;
}
do_build_hypertree(
&mut bucket_multimap,
&mut hyperplane_table,
&mut hypertree_table,
&mut ids_table,
&mut internal_vector_table,
&mut inverse_hyperplane_table,
self.max_bucket_size,
)?;
drop(rebuild_table);
drop(inverse_hyperplane_table);
drop(internal_vector_table);
drop(ids_table);
drop(hypertree_table);
drop(hyperplane_table);
drop(bucket_multimap);
tracing::debug!(
"Rebuilding hypertree: {vector_len} vectors - took {:?}",
start.elapsed()
);
txn.commit()?;
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn rebuild_hypertree(&self) -> Result<bool, TreeError> {
if self
.rebuilding
.swap(true, std::sync::atomic::Ordering::AcqRel)
{
return Ok(false);
}
let rebuild_guard = SetOnDrop(&self.rebuilding);
if self.cleanup.load(std::sync::atomic::Ordering::Acquire) {
return Ok(false);
}
let txn = self.database.begin_read()?;
let rebuild_table = txn.open_table(REBUILD_TABLE)?;
let internal_vector_table = txn.open_table(internal_vector_table::<N>())?;
let vector_len = internal_vector_table.len()?;
if let Some(previous_size) = rebuild_table.get(REBUILD_KEY)? {
if vector_len / 2 < previous_size.value() {
return Ok(false);
}
}
self.force_rebuild_hypertree()?;
drop(rebuild_guard);
Ok(true)
}
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn cleanup(&self) -> Result<(), TreeError> {
let cleanup_guard = SetOnDrop(&self.cleanup);
self.cleanup
.store(true, std::sync::atomic::Ordering::Release);
let txn = self.ingest_queue.begin_write()?;
let mut queue_table = txn.open_table(queue_table())?;
let mut chunk = Vec::with_capacity(1024);
for res in queue_table.drain::<InternalVectorId>(..)? {
let (id, vector_bytes) = res?;
let vector = Vector::decode(&vector_bytes.value())?;
chunk.push((id.value(), vector));
if chunk.len() < 1024 {
continue;
}
self.add_many_to_index(&chunk)?;
chunk.clear();
}
if !chunk.is_empty() {
self.add_many_to_index(&chunk)?;
}
drop(queue_table);
txn.commit()?;
drop(cleanup_guard);
Ok(())
}
}
impl<const N: usize> Hyperplane<N> {
const BYTES_LEN: usize = (N + 1) * crate::F32_BYTE_ARRAY_SIZE;
const CONSTANT_OFFSET: usize = N * crate::F32_BYTE_ARRAY_SIZE;
fn encode(&self) -> HyperplaneBytes<'static, N> {
let vec = self
.coefficients
.0
.iter()
.flat_map(|f| f.to_be_bytes())
.chain(self.constant.to_be_bytes())
.collect();
HyperplaneBytes {
hyperplane_bytes: std::borrow::Cow::Owned(vec),
_phantom: std::marker::PhantomData,
}
}
fn decode(
HyperplaneBytes {
hyperplane_bytes,
_phantom,
}: &HyperplaneBytes<N>,
) -> Result<Self, DecodeError> {
if hyperplane_bytes.is_empty() {
return Err(DecodeError::Empty);
}
if hyperplane_bytes.len() != Self::BYTES_LEN {
return Err(DecodeError::InvalidLength);
}
let coefficients = Vector::<N>::decode(&VectorBytes {
vector_bytes: std::borrow::Cow::Borrowed(&hyperplane_bytes[..Self::CONSTANT_OFFSET]),
_phantom: std::marker::PhantomData,
})?;
let constant = f32::from_be_bytes(
hyperplane_bytes[Self::CONSTANT_OFFSET..]
.try_into()
.expect("f32 byte array size is correct"),
);
Ok(Self {
coefficients,
constant,
})
}
}
impl<const N: usize> Hypertree<N> {
const SEGMENT_LEN: usize = 17; // u128 + u8
fn append(&self, hyperplane_id: HyperplaneId, bound: Bounded) -> Self {
let mut hyperplanes = self.hyperplanes.clone();
hyperplanes.push((hyperplane_id, bound));
Hypertree {
hyperplanes,
_size: std::marker::PhantomData,
}
}
fn matches(&self, depth: usize, direction: Bounded) -> bool {
self.hyperplanes
.get(depth)
.map(|(_, bounded)| direction == *bounded)
.unwrap_or(false)
}
fn get(&self, depth: usize) -> Option<HyperplaneId> {
self.hyperplanes.get(depth).map(|(h, _)| *h)
}
fn to_range(&self, depth: usize, direction: Bounded) -> HypertreeBytesRange<'static, N> {
let bound_capacity = (depth + 1) * Self::SEGMENT_LEN;
let mut lower = Vec::with_capacity(bound_capacity);
let mut upper = Vec::with_capacity(bound_capacity);
for (i, (h, b)) in self.hyperplanes.iter().enumerate() {
match i.cmp(&depth) {
std::cmp::Ordering::Equal => {
let bytes = h.to_be_bytes();
lower.extend_from_slice(&bytes[..]);
upper.extend_from_slice(&bytes[..]);
let (lower_bound, upper_bound) = direction.to_range_bounds();
lower.push(lower_bound.to_byte());
upper.push(upper_bound.to_byte());
}
std::cmp::Ordering::Less => {
let bytes = h.to_be_bytes();
lower.extend_from_slice(&bytes);
upper.extend_from_slice(&bytes);
lower.push(b.encode().to_byte());
}
std::cmp::Ordering::Greater => {
break;
}
}
}
HypertreeBytesRange {
lower: HypertreeBytes {
hyperplanes: std::borrow::Cow::Owned(lower),
_size: std::marker::PhantomData,
},
upper: HypertreeBytes {
hyperplanes: std::borrow::Cow::Owned(upper),
_size: std::marker::PhantomData,
},
}
}
fn encode(&self) -> HypertreeBytes<'static, N> {
let capacity = self.hyperplanes.len() * Self::SEGMENT_LEN;
let mut bytes = Vec::with_capacity(capacity);
for (h, b) in &self.hyperplanes {
bytes.extend(h.to_be_bytes());
bytes.push(b.encode().to_byte());
}
HypertreeBytes {
hyperplanes: std::borrow::Cow::Owned(bytes),
_size: std::marker::PhantomData,
}
}
fn decode(
HypertreeBytes { hyperplanes, _size }: &HypertreeBytes<N>,
) -> Result<Self, DecodeError> {
if hyperplanes.len() % Self::SEGMENT_LEN != 0 {
return Err(DecodeError::InvalidLength);
}
Ok(Hypertree {
hyperplanes: hyperplanes
.chunks_exact(Self::SEGMENT_LEN)
.map(|chunk| {
let mut id_bytes = [0; 16];
for (slot, byte) in id_bytes.iter_mut().zip(&chunk[..16]) {
*slot = *byte;
}
let hyperplane_id = HyperplaneId::from_be_bytes(id_bytes);
let bounded = Bounded::decode(BoundedByte::from_byte(chunk[16]))?;
Ok((hyperplane_id, bounded))
})
.collect::<Result<_, _>>()?,
_size: std::marker::PhantomData,
})
}
}
impl HyperplaneId {
fn from_be_bytes(bytes: [u8; 16]) -> Self {
Self(u128::from_be_bytes(bytes))
}
fn to_be_bytes(self) -> [u8; 16] {
self.0.to_be_bytes()
}
}
impl Bounded {
const fn to_range_bounds(self) -> (BoundedByte, BoundedByte) {
match self {
Self::Below => (BoundedByte::Below, BoundedByte::Above),
Self::Above => (BoundedByte::Above, BoundedByte::AboveBound),
}
}
const fn encode(self) -> BoundedByte {
match self {
Self::Below => BoundedByte::Below,
Self::Above => BoundedByte::Above,
}
}
const fn decode(bounded_byte: BoundedByte) -> Result<Self, DecodeError> {
match bounded_byte {
BoundedByte::BelowBound => Err(DecodeError::Bound),
BoundedByte::Below => Ok(Bounded::Below),
BoundedByte::Above => Ok(Bounded::Above),
BoundedByte::AboveBound => Err(DecodeError::Bound),
}
}
}
impl BoundedByte {
fn from_byte(byte: u8) -> Self {
match byte {
0 => BoundedByte::BelowBound,
1 => BoundedByte::Below,
2 => BoundedByte::Above,
3 => BoundedByte::AboveBound,
_ => panic!("Invalid byte value for BoundedByte"),
}
}
fn to_byte(self) -> u8 {
match self {
Self::BelowBound => 0,
Self::Below => 1,
Self::Above => 2,
Self::AboveBound => 3,
}
}
}
impl<const N: usize> Vector<N> {
fn direction_from(&self, hyperplane: &Hyperplane<N>) -> Bounded {
if hyperplane.coefficients.dot_product(self) + hyperplane.constant >= 0.0 {
Bounded::Above
} else {
Bounded::Below
}
}
}
impl SimilarityStyle {
const fn for_bound(self, bound: Bounded) -> Bounded {
match self {
Self::Similar | Self::FurthestSimilar => bound,
Self::Dissimilar | Self::ClosestDissimilar => match bound {
Bounded::Below => Bounded::Above,
Bounded::Above => Bounded::Below,
},
}
}
}
impl<'a, const N: usize> std::ops::RangeBounds<HypertreeBytes<'a, N>>
for HypertreeBytesRange<'a, N>
{
fn start_bound(&self) -> std::ops::Bound<&HypertreeBytes<'a, N>> {
std::ops::Bound::Included(&self.lower)
}
fn end_bound(&self) -> std::ops::Bound<&HypertreeBytes<'a, N>> {
std::ops::Bound::Excluded(&self.upper)
}
}
impl<const N: usize> RedbValue for HyperplaneBytes<'static, N> {
type SelfType<'a> = HyperplaneBytes<'a, N>;
type AsBytes<'a> = &'a [u8];
fn fixed_width() -> Option<usize> {
None
}
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
assert_eq!(
data.len(),
hyperplane_byte_length(N),
"Byte length is not hyperplane byte length"
);
HyperplaneBytes {
hyperplane_bytes: std::borrow::Cow::Borrowed(data),
_phantom: std::marker::PhantomData,
}
}
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
&value.hyperplane_bytes
}
fn type_name() -> TypeName {
TypeName::new("vectordb::HyperplaneBytes")
}
}
impl<const N: usize> RedbKey for HyperplaneBytes<'static, N> {
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
data1.cmp(data2)
}
}
impl<const N: usize> RedbValue for HypertreeBytes<'static, N> {
type SelfType<'a> = HypertreeBytes<'a, N>;
type AsBytes<'a> = &'a [u8];
fn fixed_width() -> Option<usize> {
None
}
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
assert_eq!(
data.len() % Hypertree::<N>::SEGMENT_LEN,
0,
"Byte length is not Hyperplane List byte length"
);
HypertreeBytes {
hyperplanes: std::borrow::Cow::Borrowed(data),
_size: std::marker::PhantomData,
}
}
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
&value.hyperplanes
}
fn type_name() -> TypeName {
TypeName::new("vectordb::HypertreeBytes")
}
}
impl<const N: usize> RedbKey for HypertreeBytes<'static, N> {
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
data1.cmp(data2)
}
}
impl RedbValue for HyperplaneId {
type SelfType<'a> = HyperplaneId;
type AsBytes<'a> = <u128 as RedbValue>::AsBytes<'a>;
fn fixed_width() -> Option<usize> {
u128::fixed_width()
}
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
Self(u128::from_bytes(data))
}
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
u128::as_bytes(&value.0)
}
fn type_name() -> TypeName {
TypeName::new("vectordb::HyperplaneId")
}
}
impl RedbKey for HyperplaneId {
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
u128::compare(data1, data2)
}
}
impl RedbValue for BucketId {
type SelfType<'a> = BucketId;
type AsBytes<'a> = <u128 as RedbValue>::AsBytes<'a>;
fn fixed_width() -> Option<usize> {
u128::fixed_width()
}
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
Self(u128::from_bytes(data))
}
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
u128::as_bytes(&value.0)
}
fn type_name() -> TypeName {
TypeName::new("vectordb::BucketId")
}
}
impl RedbKey for BucketId {
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
u128::compare(data1, data2)
}
}
impl std::fmt::Display for HyperplaneId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
u128::fmt(&self.0, f)
}
}
impl std::fmt::Display for BucketId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
u128::fmt(&self.0, f)
}
}
impl<'a> Drop for SetOnDrop<'a> {
fn drop(&mut self) {
self.0.store(false, std::sync::atomic::Ordering::Release);
}
}