Skip to content

Commit

Permalink
feat: scalar8 & indexing on halfvec (#131)
Browse files Browse the repository at this point in the history
closes #91 (indexing on halfvec)
closes #118 (scalar8)

---------

Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi authored Dec 10, 2024
1 parent a60c262 commit 5175707
Show file tree
Hide file tree
Showing 45 changed files with 2,137 additions and 475 deletions.
116 changes: 12 additions & 104 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vchord"
version = "0.1.0"
version = "0.0.0"
edition = "2021"

[lib]
Expand All @@ -20,17 +20,15 @@ pg16 = ["pgrx/pg16", "pgrx-catalog/pg16"]
pg17 = ["pgrx/pg17", "pgrx-catalog/pg17"]

[dependencies]
base = { git = "https://github.com/tensorchord/pgvecto.rs.git", branch = "rabbithole-2" }
detect = { git = "https://github.com/tensorchord/pgvecto.rs.git", branch = "rabbithole-2" }
quantization = { git = "https://github.com/tensorchord/pgvecto.rs.git", branch = "rabbithole-2" }
base = { git = "https://github.com/tensorchord/pgvecto.rs.git", rev = "c911e8a476effaf05cd4b1a037826b100177bdad" }

# lock algebra version forever so that the QR decomposition never changes for same input
nalgebra = { version = "=0.33.0", default-features = false }
nalgebra = "=0.33.0"

# lock rkyv version forever so that data is always compatible
rkyv = { version = "=0.7.45", features = ["validation"] }

half = "2.4.1"
half = { version = "2.4.1", features = ["rkyv"] }
log = "0.4.22"
paste = "1"
pgrx = { version = "=0.12.8", default-features = false, features = ["cshim"] }
Expand All @@ -43,6 +41,9 @@ serde = "1"
toml = "0.8.19"
validator = { version = "0.18.1", features = ["derive"] }

[patch.crates-io]
half = { git = "https://github.com/tensorchord/half-rs.git" }

[lints]
rust.unsafe_op_in_unsafe_fn = "deny"
rust.unused_lifetimes = "warn"
Expand Down
75 changes: 75 additions & 0 deletions src/datatype/binary_scalar8.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::memory_scalar8::{Scalar8Input, Scalar8Output};
use crate::types::scalar8::Scalar8Borrowed;
use base::vector::VectorBorrowed;
use pgrx::datum::Internal;
use pgrx::pg_sys::Oid;

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vchord_scalar8_send(vector: Scalar8Input<'_>) -> Vec<u8> {
let vector = vector.as_borrowed();
let mut stream = Vec::<u8>::new();
stream.extend(vector.dims().to_be_bytes());
stream.extend(vector.sum_of_x2().to_be_bytes());
stream.extend(vector.k().to_be_bytes());
stream.extend(vector.b().to_be_bytes());
stream.extend(vector.sum_of_code().to_be_bytes());
for &c in vector.code() {
stream.extend(c.to_be_bytes());
}
stream
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vchord_scalar8_recv(internal: Internal, oid: Oid, typmod: i32) -> Scalar8Output {
let _ = (oid, typmod);
let buf = unsafe { internal.get_mut::<pgrx::pg_sys::StringInfoData>().unwrap() };

let dims = {
assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len);
let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() };
buf.cursor += 4;
u32::from_be_bytes(raw)
};
let sum_of_x2 = {
assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len);
let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() };
buf.cursor += 4;
f32::from_be_bytes(raw)
};
let k = {
assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len);
let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() };
buf.cursor += 4;
f32::from_be_bytes(raw)
};
let b = {
assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len);
let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() };
buf.cursor += 4;
f32::from_be_bytes(raw)
};
let sum_of_code = {
assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len);
let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() };
buf.cursor += 4;
f32::from_be_bytes(raw)
};
let code = {
let mut result = Vec::with_capacity(dims as _);
for _ in 0..dims {
result.push({
assert!(buf.cursor < i32::MAX - 1 && buf.cursor + 1 <= buf.len);
let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 1]>().read() };
buf.cursor += 1;
u8::from_be_bytes(raw)
});
}
result
};

if let Some(x) = Scalar8Borrowed::new_checked(sum_of_x2, k, b, sum_of_code, &code) {
Scalar8Output::new(x)
} else {
pgrx::error!("detect data corruption");
}
}
26 changes: 26 additions & 0 deletions src/datatype/functions_scalar8.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput;
use crate::datatype::memory_pgvector_vector::PgvectorVectorInput;
use crate::datatype::memory_scalar8::Scalar8Output;
use crate::types::scalar8::Scalar8Borrowed;
use base::simd::ScalarLike;
use half::f16;

#[pgrx::pg_extern(sql = "")]
fn _vchord_vector_quantize_to_scalar8(vector: PgvectorVectorInput) -> Scalar8Output {
let vector = vector.as_borrowed();
let sum_of_x2 = f32::reduce_sum_of_x2(vector.slice());
let (k, b, code) =
base::simd::quantize::quantize(f32::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0);
let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32;
Scalar8Output::new(Scalar8Borrowed::new(sum_of_x2, k, b, sum_of_code, &code))
}

#[pgrx::pg_extern(sql = "")]
fn _vchord_halfvec_quantize_to_scalar8(vector: PgvectorHalfvecInput) -> Scalar8Output {
let vector = vector.as_borrowed();
let sum_of_x2 = f16::reduce_sum_of_x2(vector.slice());
let (k, b, code) =
base::simd::quantize::quantize(f16::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0);
let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32;
Scalar8Output::new(Scalar8Borrowed::new(sum_of_x2, k, b, sum_of_code, &code))
}
Loading

0 comments on commit 5175707

Please sign in to comment.