vectordb/src/lib.rs

1287 lines
36 KiB
Rust

mod hypertree;
mod pool;
mod reaper;
mod thread;
use std::{collections::HashMap, marker::PhantomData, path::Path};
use hypertree::{HypertreeMeta, SimilarityStyle};
use pool::Pool;
use reaper::Reaper;
use redb::{
CommitError, CompactionError, Database, DatabaseError, MultimapTable, MultimapTableDefinition,
ReadTransaction, ReadableMultimapTable, ReadableTable, RedbKey, RedbValue, StorageError, Table,
TableDefinition, TableError, TableHandle, TransactionError, TypeName, WriteTransaction,
};
use thread::Thread;
#[derive(Clone, Debug)]
pub struct VectorDb<const N: usize> {
inner: std::sync::Arc<VectorMeta<N>>,
_reaper: Reaper,
}
#[derive(Debug)]
struct VectorMeta<const N: usize> {
reader: Pool<VectorRepo<N>, ReaderCommand<N>>,
insert_vector: Pool<VectorRepo<N>, InsertVectorsCommand<N>>,
_maintenance: Thread,
}
enum ReaderCommand<const N: usize> {
FindSimilarities {
vector: std::sync::Arc<Vector<N>>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
responder: oneshot::Sender<Result<Vec<VectorId>, TreeError>>,
parent: tracing::Span,
},
GetVectors {
vector_ids: Vec<VectorId>,
responder: oneshot::Sender<Result<HashMap<VectorId, Vector<N>>, TreeError>>,
parent: tracing::Span,
},
}
struct InsertVectorsCommand<const N: usize> {
vectors: Vec<Vector<N>>,
responder: oneshot::Sender<Result<Vec<VectorId>, TreeError>>,
parent: tracing::Span,
}
#[derive(Debug)]
struct VectorRepo<const N: usize> {
database: Database,
hypertrees: Vec<HypertreeMeta<N>>,
maintenance_sender: flume::Sender<()>,
last_rebuild: std::sync::atomic::AtomicU64,
}
#[derive(Clone, Debug, PartialEq, PartialOrd)]
pub struct Vector<const N: usize>([f32; N]);
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct VectorBytes<'a, const N: usize> {
vector_bytes: std::borrow::Cow<'a, [u8]>,
_phantom: PhantomData<Vector<N>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct VectorId(u128);
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
struct InternalVectorId(u128);
#[derive(Debug)]
pub enum DecodeError {
Empty,
InvalidLength,
Bound,
}
#[derive(Debug)]
pub enum TreeError {
Commit(CommitError),
Compact(CompactionError),
Database(DatabaseError),
Decode(DecodeError),
Io(std::io::Error),
Storage(StorageError),
Table(TableError),
Transaction(TransactionError),
Canceled,
Closed,
}
const IDS_TABLE: TableDefinition<'static, &'static str, u128> =
TableDefinition::new("vectordb::ids");
const VECTOR_INSERT_TABLE: TableDefinition<'static, VectorId, InternalVectorId> =
TableDefinition::new("vectordb::vector_insert_table");
const INVERSE_VECTOR_INSERT_MULTIMAP: MultimapTableDefinition<'static, InternalVectorId, VectorId> =
MultimapTableDefinition::new("vectordb::inverse_vector_insert_multimap");
const fn inverse_vector_table_definition<const N: usize>(
) -> TableDefinition<'static, VectorBytes<'static, N>, InternalVectorId> {
TableDefinition::new("vectordb::inverse_vector_table")
}
const fn vector_table_definition<const N: usize>(
) -> TableDefinition<'static, InternalVectorId, VectorBytes<'static, N>> {
TableDefinition::new("vectordb::vector_table")
}
impl<const N: usize> VectorDb<N> {
#[tracing::instrument(level = "debug", skip(db_directory), fields(directory = ?db_directory.as_ref()))]
pub fn open<P: AsRef<Path>>(
db_directory: P,
target_hypertree_count: usize,
) -> Result<Self, TreeError> {
let reaper = Reaper::new();
let (tx, rx) = flume::bounded(1);
let repo = VectorRepo::open(db_directory, target_hypertree_count, reaper.clone(), tx)?;
let meta = VectorMeta::build(repo, reaper.clone(), rx);
Ok(Self {
inner: std::sync::Arc::new(meta),
_reaper: reaper,
})
}
#[tracing::instrument(level = "debug", skip(self))]
pub fn insert_vector_blocking(&self, vector: Vector<N>) -> Result<VectorId, TreeError> {
let vec = self.inner.insert_vectors_blocking(vec![vector])?;
Ok(vec[0])
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn insert_vector(&self, vector: Vector<N>) -> Result<VectorId, TreeError> {
let vec = self.inner.insert_vectors(vec![vector]).await?;
Ok(vec[0])
}
#[tracing::instrument(level = "debug", skip(self, vectors))]
pub fn insert_vectors_blocking(
&self,
vectors: Vec<Vector<N>>,
) -> Result<Vec<VectorId>, TreeError> {
self.inner.insert_vectors_blocking(vectors)
}
#[tracing::instrument(level = "debug", skip(self, vectors))]
pub async fn insert_vectors(
&self,
vectors: Vec<Vector<N>>,
) -> Result<Vec<VectorId>, TreeError> {
self.inner.insert_vectors(vectors).await
}
#[tracing::instrument(level = "debug", skip(self))]
pub fn find_similarities_blocking(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner
.find_similarities_blocking(vector, threshold, limit, SimilarityStyle::Similar)
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn find_similarities(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner
.find_similarities(vector, threshold, limit, SimilarityStyle::Similar)
.await
}
#[tracing::instrument(level = "debug", skip(self))]
pub fn find_furthest_similarities_blocking(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner.find_similarities_blocking(
vector,
threshold,
limit,
SimilarityStyle::FurthestSimilar,
)
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn find_furthest_similarities(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner
.find_similarities(vector, threshold, limit, SimilarityStyle::FurthestSimilar)
.await
}
#[tracing::instrument(level = "debug", skip(self))]
pub fn find_dissimilarities_blocking(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner
.find_similarities_blocking(vector, threshold, limit, SimilarityStyle::Dissimilar)
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn find_dissimilarities(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner
.find_similarities(vector, threshold, limit, SimilarityStyle::Dissimilar)
.await
}
#[tracing::instrument(level = "debug", skip(self))]
pub fn find_closest_dissimilarities_blocking(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner.find_similarities_blocking(
vector,
threshold,
limit,
SimilarityStyle::ClosestDissimilar,
)
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn find_closest_dissimilarities(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
) -> Result<Vec<VectorId>, TreeError> {
self.inner
.find_similarities(vector, threshold, limit, SimilarityStyle::ClosestDissimilar)
.await
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get_vector(&self, vector_id: VectorId) -> Result<Option<Vector<N>>, TreeError> {
let mut hm = self.inner.get_vectors(vec![vector_id]).await?;
Ok(hm.remove(&vector_id))
}
#[tracing::instrument(level = "debug", skip(self))]
pub fn get_vector_blocking(&self, vector_id: VectorId) -> Result<Option<Vector<N>>, TreeError> {
let mut hm = self.inner.get_vectors_blocking(vec![vector_id])?;
Ok(hm.remove(&vector_id))
}
#[tracing::instrument(level = "debug", skip(self, vector_ids))]
pub async fn get_vectors(
&self,
vector_ids: Vec<VectorId>,
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
self.inner.get_vectors(vector_ids).await
}
#[tracing::instrument(level = "debug", skip(self, vector_ids))]
pub fn get_vectors_blocking(
&self,
vector_ids: Vec<VectorId>,
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
self.inner.get_vectors_blocking(vector_ids)
}
}
impl<const N: usize> VectorMeta<N> {
fn build(
repo: VectorRepo<N>,
reaper: Reaper,
maintenance_receiver: flume::Receiver<()>,
) -> Self {
let state = std::sync::Arc::new(repo);
let parallelism = std::thread::available_parallelism()
.map(usize::from)
.unwrap_or(1);
let reader = Pool::builder(std::sync::Arc::clone(&state), reaper.clone(), reader_runner)
.with_lower_limit(2)
.with_upper_limit(16 * parallelism)
.with_name(String::from("vectordb-reader"))
.finish();
let insert_vector =
Pool::builder(std::sync::Arc::clone(&state), reaper, insert_vectors_runner)
.with_name(String::from("vectordb-inserter"))
.finish();
let maintenance =
Thread::build(String::from("vectordb-maintenance")).spawn(move |stopper| {
let mut index = 0;
with_stopper(maintenance_receiver, stopper, |()| {
if let Err(e) = state.hypertrees[index].perform_maintenance() {
tracing::warn!("Error peforming maintenance: {e:?}");
}
index += 1;
if index >= state.hypertrees.len() {
index = 0;
}
})
});
Self {
reader,
insert_vector,
_maintenance: maintenance,
}
}
async fn insert_vectors(&self, vectors: Vec<Vector<N>>) -> Result<Vec<VectorId>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.insert_vector
.send_async(InsertVectorsCommand {
vectors,
responder,
parent: tracing::Span::current(),
})
.await
.is_err()
{
return Err(TreeError::Closed);
}
rx.await?
}
fn insert_vectors_blocking(&self, vectors: Vec<Vector<N>>) -> Result<Vec<VectorId>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.insert_vector
.send_blocking(InsertVectorsCommand {
vectors,
responder,
parent: tracing::Span::current(),
})
.is_err()
{
return Err(TreeError::Closed);
}
rx.recv()?
}
async fn get_vectors(
&self,
vector_ids: Vec<VectorId>,
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.reader
.send_async(ReaderCommand::GetVectors {
vector_ids,
responder,
parent: tracing::Span::current(),
})
.await
.is_err()
{
return Err(TreeError::Closed);
}
rx.await?
}
fn get_vectors_blocking(
&self,
vector_ids: Vec<VectorId>,
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.reader
.send_blocking(ReaderCommand::GetVectors {
vector_ids,
responder,
parent: tracing::Span::current(),
})
.is_err()
{
return Err(TreeError::Closed);
}
rx.recv()?
}
async fn find_similarities(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
) -> Result<Vec<VectorId>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.reader
.send_async(ReaderCommand::FindSimilarities {
vector: std::sync::Arc::new(vector),
threshold,
limit,
similarity_style,
responder,
parent: tracing::Span::current(),
})
.await
.is_err()
{
return Err(TreeError::Closed);
}
rx.await?
}
fn find_similarities_blocking(
&self,
vector: Vector<N>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
) -> Result<Vec<VectorId>, TreeError> {
let (responder, rx) = oneshot::channel();
if self
.reader
.send_blocking(ReaderCommand::FindSimilarities {
vector: std::sync::Arc::new(vector),
threshold,
limit,
similarity_style,
responder,
parent: tracing::Span::current(),
})
.is_err()
{
return Err(TreeError::Closed);
}
rx.recv()?
}
}
impl<const N: usize> VectorRepo<N> {
fn open<P: AsRef<Path>>(
db_directory: P,
target_hypertree_count: usize,
reaper: Reaper,
maintenance_sender: flume::Sender<()>,
) -> Result<Self, TreeError> {
std::fs::create_dir_all(db_directory.as_ref())?;
let mut database = Database::create(db_directory.as_ref().join("vectordb.redb"))?;
database.check_integrity()?;
database.compact()?;
let txn = database.begin_write()?;
txn.open_table(vector_table_definition::<N>())?;
txn.commit()?;
let max_bucket_size = 128;
let hypertrees = (0..target_hypertree_count)
.map(|i| {
let tree = hypertree::HypertreeRepo::open(
db_directory
.as_ref()
.join("hypertrees")
.join(format!("tree-{i}")),
max_bucket_size,
)?;
tree.cleanup()?;
Ok(HypertreeMeta::build(tree, reaper.clone()))
})
.collect::<Result<Vec<_>, TreeError>>()?;
Ok(VectorRepo {
database,
hypertrees,
last_rebuild: std::sync::atomic::AtomicU64::new(0),
maintenance_sender,
})
}
fn insert_many_vectors(&self, vectors: Vec<Vector<N>>) -> Result<Vec<VectorId>, TreeError> {
let mut txn = self.database.begin_write()?;
let (ids, internal_ids) = insert_many_vectors(
&mut txn,
&vectors,
&self.last_rebuild,
&self.maintenance_sender,
)?;
let vectors = std::sync::Arc::from(
internal_ids
.into_iter()
.zip(vectors.into_iter())
.collect::<Vec<_>>(),
);
self.hypertrees
.iter()
.map(|hypertree| hypertree.add_many_to_index(std::sync::Arc::clone(&vectors)))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.try_for_each(|receiver| receiver.recv()?)?;
txn.commit()?;
Ok(ids)
}
fn get_vectors(
&self,
vector_ids: &[VectorId],
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
let txn = self.database.begin_read()?;
get_vectors(&txn, vector_ids)
}
fn find_similarities(
&self,
vector: std::sync::Arc<Vector<N>>,
threshold: Option<f32>,
limit: usize,
similarity_style: SimilarityStyle,
) -> Result<Vec<VectorId>, TreeError> {
let txn = self.database.begin_read()?;
let inverse_vector_insert_multimap =
txn.open_multimap_table(INVERSE_VECTOR_INSERT_MULTIMAP)?;
let mut candidates = self
.hypertrees
.iter()
.map(|hypertree| {
hypertree.find_similarities(
std::sync::Arc::clone(&vector),
threshold,
limit,
similarity_style,
)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|receiver| receiver.recv()?)
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
match similarity_style {
SimilarityStyle::Dissimilar | SimilarityStyle::FurthestSimilar => {
candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Numbers are finite"));
}
SimilarityStyle::Similar | SimilarityStyle::ClosestDissimilar => {
candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Numbers are finite"));
}
}
candidates.dedup_by_key(|a| a.1);
candidates.truncate(limit);
let vectors = candidates
.into_iter()
.map(|(_, internal_vector_id, _)| {
inverse_vector_insert_multimap
.get(internal_vector_id)
.and_then(|iter| {
iter.map(|res| res.map(|value| value.value()))
.collect::<Result<Vec<_>, _>>()
})
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();
Ok(vectors)
}
}
fn reader_runner<const N: usize>(
repo: &VectorRepo<N>,
rx: flume::Receiver<ReaderCommand<N>>,
stopper: flume::Receiver<()>,
) {
with_stopper(rx, stopper, |message| match message {
ReaderCommand::FindSimilarities {
vector,
threshold,
limit,
similarity_style,
responder,
parent,
} => {
let span = tracing::debug_span!(parent: parent, "FindSimilarities");
let guard = span.enter();
let res = std::panic::catch_unwind(|| {
repo.find_similarities(vector, threshold, limit, similarity_style)
});
match res {
Ok(res) => {
if responder.send(res).is_err() {
tracing::warn!("Requester disconnected");
}
}
Err(e) => {
tracing::error!("operation panicked: {e:?}");
}
}
drop(guard);
}
ReaderCommand::GetVectors {
vector_ids,
responder,
parent,
} => {
let span = tracing::debug_span!(parent: parent, "GetVectors");
let guard = span.enter();
let res = std::panic::catch_unwind(|| repo.get_vectors(&vector_ids));
match res {
Ok(res) => {
if responder.send(res).is_err() {
tracing::warn!("Requester disconnected");
}
}
Err(e) => {
tracing::error!("operation panicked: {e:?}");
}
}
drop(guard);
}
})
}
fn insert_vectors_runner<const N: usize>(
repo: &VectorRepo<N>,
rx: flume::Receiver<InsertVectorsCommand<N>>,
stopper: flume::Receiver<()>,
) {
with_stopper(rx, stopper, |command| {
let InsertVectorsCommand {
vectors,
responder,
parent,
} = command;
let span = tracing::debug_span!(parent: parent, "InsertVectors");
let guard = span.enter();
let res = std::panic::catch_unwind(|| repo.insert_many_vectors(vectors));
match res {
Ok(res) => {
if responder.send(res).is_err() {
tracing::warn!("Requester disconnected");
}
}
Err(e) => {
tracing::error!("operation panicked: {e:?}");
}
}
drop(guard);
})
}
fn with_stopper<Msg, Callback>(
rx: flume::Receiver<Msg>,
stopper: flume::Receiver<()>,
mut cb: Callback,
) where
Callback: FnMut(Msg),
{
let stopping = std::sync::atomic::AtomicBool::new(false);
while !stopping.load(std::sync::atomic::Ordering::Acquire) {
flume::Selector::new()
.recv(&stopper, |_| {
stopping.store(true, std::sync::atomic::Ordering::Release)
})
.recv(&rx, |res| match res {
Ok(msg) => (cb)(msg),
Err(_) => stopping.store(true, std::sync::atomic::Ordering::Release),
})
.wait();
}
}
fn insert_many_vectors<const N: usize>(
transaction: &mut WriteTransaction<'_>,
vectors: &[Vector<N>],
len: &std::sync::atomic::AtomicU64,
maintenance_sender: &flume::Sender<()>,
) -> Result<(Vec<VectorId>, Vec<InternalVectorId>), TreeError> {
let mut ids_table = transaction.open_table(IDS_TABLE)?;
let mut inverse_vector_insert_multimap =
transaction.open_multimap_table(INVERSE_VECTOR_INSERT_MULTIMAP)?;
let mut inverse_vector_table = transaction.open_table(inverse_vector_table_definition())?;
let mut vector_insert_table = transaction.open_table(VECTOR_INSERT_TABLE)?;
let mut vector_table = transaction.open_table(vector_table_definition())?;
let (vectors, internal_vectors) = vectors
.iter()
.map(|vector| {
do_insert_vector(
&mut ids_table,
&mut inverse_vector_insert_multimap,
&mut inverse_vector_table,
&mut vector_insert_table,
&mut vector_table,
vector,
)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.unzip::<_, _, _, Vec<Option<InternalVectorId>>>();
let internal_vectors = internal_vectors.into_iter().flatten().collect();
let vectors_len = vector_table.len()?;
let previous_len = len.load(std::sync::atomic::Ordering::Acquire);
if vectors_len > 0
&& vectors_len / 2 > previous_len
&& len
.compare_exchange(
previous_len,
vectors_len,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_ok()
{
let _ = maintenance_sender.try_send(());
}
Ok((vectors, internal_vectors))
}
fn do_insert_vector<'db, 'txn, const N: usize>(
ids_table: &mut Table<'db, 'txn, &'static str, u128>,
inverse_vector_insert_multimap: &mut MultimapTable<'db, 'txn, InternalVectorId, VectorId>,
inverse_vector_table: &mut Table<'db, 'txn, VectorBytes<N>, InternalVectorId>,
vector_insert_table: &mut Table<'db, 'txn, VectorId, InternalVectorId>,
vector_table: &mut Table<'db, 'txn, InternalVectorId, VectorBytes<N>>,
vector: &Vector<N>,
) -> Result<(VectorId, Option<InternalVectorId>), TreeError> {
let vector_id = VectorId(next_id(ids_table, VECTOR_INSERT_TABLE)?);
let vector_bytes = vector.encode();
let internal_vector_id_opt = inverse_vector_table
.get(&vector_bytes)?
.map(|id| id.value());
let internal_vector_id = if let Some(internal_vector_id) = internal_vector_id_opt {
vector_insert_table.insert(vector_id, internal_vector_id)?;
inverse_vector_insert_multimap.insert(internal_vector_id, vector_id)?;
None
} else {
let internal_vector_id =
InternalVectorId(next_id(ids_table, vector_table_definition::<N>())?);
vector_table.insert(internal_vector_id, &vector_bytes)?;
inverse_vector_table.insert(&vector_bytes, internal_vector_id)?;
vector_insert_table.insert(vector_id, internal_vector_id)?;
inverse_vector_insert_multimap.insert(internal_vector_id, vector_id)?;
Some(internal_vector_id)
};
Ok((vector_id, internal_vector_id))
}
fn get_vectors<const N: usize>(
transaction: &ReadTransaction<'_>,
vector_ids: &[VectorId],
) -> Result<HashMap<VectorId, Vector<N>>, TreeError> {
let vector_insert_table = transaction.open_table(VECTOR_INSERT_TABLE)?;
let vector_table = transaction.open_table(vector_table_definition())?;
let mut out = HashMap::with_capacity(vector_ids.len());
for vector_id in vector_ids {
let Some(internal_vector_id) = vector_insert_table.get(*vector_id)? else {
continue;
};
let Some(vector) = do_get_vector(&vector_table, internal_vector_id.value())? else {
continue;
};
out.insert(*vector_id, vector);
}
Ok(out)
}
fn do_get_vector<const N: usize, T>(
vector_table: &T,
internal_vector_id: InternalVectorId,
) -> Result<Option<Vector<N>>, TreeError>
where
T: ReadableTable<InternalVectorId, VectorBytes<'static, N>>,
{
let Some(vector_bytes) = vector_table.get(internal_vector_id)? else {
return Ok(None);
};
let vector = Vector::decode(&vector_bytes.value())?;
Ok(Some(vector))
}
fn next_id<H>(
table: &mut Table<'_, '_, &'static str, u128>,
handle: H,
) -> Result<u128, StorageError>
where
H: TableHandle,
{
let mut next_id = table.get(handle.name())?.map(|id| id.value()).unwrap_or(0);
while let Some(id) = table.insert(handle.name(), next_id + 1)? {
let id = id.value();
if id == next_id {
break;
}
next_id = id;
}
Ok(next_id)
}
type F32ByteArray = [u8; 4];
const F32_BYTE_ARRAY_SIZE: usize = std::mem::size_of::<F32ByteArray>();
const fn vector_byte_length(n: usize) -> usize {
n * crate::F32_BYTE_ARRAY_SIZE
}
impl<const N: usize> Vector<N> {
const BYTES_LEN: usize = N * F32_BYTE_ARRAY_SIZE;
fn encode(&self) -> VectorBytes<'static, N> {
VectorBytes {
vector_bytes: std::borrow::Cow::Owned(
self.0.iter().flat_map(|f| f.to_be_bytes()).collect(),
),
_phantom: PhantomData,
}
}
fn decode(
VectorBytes {
vector_bytes,
_phantom,
}: &VectorBytes<N>,
) -> Result<Self, DecodeError> {
if vector_bytes.is_empty() {
return Err(DecodeError::Empty);
}
if vector_bytes.len() != Self::BYTES_LEN {
return Err(DecodeError::InvalidLength);
}
let mut index = 0;
// TODO: array.zip
let array = [(); N].map(|()| {
let f = f32::from_be_bytes(
vector_bytes[index..(index + F32_BYTE_ARRAY_SIZE)]
.try_into()
.expect("f32 byte array size is correct"),
);
index += F32_BYTE_ARRAY_SIZE;
f
});
Ok(Self(array))
}
fn average(&self, rhs: &Self) -> Self {
let mut index = 0;
// TODO: array.zip
Vector([(); N].map(|_| {
let avg = (self.0[index] + rhs.0[index]) / 2.0;
index += 1;
avg
}))
}
fn dot_product(&self, rhs: &Self) -> f32 {
self.0.iter().zip(rhs.0.iter()).map(|(a, b)| a * b).sum()
}
pub fn squared_euclidean_distance(&self, rhs: &Self) -> f32 {
self.0
.iter()
.zip(rhs.0.iter())
.map(|(a, b)| (a - b).powi(2))
.sum()
}
}
impl<const N: usize> RedbValue for VectorBytes<'static, N> {
type SelfType<'a> = VectorBytes<'a, N>;
type AsBytes<'a> = &'a [u8];
fn fixed_width() -> Option<usize> {
None
}
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
assert_eq!(
data.len(),
vector_byte_length(N),
"Byte length is not vector byte length"
);
VectorBytes {
vector_bytes: std::borrow::Cow::Borrowed(data),
_phantom: std::marker::PhantomData,
}
}
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
&value.vector_bytes
}
fn type_name() -> TypeName {
TypeName::new("vectordb::VectorBytes")
}
}
impl<const N: usize> RedbKey for VectorBytes<'static, N> {
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
data1.cmp(data2)
}
}
impl RedbValue for VectorId {
type SelfType<'a> = VectorId;
type AsBytes<'a> = <u128 as RedbValue>::AsBytes<'a>;
fn fixed_width() -> Option<usize> {
u128::fixed_width()
}
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
Self(u128::from_bytes(data))
}
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
u128::as_bytes(&value.0)
}
fn type_name() -> TypeName {
TypeName::new("vectordb::VectorId")
}
}
impl RedbKey for VectorId {
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
u128::compare(data1, data2)
}
}
impl RedbValue for InternalVectorId {
type SelfType<'a> = InternalVectorId;
type AsBytes<'a> = <u128 as RedbValue>::AsBytes<'a>;
fn fixed_width() -> Option<usize> {
u128::fixed_width()
}
fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
where
Self: 'a,
{
Self(u128::from_bytes(data))
}
fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
where
Self: 'a,
Self: 'b,
{
u128::as_bytes(&value.0)
}
fn type_name() -> TypeName {
TypeName::new("vectordb::InternalVectorId")
}
}
impl RedbKey for InternalVectorId {
fn compare(data1: &[u8], data2: &[u8]) -> std::cmp::Ordering {
u128::compare(data1, data2)
}
}
impl<const N: usize> From<[f32; N]> for Vector<N> {
fn from(value: [f32; N]) -> Self {
Self(value)
}
}
impl<const N: usize> std::ops::Mul<f32> for Vector<N> {
type Output = Vector<N>;
fn mul(self, rhs: f32) -> Self::Output {
Vector(self.0.map(|a| a * rhs))
}
}
impl<const N: usize> std::ops::Div<f32> for Vector<N> {
type Output = Vector<N>;
fn div(self, rhs: f32) -> Self::Output {
Vector(self.0.map(|a| a / rhs))
}
}
impl<const N: usize> std::ops::Add<Vector<N>> for Vector<N> {
type Output = Vector<N>;
fn add(self, rhs: Self) -> Self::Output {
let mut index = 0;
// TODO: array.zip
Vector([(); N].map(|()| {
let sum = self.0[index] + rhs.0[index];
index += 1;
sum
}))
}
}
impl<'a, const N: usize> std::ops::Add<&'a Vector<N>> for Vector<N> {
type Output = Vector<N>;
fn add(self, rhs: &'a Self) -> Self::Output {
let mut index = 0;
// TODO: array.zip
Vector([(); N].map(|()| {
let sum = self.0[index] + rhs.0[index];
index += 1;
sum
}))
}
}
impl<'a, const N: usize> std::ops::Add<&'a Vector<N>> for &'a Vector<N> {
type Output = Vector<N>;
fn add(self, rhs: &'a Vector<N>) -> Self::Output {
let mut index = 0;
// TODO: array.zip
Vector([(); N].map(|()| {
let sum = self.0[index] + rhs.0[index];
index += 1;
sum
}))
}
}
impl<const N: usize> std::ops::Sub<Vector<N>> for Vector<N> {
type Output = Vector<N>;
fn sub(self, rhs: Self) -> Self::Output {
let mut index = 0;
// TODO: array.zip
Vector([(); N].map(|()| {
let diff = self.0[index] - rhs.0[index];
index += 1;
diff
}))
}
}
impl<'a, const N: usize> std::ops::Sub<&'a Vector<N>> for Vector<N> {
type Output = Vector<N>;
fn sub(self, rhs: &'a Self) -> Self::Output {
let mut index = 0;
// TODO: array.zip
Vector([(); N].map(|()| {
let diff = self.0[index] - rhs.0[index];
index += 1;
diff
}))
}
}
impl<'a, const N: usize> std::ops::Sub<&'a Vector<N>> for &'a Vector<N> {
type Output = Vector<N>;
fn sub(self, rhs: &'a Vector<N>) -> Self::Output {
let mut index = 0;
// TODO: array.zip
Vector([(); N].map(|()| {
let diff = self.0[index] - rhs.0[index];
index += 1;
diff
}))
}
}
impl From<oneshot::RecvError> for TreeError {
fn from(_: oneshot::RecvError) -> Self {
Self::Canceled
}
}
impl From<CommitError> for TreeError {
fn from(value: CommitError) -> Self {
Self::Commit(value)
}
}
impl From<CompactionError> for TreeError {
fn from(value: CompactionError) -> Self {
Self::Compact(value)
}
}
impl From<DatabaseError> for TreeError {
fn from(value: DatabaseError) -> Self {
Self::Database(value)
}
}
impl From<DecodeError> for TreeError {
fn from(value: DecodeError) -> Self {
Self::Decode(value)
}
}
impl From<std::io::Error> for TreeError {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
impl From<StorageError> for TreeError {
fn from(value: StorageError) -> Self {
Self::Storage(value)
}
}
impl From<TableError> for TreeError {
fn from(value: TableError) -> Self {
Self::Table(value)
}
}
impl From<TransactionError> for TreeError {
fn from(value: TransactionError) -> Self {
Self::Transaction(value)
}
}
impl std::fmt::Display for VectorId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
u128::fmt(&self.0, f)
}
}
impl std::fmt::Display for DecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Empty => write!(f, "No bytes present"),
Self::InvalidLength => write!(f, "Invalid number of bytes for type"),
Self::Bound => write!(f, "Direction bound is invalid"),
}
}
}
impl std::fmt::Display for TreeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Commit(_) => write!(f, "Error committing transaction"),
Self::Compact(_) => write!(f, "Error compacting database"),
Self::Database(_) => write!(f, "Error opening database"),
Self::Decode(_) => write!(f, "Error decoding bytes"),
Self::Io(_) => write!(f, "Error creating db directory"),
Self::Storage(_) => write!(f, "Error in storage"),
Self::Table(_) => write!(f, "Error openning table"),
Self::Transaction(_) => write!(f, "Error opening transaction"),
Self::Canceled => write!(f, "Operation panicked"),
Self::Closed => write!(f, "Database is closing"),
}
}
}
impl std::error::Error for DecodeError {}
impl std::error::Error for TreeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Commit(e) => Some(e),
Self::Compact(e) => Some(e),
Self::Database(e) => Some(e),
Self::Decode(e) => Some(e),
Self::Io(e) => Some(e),
Self::Storage(e) => Some(e),
Self::Table(e) => Some(e),
Self::Transaction(e) => Some(e),
Self::Canceled => None,
Self::Closed => None,
}
}
}