From 5175707c406687496c0796123439dab0e833ff94 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 10 Dec 2024 14:24:48 +0800 Subject: [PATCH] feat: scalar8 & indexing on halfvec (#131) closes #91 (indexing on halfvec) closes #118 (scalar8) --------- Signed-off-by: usamoi --- Cargo.lock | 116 +-------- Cargo.toml | 13 +- src/datatype/binary_scalar8.rs | 75 ++++++ src/datatype/functions_scalar8.rs | 26 ++ src/datatype/memory_pgvector_halfvec.rs | 204 +++++++++++++++ src/datatype/memory_pgvector_vector.rs | 56 ++-- src/datatype/memory_scalar8.rs | 215 ++++++++++++++++ src/datatype/mod.rs | 7 + src/datatype/operators_pgvector_halfvec.rs | 75 ++++++ src/datatype/operators_pgvector_vector.rs | 18 +- src/datatype/operators_scalar8.rs | 106 ++++++++ src/datatype/text_scalar8.rs | 138 ++++++++++ src/datatype/typmod.rs | 30 ++- src/lib.rs | 6 +- src/projection.rs | 2 +- src/sql/bootstrap.sql | 3 + src/sql/finalize.sql | 118 +++++++++ src/types/mod.rs | 1 + src/types/scalar8.rs | 286 +++++++++++++++++++++ src/utils/infinite_byte_chunks.rs | 20 ++ src/utils/k_means.rs | 14 +- src/utils/mod.rs | 1 + src/vchordrq/algorithm/build.rs | 32 +-- src/vchordrq/algorithm/insert.rs | 47 ++-- src/vchordrq/algorithm/prewarm.rs | 6 +- src/vchordrq/algorithm/rabitq.rs | 44 +--- src/vchordrq/algorithm/scan.rs | 44 ++-- src/vchordrq/algorithm/tuples.rs | 237 ++++++++++++++++- src/vchordrq/algorithm/vacuum.rs | 10 +- src/vchordrq/algorithm/vectors.rs | 57 ++-- src/vchordrq/index/am.rs | 263 ++++++++++++------- src/vchordrq/index/am_options.rs | 27 +- src/vchordrq/index/am_scan.rs | 74 ++++-- src/vchordrq/index/functions.rs | 10 +- src/vchordrq/index/opclass.rs | 15 ++ src/vchordrq/types.rs | 47 ++++ src/vchordrqfscan/algorithm/build.rs | 6 +- src/vchordrqfscan/algorithm/insert.rs | 2 +- src/vchordrqfscan/algorithm/rabitq.rs | 16 +- src/vchordrqfscan/algorithm/scan.rs | 2 +- src/vchordrqfscan/index/am.rs | 20 +- src/vchordrqfscan/index/am_options.rs | 12 +- src/vchordrqfscan/index/am_scan.rs | 11 +- src/vchordrqfscan/types.rs | 41 +++ tests/logic/distance.slt | 59 +++++ 45 files changed, 2137 insertions(+), 475 deletions(-) create mode 100644 src/datatype/binary_scalar8.rs create mode 100644 src/datatype/functions_scalar8.rs create mode 100644 src/datatype/memory_pgvector_halfvec.rs create mode 100644 src/datatype/memory_scalar8.rs create mode 100644 src/datatype/operators_pgvector_halfvec.rs create mode 100644 src/datatype/operators_scalar8.rs create mode 100644 src/datatype/text_scalar8.rs create mode 100644 src/types/mod.rs create mode 100644 src/types/scalar8.rs create mode 100644 src/utils/infinite_byte_chunks.rs create mode 100644 tests/logic/distance.slt diff --git a/Cargo.lock b/Cargo.lock index 5fe8ddb..235b782 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "ahash" @@ -66,15 +66,13 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "base" version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?branch=rabbithole-2#29df8b43d45861fc741034bc7f8f304ca18822ca" +source = "git+https://github.com/tensorchord/pgvecto.rs.git?rev=c911e8a476effaf05cd4b1a037826b100177bdad#c911e8a476effaf05cd4b1a037826b100177bdad" dependencies = [ "base_macros", "detect", "half 2.4.1", "libc", - "log", "rand", - "rayon", "serde", "thiserror", "toml", @@ -84,7 +82,7 @@ dependencies = [ [[package]] name = "base_macros" version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?branch=rabbithole-2#29df8b43d45861fc741034bc7f8f304ca18822ca" +source = "git+https://github.com/tensorchord/pgvecto.rs.git?rev=c911e8a476effaf05cd4b1a037826b100177bdad#c911e8a476effaf05cd4b1a037826b100177bdad" dependencies = [ "proc-macro2", "quote", @@ -152,9 +150,9 @@ dependencies = [ [[package]] name = "bytemuck" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" [[package]] name = "byteorder" @@ -223,20 +221,6 @@ dependencies = [ "libloading", ] -[[package]] -name = "common" -version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?branch=rabbithole-2#29df8b43d45861fc741034bc7f8f304ca18822ca" -dependencies = [ - "base", - "log", - "memmap2", - "rand", - "rustix", - "serde", - "serde_json", -] - [[package]] name = "convert_case" version = "0.6.0" @@ -315,7 +299,7 @@ dependencies = [ [[package]] name = "detect" version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?branch=rabbithole-2#29df8b43d45861fc741034bc7f8f304ca18822ca" +source = "git+https://github.com/tensorchord/pgvecto.rs.git?rev=c911e8a476effaf05cd4b1a037826b100177bdad#c911e8a476effaf05cd4b1a037826b100177bdad" dependencies = [ "detect_macros", ] @@ -323,7 +307,7 @@ dependencies = [ [[package]] name = "detect_macros" version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?branch=rabbithole-2#29df8b43d45861fc741034bc7f8f304ca18822ca" +source = "git+https://github.com/tensorchord/pgvecto.rs.git?rev=c911e8a476effaf05cd4b1a037826b100177bdad#c911e8a476effaf05cd4b1a037826b100177bdad" dependencies = [ "proc-macro2", "quote", @@ -362,16 +346,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" -[[package]] -name = "errno" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "eyre" version = "0.6.12" @@ -435,13 +409,13 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" [[package]] name = "half" version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +source = "git+https://github.com/tensorchord/half-rs.git#5b7fedc636c0eb1624763a40840c5cbf54cffd02" dependencies = [ "cfg-if", "crunchy", "rand", "rand_distr", + "rkyv", "serde", ] @@ -558,18 +532,6 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" -[[package]] -name = "k_means" -version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?branch=rabbithole-2#29df8b43d45861fc741034bc7f8f304ca18822ca" -dependencies = [ - "base", - "common", - "half 2.4.1", - "rand", - "smawk", -] - [[package]] name = "libc" version = "0.2.161" @@ -592,12 +554,6 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" -[[package]] -name = "linux-raw-sys" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" - [[package]] name = "log" version = "0.4.22" @@ -620,15 +576,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "memmap2" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" -dependencies = [ - "libc", -] - [[package]] name = "minimal-lexical" version = "0.2.1" @@ -955,24 +902,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "quantization" -version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?branch=rabbithole-2#29df8b43d45861fc741034bc7f8f304ca18822ca" -dependencies = [ - "base", - "common", - "detect", - "k_means", - "log", - "nalgebra", - "rand", - "rand_chacha", - "rand_distr", - "serde", - "serde_json", -] - [[package]] name = "quote" version = "1.0.37" @@ -1136,19 +1065,6 @@ dependencies = [ "semver", ] -[[package]] -name = "rustix" -version = "0.38.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" -dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.52.0", -] - [[package]] name = "ryu" version = "1.0.18" @@ -1273,12 +1189,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" -[[package]] -name = "smawk" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" - [[package]] name = "sptr" version = "0.3.2" @@ -1516,17 +1426,15 @@ dependencies = [ [[package]] name = "vchord" -version = "0.1.0" +version = "0.0.0" dependencies = [ "base", - "detect", "half 2.4.1", "log", "nalgebra", "paste", "pgrx", "pgrx-catalog", - "quantization", "rand", "rand_chacha", "rand_distr", @@ -1561,9 +1469,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wide" -version = "0.7.28" +version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b828f995bf1e9622031f8009f8481a85406ce1f4d4588ff746d872043e855690" +checksum = "58e6db2670d2be78525979e9a5f9c69d296fd7d670549fe9ebf70f8708cb5019" dependencies = [ "bytemuck", "safe_arch", diff --git a/Cargo.toml b/Cargo.toml index 000def6..66af418 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vchord" -version = "0.1.0" +version = "0.0.0" edition = "2021" [lib] @@ -20,17 +20,15 @@ pg16 = ["pgrx/pg16", "pgrx-catalog/pg16"] pg17 = ["pgrx/pg17", "pgrx-catalog/pg17"] [dependencies] -base = { git = "https://github.com/tensorchord/pgvecto.rs.git", branch = "rabbithole-2" } -detect = { git = "https://github.com/tensorchord/pgvecto.rs.git", branch = "rabbithole-2" } -quantization = { git = "https://github.com/tensorchord/pgvecto.rs.git", branch = "rabbithole-2" } +base = { git = "https://github.com/tensorchord/pgvecto.rs.git", rev = "c911e8a476effaf05cd4b1a037826b100177bdad" } # lock algebra version forever so that the QR decomposition never changes for same input -nalgebra = { version = "=0.33.0", default-features = false } +nalgebra = "=0.33.0" # lock rkyv version forever so that data is always compatible rkyv = { version = "=0.7.45", features = ["validation"] } -half = "2.4.1" +half = { version = "2.4.1", features = ["rkyv"] } log = "0.4.22" paste = "1" pgrx = { version = "=0.12.8", default-features = false, features = ["cshim"] } @@ -43,6 +41,9 @@ serde = "1" toml = "0.8.19" validator = { version = "0.18.1", features = ["derive"] } +[patch.crates-io] +half = { git = "https://github.com/tensorchord/half-rs.git" } + [lints] rust.unsafe_op_in_unsafe_fn = "deny" rust.unused_lifetimes = "warn" diff --git a/src/datatype/binary_scalar8.rs b/src/datatype/binary_scalar8.rs new file mode 100644 index 0000000..4dfe844 --- /dev/null +++ b/src/datatype/binary_scalar8.rs @@ -0,0 +1,75 @@ +use super::memory_scalar8::{Scalar8Input, Scalar8Output}; +use crate::types::scalar8::Scalar8Borrowed; +use base::vector::VectorBorrowed; +use pgrx::datum::Internal; +use pgrx::pg_sys::Oid; + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_send(vector: Scalar8Input<'_>) -> Vec { + let vector = vector.as_borrowed(); + let mut stream = Vec::::new(); + stream.extend(vector.dims().to_be_bytes()); + stream.extend(vector.sum_of_x2().to_be_bytes()); + stream.extend(vector.k().to_be_bytes()); + stream.extend(vector.b().to_be_bytes()); + stream.extend(vector.sum_of_code().to_be_bytes()); + for &c in vector.code() { + stream.extend(c.to_be_bytes()); + } + stream +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_recv(internal: Internal, oid: Oid, typmod: i32) -> Scalar8Output { + let _ = (oid, typmod); + let buf = unsafe { internal.get_mut::().unwrap() }; + + let dims = { + assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len); + let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() }; + buf.cursor += 4; + u32::from_be_bytes(raw) + }; + let sum_of_x2 = { + assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len); + let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() }; + buf.cursor += 4; + f32::from_be_bytes(raw) + }; + let k = { + assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len); + let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() }; + buf.cursor += 4; + f32::from_be_bytes(raw) + }; + let b = { + assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len); + let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() }; + buf.cursor += 4; + f32::from_be_bytes(raw) + }; + let sum_of_code = { + assert!(buf.cursor < i32::MAX - 4 && buf.cursor + 4 <= buf.len); + let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 4]>().read() }; + buf.cursor += 4; + f32::from_be_bytes(raw) + }; + let code = { + let mut result = Vec::with_capacity(dims as _); + for _ in 0..dims { + result.push({ + assert!(buf.cursor < i32::MAX - 1 && buf.cursor + 1 <= buf.len); + let raw = unsafe { buf.data.add(buf.cursor as _).cast::<[u8; 1]>().read() }; + buf.cursor += 1; + u8::from_be_bytes(raw) + }); + } + result + }; + + if let Some(x) = Scalar8Borrowed::new_checked(sum_of_x2, k, b, sum_of_code, &code) { + Scalar8Output::new(x) + } else { + pgrx::error!("detect data corruption"); + } +} diff --git a/src/datatype/functions_scalar8.rs b/src/datatype/functions_scalar8.rs new file mode 100644 index 0000000..fc1c221 --- /dev/null +++ b/src/datatype/functions_scalar8.rs @@ -0,0 +1,26 @@ +use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput; +use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; +use crate::datatype::memory_scalar8::Scalar8Output; +use crate::types::scalar8::Scalar8Borrowed; +use base::simd::ScalarLike; +use half::f16; + +#[pgrx::pg_extern(sql = "")] +fn _vchord_vector_quantize_to_scalar8(vector: PgvectorVectorInput) -> Scalar8Output { + let vector = vector.as_borrowed(); + let sum_of_x2 = f32::reduce_sum_of_x2(vector.slice()); + let (k, b, code) = + base::simd::quantize::quantize(f32::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0); + let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32; + Scalar8Output::new(Scalar8Borrowed::new(sum_of_x2, k, b, sum_of_code, &code)) +} + +#[pgrx::pg_extern(sql = "")] +fn _vchord_halfvec_quantize_to_scalar8(vector: PgvectorHalfvecInput) -> Scalar8Output { + let vector = vector.as_borrowed(); + let sum_of_x2 = f16::reduce_sum_of_x2(vector.slice()); + let (k, b, code) = + base::simd::quantize::quantize(f16::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0); + let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32; + Scalar8Output::new(Scalar8Borrowed::new(sum_of_x2, k, b, sum_of_code, &code)) +} diff --git a/src/datatype/memory_pgvector_halfvec.rs b/src/datatype/memory_pgvector_halfvec.rs new file mode 100644 index 0000000..7265a1b --- /dev/null +++ b/src/datatype/memory_pgvector_halfvec.rs @@ -0,0 +1,204 @@ +use base::vector::*; +use half::f16; +use pgrx::datum::FromDatum; +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use std::ops::Deref; +use std::ptr::NonNull; + +#[repr(C, align(8))] +pub struct PgvectorHalfvecHeader { + varlena: u32, + dims: u16, + unused: u16, + phantom: [f16; 0], +} + +impl PgvectorHalfvecHeader { + fn size_of(len: usize) -> usize { + if len > 65535 { + panic!("vector is too large"); + } + (size_of::() + size_of::() * len).next_multiple_of(8) + } + pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> { + unsafe { + VectBorrowed::new_unchecked(std::slice::from_raw_parts( + self.phantom.as_ptr(), + self.dims as usize, + )) + } + } +} + +pub enum PgvectorHalfvecInput<'a> { + Owned(PgvectorHalfvecOutput), + Borrowed(&'a PgvectorHalfvecHeader), +} + +impl PgvectorHalfvecInput<'_> { + unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + PgvectorHalfvecInput::Owned(PgvectorHalfvecOutput(q)) + } else { + unsafe { PgvectorHalfvecInput::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for PgvectorHalfvecInput<'_> { + type Target = PgvectorHalfvecHeader; + + fn deref(&self) -> &Self::Target { + match self { + PgvectorHalfvecInput::Owned(x) => x, + PgvectorHalfvecInput::Borrowed(x) => x, + } + } +} + +pub struct PgvectorHalfvecOutput(NonNull); + +impl PgvectorHalfvecOutput { + pub fn new(vector: VectBorrowed<'_, f16>) -> PgvectorHalfvecOutput { + unsafe { + let slice = vector.slice(); + let size = PgvectorHalfvecHeader::size_of(slice.len()); + + let ptr = pgrx::pg_sys::palloc0(size) as *mut PgvectorHalfvecHeader; + (&raw mut (*ptr).varlena).write((size << 2) as u32); + (&raw mut (*ptr).dims).write(vector.dims() as _); + (&raw mut (*ptr).unused).write(0); + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); + PgvectorHalfvecOutput(NonNull::new(ptr).unwrap()) + } + } + pub fn into_raw(self) -> *mut PgvectorHalfvecHeader { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for PgvectorHalfvecOutput { + type Target = PgvectorHalfvecHeader; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl Drop for PgvectorHalfvecOutput { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl FromDatum for PgvectorHalfvecInput<'_> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(PgvectorHalfvecInput::new(ptr)) } + } + } +} + +impl IntoDatum for PgvectorHalfvecOutput { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + Oid::INVALID + } + + fn is_compatible_with(_: Oid) -> bool { + true + } +} + +impl FromDatum for PgvectorHalfvecOutput { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let p = NonNull::new(datum.cast_mut_ptr::())?; + let q = + unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? }; + if p != q { + Some(PgvectorHalfvecOutput(q)) + } else { + let header = p.as_ptr(); + let vector = unsafe { (*header).as_borrowed() }; + Some(PgvectorHalfvecOutput::new(vector)) + } + } + } +} + +unsafe impl pgrx::datum::UnboxDatum for PgvectorHalfvecOutput { + type As<'src> = PgvectorHalfvecOutput; + #[inline] + unsafe fn unbox<'src>(d: pgrx::datum::Datum<'src>) -> Self::As<'src> + where + Self: 'src, + { + let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::()).unwrap(); + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + PgvectorHalfvecOutput(q) + } else { + let header = p.as_ptr(); + let vector = unsafe { (*header).as_borrowed() }; + PgvectorHalfvecOutput::new(vector) + } + } +} + +unsafe impl SqlTranslatable for PgvectorHalfvecInput<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("halfvec"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("halfvec")))) + } +} + +unsafe impl SqlTranslatable for PgvectorHalfvecOutput { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("halfvec"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("halfvec")))) + } +} + +unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for PgvectorHalfvecInput<'fcx> { + unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self { + unsafe { arg.unbox_arg_using_from_datum().unwrap() } + } +} + +unsafe impl pgrx::callconv::BoxRet for PgvectorHalfvecOutput { + unsafe fn box_into<'fcx>( + self, + fcinfo: &mut pgrx::callconv::FcInfo<'fcx>, + ) -> pgrx::datum::Datum<'fcx> { + unsafe { fcinfo.return_raw_datum(Datum::from(self.into_raw() as *mut ())) } + } +} diff --git a/src/datatype/memory_pgvector_vector.rs b/src/datatype/memory_pgvector_vector.rs index 7166ace..d81492c 100644 --- a/src/datatype/memory_pgvector_vector.rs +++ b/src/datatype/memory_pgvector_vector.rs @@ -8,48 +8,31 @@ use pgrx::pgrx_sql_entity_graph::metadata::Returns; use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use std::alloc::Layout; use std::ops::Deref; use std::ptr::NonNull; -pub const HEADER_MAGIC: u16 = 0; - #[repr(C, align(8))] pub struct PgvectorVectorHeader { varlena: u32, dims: u16, - magic: u16, + unused: u16, phantom: [f32; 0], } impl PgvectorVectorHeader { - fn varlena(size: usize) -> u32 { - (size << 2) as u32 - } - fn layout(len: usize) -> Layout { - u16::try_from(len).expect("Vector is too large."); - let layout_alpha = Layout::new::(); - let layout_beta = Layout::array::(len).unwrap(); - let layout = layout_alpha.extend(layout_beta).unwrap().0; - layout.pad_to_align() - } - #[allow(dead_code)] - pub fn dims(&self) -> u32 { - self.dims as u32 - } - pub fn slice(&self) -> &[f32] { - unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) } + fn size_of(len: usize) -> usize { + if len > 65535 { + panic!("vector is too large"); + } + (size_of::() + size_of::() * len).next_multiple_of(8) } pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> { - unsafe { VectBorrowed::new_unchecked(self.slice()) } - } -} - -impl Deref for PgvectorVectorHeader { - type Target = [f32]; - - fn deref(&self) -> &Self::Target { - self.slice() + unsafe { + VectBorrowed::new_unchecked(std::slice::from_raw_parts( + self.phantom.as_ptr(), + self.dims as usize, + )) + } } } @@ -88,15 +71,12 @@ impl PgvectorVectorOutput { pub fn new(vector: VectBorrowed<'_, f32>) -> PgvectorVectorOutput { unsafe { let slice = vector.slice(); - let layout = PgvectorVectorHeader::layout(slice.len()); - let dims = vector.dims(); - let internal_dims = dims as u16; - let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut PgvectorVectorHeader; - ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena) - .write(PgvectorVectorHeader::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).magic).write(HEADER_MAGIC); - std::ptr::addr_of_mut!((*ptr).dims).write(internal_dims); + let size = PgvectorVectorHeader::size_of(slice.len()); + + let ptr = pgrx::pg_sys::palloc0(size) as *mut PgvectorVectorHeader; + (&raw mut (*ptr).varlena).write((size << 2) as u32); + (&raw mut (*ptr).dims).write(vector.dims() as _); + (&raw mut (*ptr).unused).write(0); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); PgvectorVectorOutput(NonNull::new(ptr).unwrap()) } diff --git a/src/datatype/memory_scalar8.rs b/src/datatype/memory_scalar8.rs new file mode 100644 index 0000000..3a7dab4 --- /dev/null +++ b/src/datatype/memory_scalar8.rs @@ -0,0 +1,215 @@ +use crate::types::scalar8::Scalar8Borrowed; +use base::vector::*; +use pgrx::datum::FromDatum; +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use std::ops::Deref; +use std::ptr::NonNull; + +#[repr(C, align(8))] +pub struct Scalar8Header { + varlena: u32, + dims: u16, + unused: u16, + sum_of_x2: f32, + k: f32, + b: f32, + sum_of_code: f32, + phantom: [u8; 0], +} + +impl Scalar8Header { + fn size_of(len: usize) -> usize { + if len > 65535 { + panic!("vector is too large"); + } + (size_of::() + size_of::() * len).next_multiple_of(8) + } + pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> { + unsafe { + Scalar8Borrowed::new_unchecked( + self.sum_of_x2, + self.k, + self.b, + self.sum_of_code, + std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize), + ) + } + } +} + +pub enum Scalar8Input<'a> { + Owned(Scalar8Output), + Borrowed(&'a Scalar8Header), +} + +impl Scalar8Input<'_> { + unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + Scalar8Input::Owned(Scalar8Output(q)) + } else { + unsafe { Scalar8Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for Scalar8Input<'_> { + type Target = Scalar8Header; + + fn deref(&self) -> &Self::Target { + match self { + Scalar8Input::Owned(x) => x, + Scalar8Input::Borrowed(x) => x, + } + } +} + +pub struct Scalar8Output(NonNull); + +impl Scalar8Output { + pub fn new(vector: Scalar8Borrowed<'_>) -> Scalar8Output { + unsafe { + let code = vector.code(); + let size = Scalar8Header::size_of(code.len()); + + let ptr = pgrx::pg_sys::palloc0(size) as *mut Scalar8Header; + (&raw mut (*ptr).varlena).write((size << 2) as u32); + (&raw mut (*ptr).dims).write(vector.dims() as _); + (&raw mut (*ptr).unused).write(0); + (&raw mut (*ptr).sum_of_x2).write(vector.sum_of_x2()); + (&raw mut (*ptr).k).write(vector.k()); + (&raw mut (*ptr).b).write(vector.b()); + (&raw mut (*ptr).sum_of_code).write(vector.sum_of_code()); + std::ptr::copy_nonoverlapping(code.as_ptr(), (*ptr).phantom.as_mut_ptr(), code.len()); + Scalar8Output(NonNull::new(ptr).unwrap()) + } + } + pub fn into_raw(self) -> *mut Scalar8Header { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for Scalar8Output { + type Target = Scalar8Header; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl Drop for Scalar8Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl FromDatum for Scalar8Input<'_> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(Scalar8Input::new(ptr)) } + } + } +} + +impl IntoDatum for Scalar8Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + Oid::INVALID + } + + fn is_compatible_with(_: Oid) -> bool { + true + } +} + +impl FromDatum for Scalar8Output { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let p = NonNull::new(datum.cast_mut_ptr::())?; + let q = + unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? }; + if p != q { + Some(Scalar8Output(q)) + } else { + let header = p.as_ptr(); + let vector = unsafe { (*header).as_borrowed() }; + Some(Scalar8Output::new(vector)) + } + } + } +} + +unsafe impl pgrx::datum::UnboxDatum for Scalar8Output { + type As<'src> = Scalar8Output; + #[inline] + unsafe fn unbox<'src>(d: pgrx::datum::Datum<'src>) -> Self::As<'src> + where + Self: 'src, + { + let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::()).unwrap(); + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + Scalar8Output(q) + } else { + let header = p.as_ptr(); + let vector = unsafe { (*header).as_borrowed() }; + Scalar8Output::new(vector) + } + } +} + +unsafe impl SqlTranslatable for Scalar8Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("scalar8"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("scalar8")))) + } +} + +unsafe impl SqlTranslatable for Scalar8Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("scalar8"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("scalar8")))) + } +} + +unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for Scalar8Input<'fcx> { + unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self { + unsafe { arg.unbox_arg_using_from_datum().unwrap() } + } +} + +unsafe impl pgrx::callconv::BoxRet for Scalar8Output { + unsafe fn box_into<'fcx>( + self, + fcinfo: &mut pgrx::callconv::FcInfo<'fcx>, + ) -> pgrx::datum::Datum<'fcx> { + unsafe { fcinfo.return_raw_datum(Datum::from(self.into_raw() as *mut ())) } + } +} diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs index 9152f33..98b6650 100644 --- a/src/datatype/mod.rs +++ b/src/datatype/mod.rs @@ -1,3 +1,10 @@ +pub mod binary_scalar8; +pub mod functions_scalar8; +pub mod memory_pgvector_halfvec; pub mod memory_pgvector_vector; +pub mod memory_scalar8; +pub mod operators_pgvector_halfvec; pub mod operators_pgvector_vector; +pub mod operators_scalar8; +pub mod text_scalar8; pub mod typmod; diff --git a/src/datatype/operators_pgvector_halfvec.rs b/src/datatype/operators_pgvector_halfvec.rs new file mode 100644 index 0000000..2e707eb --- /dev/null +++ b/src/datatype/operators_pgvector_halfvec.rs @@ -0,0 +1,75 @@ +use crate::datatype::memory_pgvector_halfvec::*; +use base::vector::{VectBorrowed, VectorBorrowed}; +use std::num::NonZero; + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_halfvec_sphere_l2_in( + lhs: PgvectorHalfvecInput<'_>, + rhs: pgrx::composite_type!("sphere_halfvec"), +) -> bool { + let center: PgvectorHalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty center at sphere"), + Err(_) => unreachable!(), + }; + let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), + Err(_) => unreachable!(), + }; + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } + let d = VectBorrowed::operator_l2(lhs, center).to_f32().sqrt(); + d < radius +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_halfvec_sphere_ip_in( + lhs: PgvectorHalfvecInput<'_>, + rhs: pgrx::composite_type!("sphere_halfvec"), +) -> bool { + let center: PgvectorHalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty center at sphere"), + Err(_) => unreachable!(), + }; + let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), + Err(_) => unreachable!(), + }; + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } + let d = VectBorrowed::operator_dot(lhs, center).to_f32(); + d < radius +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_halfvec_sphere_cosine_in( + lhs: PgvectorHalfvecInput<'_>, + rhs: pgrx::composite_type!("sphere_halfvec"), +) -> bool { + let center: PgvectorHalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty center at sphere"), + Err(_) => unreachable!(), + }; + let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), + Err(_) => unreachable!(), + }; + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } + let d = VectBorrowed::operator_cos(lhs, center).to_f32(); + d < radius +} diff --git a/src/datatype/operators_pgvector_vector.rs b/src/datatype/operators_pgvector_vector.rs index 4a1f055..2308ab9 100644 --- a/src/datatype/operators_pgvector_vector.rs +++ b/src/datatype/operators_pgvector_vector.rs @@ -12,9 +12,6 @@ fn _vchord_vector_sphere_l2_in( Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), }; - if lhs.dims() != center.dims() { - pgrx::error!("dimension is not matched"); - } let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), @@ -22,6 +19,9 @@ fn _vchord_vector_sphere_l2_in( }; let lhs = lhs.as_borrowed(); let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } let d = VectBorrowed::operator_l2(lhs, center).to_f32().sqrt(); d < radius } @@ -36,9 +36,6 @@ fn _vchord_vector_sphere_ip_in( Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), }; - if lhs.dims() != center.dims() { - pgrx::error!("dimension is not matched"); - } let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), @@ -46,6 +43,9 @@ fn _vchord_vector_sphere_ip_in( }; let lhs = lhs.as_borrowed(); let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } let d = VectBorrowed::operator_dot(lhs, center).to_f32(); d < radius } @@ -60,9 +60,6 @@ fn _vchord_vector_sphere_cosine_in( Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), }; - if lhs.dims() != center.dims() { - pgrx::error!("dimension is not matched"); - } let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), @@ -70,6 +67,9 @@ fn _vchord_vector_sphere_cosine_in( }; let lhs = lhs.as_borrowed(); let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } let d = VectBorrowed::operator_cos(lhs, center).to_f32(); d < radius } diff --git a/src/datatype/operators_scalar8.rs b/src/datatype/operators_scalar8.rs new file mode 100644 index 0000000..db6a372 --- /dev/null +++ b/src/datatype/operators_scalar8.rs @@ -0,0 +1,106 @@ +use crate::datatype::memory_scalar8::{Scalar8Input, Scalar8Output}; +use crate::types::scalar8::Scalar8Borrowed; +use base::vector::*; +use std::num::NonZero; + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_operator_ip(lhs: Scalar8Input<'_>, rhs: Scalar8Input<'_>) -> f32 { + let lhs = lhs.as_borrowed(); + let rhs = rhs.as_borrowed(); + if lhs.dims() != rhs.dims() { + pgrx::error!("dimension is not matched"); + } + Scalar8Borrowed::operator_dot(lhs, rhs).to_f32() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_operator_l2(lhs: Scalar8Input<'_>, rhs: Scalar8Input<'_>) -> f32 { + let lhs = lhs.as_borrowed(); + let rhs = rhs.as_borrowed(); + if lhs.dims() != rhs.dims() { + pgrx::error!("dimension is not matched"); + } + Scalar8Borrowed::operator_l2(lhs, rhs).to_f32().sqrt() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_operator_cosine(lhs: Scalar8Input<'_>, rhs: Scalar8Input<'_>) -> f32 { + let lhs = lhs.as_borrowed(); + let rhs = rhs.as_borrowed(); + if lhs.dims() != rhs.dims() { + pgrx::error!("dimension is not matched"); + } + Scalar8Borrowed::operator_cos(lhs, rhs).to_f32() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_sphere_ip_in( + lhs: Scalar8Input<'_>, + rhs: pgrx::composite_type!("sphere_scalar8"), +) -> bool { + let center: Scalar8Output = match rhs.get_by_index(NonZero::new(1).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty center at sphere"), + Err(_) => unreachable!(), + }; + let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), + Err(_) => unreachable!(), + }; + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } + let d = Scalar8Borrowed::operator_dot(lhs, center).to_f32(); + d < radius +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_sphere_l2_in( + lhs: Scalar8Input<'_>, + rhs: pgrx::composite_type!("sphere_scalar8"), +) -> bool { + let center: Scalar8Output = match rhs.get_by_index(NonZero::new(1).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty center at sphere"), + Err(_) => unreachable!(), + }; + let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), + Err(_) => unreachable!(), + }; + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } + let d = Scalar8Borrowed::operator_l2(lhs, center).to_f32().sqrt(); + d < radius +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_sphere_cosine_in( + lhs: Scalar8Input<'_>, + rhs: pgrx::composite_type!("sphere_scalar8"), +) -> bool { + let center: Scalar8Output = match rhs.get_by_index(NonZero::new(1).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty center at sphere"), + Err(_) => unreachable!(), + }; + let radius: f32 = match rhs.get_by_index(NonZero::new(2).unwrap()) { + Ok(Some(s)) => s, + Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), + Err(_) => unreachable!(), + }; + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + if lhs.dims() != center.dims() { + pgrx::error!("dimension is not matched"); + } + let d = Scalar8Borrowed::operator_cos(lhs, center).to_f32(); + d < radius +} diff --git a/src/datatype/text_scalar8.rs b/src/datatype/text_scalar8.rs new file mode 100644 index 0000000..5de82da --- /dev/null +++ b/src/datatype/text_scalar8.rs @@ -0,0 +1,138 @@ +use super::memory_scalar8::Scalar8Output; +use crate::datatype::memory_scalar8::Scalar8Input; +use crate::types::scalar8::Scalar8Borrowed; +use pgrx::pg_sys::Oid; +use std::ffi::{CStr, CString}; + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_in(input: &CStr, oid: Oid, typmod: i32) -> Scalar8Output { + let _ = (oid, typmod); + let mut input = input.to_bytes().iter(); + let mut p0 = Vec::::new(); + let mut p1 = Vec::::new(); + { + loop { + let Some(c) = input.next().copied() else { + pgrx::error!("incorrect vector") + }; + match c { + b' ' => (), + b'(' => break, + _ => pgrx::error!("incorrect vector"), + } + } + } + { + let mut s = Option::::None; + loop { + let Some(c) = input.next().copied() else { + pgrx::error!("incorrect vector") + }; + s = match (s, c) { + (s, b' ') => s, + (None, c @ (b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-')) => { + Some(String::from(c as char)) + } + (Some(s), c @ (b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-')) => { + let mut x = s; + x.push(c as char); + Some(x) + } + (Some(s), b',') => { + p0.push(s.parse().expect("failed to parse number")); + None + } + (None, b',') => { + pgrx::error!("incorrect vector") + } + (Some(s), b')') => { + p0.push(s.parse().expect("failed to parse number")); + break; + } + (None, b')') => break, + _ => pgrx::error!("incorrect vector"), + }; + } + } + { + loop { + let Some(c) = input.next().copied() else { + pgrx::error!("incorrect vector") + }; + match c { + b' ' => (), + b'[' => break, + _ => pgrx::error!("incorrect vector"), + } + } + } + { + let mut s = Option::::None; + loop { + let Some(c) = input.next().copied() else { + pgrx::error!("incorrect vector") + }; + s = match (s, c) { + (s, b' ') => s, + (None, c @ (b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-')) => { + Some(String::from(c as char)) + } + (Some(s), c @ (b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-')) => { + let mut x = s; + x.push(c as char); + Some(x) + } + (Some(s), b',') => { + p1.push(s.parse().expect("failed to parse number")); + None + } + (None, b',') => { + pgrx::error!("incorrect vector") + } + (Some(s), b']') => { + p1.push(s.parse().expect("failed to parse number")); + break; + } + (None, b']') => break, + _ => pgrx::error!("incorrect vector"), + }; + } + } + if p0.len() != 4 { + pgrx::error!("incorrect vector"); + } + if p1.is_empty() { + pgrx::error!("vector must have at least 1 dimension"); + } + let sum_of_x2 = p0[0]; + let k = p0[1]; + let b = p0[2]; + let sum_of_code = p0[3]; + let code = p1; + if let Some(x) = Scalar8Borrowed::new_checked(sum_of_x2, k, b, sum_of_code, &code) { + Scalar8Output::new(x) + } else { + pgrx::error!("incorrect vector"); + } +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_scalar8_out(vector: Scalar8Input<'_>) -> CString { + let vector = vector.as_borrowed(); + let mut buffer = String::new(); + buffer.push('('); + buffer.push_str(format!("{}", vector.sum_of_x2()).as_str()); + buffer.push_str(format!(", {}", vector.k()).as_str()); + buffer.push_str(format!(", {}", vector.b()).as_str()); + buffer.push_str(format!(", {}", vector.sum_of_code()).as_str()); + buffer.push(')'); + buffer.push('['); + if let Some(&x) = vector.code().first() { + buffer.push_str(format!("{}", x).as_str()); + } + for &x in vector.code().iter().skip(1) { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} diff --git a/src/datatype/typmod.rs b/src/datatype/typmod.rs index d67bb0b..fe90a6d 100644 --- a/src/datatype/typmod.rs +++ b/src/datatype/typmod.rs @@ -1,5 +1,6 @@ use serde::{Deserialize, Serialize}; -use std::num::NonZeroU32; +use std::ffi::{CStr, CString}; +use std::num::{NonZero, NonZeroU32}; #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum Typmod { @@ -42,3 +43,30 @@ impl Typmod { } } } + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_typmod_in_65535(list: pgrx::datum::Array<&CStr>) -> i32 { + if list.is_empty() { + -1 + } else if list.len() == 1 { + let s = list.get(0).unwrap().unwrap().to_str().unwrap(); + let d = s.parse::().ok(); + if let Some(d @ 1..=65535) = d { + let typmod = Typmod::Dims(NonZero::new(d).unwrap()); + typmod.into_i32() + } else { + pgrx::error!("Modifier of the type is invalid.") + } + } else { + pgrx::error!("Modifier of the type is invalid.") + } +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchord_typmod_out(typmod: i32) -> CString { + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + match typmod.into_option_string() { + Some(s) => CString::new(format!("({})", s)).unwrap(), + None => CString::new("()").unwrap(), + } +} diff --git a/src/lib.rs b/src/lib.rs index 94741f5..4c55ac9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,14 @@ #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] +#![allow(clippy::int_plus_one)] +#![allow(clippy::unused_unit)] +#![allow(clippy::infallible_destructuring_match)] mod datatype; mod postgres; mod projection; +mod types; mod upgrade; mod utils; mod vchordrq; @@ -21,7 +25,7 @@ unsafe extern "C" fn _PG_init() { if unsafe { pgrx::pg_sys::IsUnderPostmaster } { pgrx::error!("vchord must be loaded via shared_preload_libraries."); } - detect::init(); + base::simd::enable(); unsafe { vchordrq::init(); vchordrqfscan::init(); diff --git a/src/projection.rs b/src/projection.rs index 68926ac..4273180 100644 --- a/src/projection.rs +++ b/src/projection.rs @@ -67,7 +67,7 @@ pub fn prewarm(n: usize) { } pub fn project(vector: &[f32]) -> Vec { - use base::scalar::ScalarLike; + use base::simd::ScalarLike; let n = vector.len(); let matrix = MATRIXS[n].get_or_init(|| orthogonal_matrix(n)); (0..n) diff --git a/src/sql/bootstrap.sql b/src/sql/bootstrap.sql index 6cc59f1..acc6a55 100644 --- a/src/sql/bootstrap.sql +++ b/src/sql/bootstrap.sql @@ -1,3 +1,6 @@ -- List of shell types +CREATE TYPE scalar8; CREATE TYPE sphere_vector; +CREATE TYPE sphere_halfvec; +CREATE TYPE sphere_scalar8; diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index 32b356c..c00aab1 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -1,12 +1,55 @@ -- List of data types +CREATE TYPE scalar8 ( + INPUT = _vchord_scalar8_in, + OUTPUT = _vchord_scalar8_out, + RECEIVE = _vchord_scalar8_recv, + SEND = _vchord_scalar8_send, + TYPMOD_IN = _vchord_typmod_in_65535, + TYPMOD_OUT = _vchord_typmod_out, + STORAGE = EXTERNAL, + INTERNALLENGTH = VARIABLE, + ALIGNMENT = double +); + CREATE TYPE sphere_vector AS ( center vector, radius REAL ); +CREATE TYPE sphere_halfvec AS ( + center halfvec, + radius REAL +); + +CREATE TYPE sphere_scalar8 AS ( + center scalar8, + radius REAL +); + -- List of operators +CREATE OPERATOR <-> ( + PROCEDURE = _vchord_scalar8_operator_l2, + LEFTARG = scalar8, + RIGHTARG = scalar8, + COMMUTATOR = <-> +); + +CREATE OPERATOR <#> ( + PROCEDURE = _vchord_scalar8_operator_ip, + LEFTARG = scalar8, + RIGHTARG = scalar8, + COMMUTATOR = <#> +); + +CREATE OPERATOR <=> ( + PROCEDURE = _vchord_scalar8_operator_cosine, + LEFTARG = scalar8, + RIGHTARG = scalar8, + COMMUTATOR = <=> +); + CREATE OPERATOR <<->> ( PROCEDURE = _vchord_vector_sphere_l2_in, LEFTARG = vector, @@ -14,6 +57,20 @@ CREATE OPERATOR <<->> ( COMMUTATOR = <<->> ); +CREATE OPERATOR <<->> ( + PROCEDURE = _vchord_halfvec_sphere_l2_in, + LEFTARG = halfvec, + RIGHTARG = sphere_halfvec, + COMMUTATOR = <<->> +); + +CREATE OPERATOR <<->> ( + PROCEDURE = _vchord_scalar8_sphere_l2_in, + LEFTARG = scalar8, + RIGHTARG = sphere_scalar8, + COMMUTATOR = <<->> +); + CREATE OPERATOR <<#>> ( PROCEDURE = _vchord_vector_sphere_ip_in, LEFTARG = vector, @@ -21,6 +78,20 @@ CREATE OPERATOR <<#>> ( COMMUTATOR = <<#>> ); +CREATE OPERATOR <<#>> ( + PROCEDURE = _vchord_halfvec_sphere_ip_in, + LEFTARG = halfvec, + RIGHTARG = sphere_halfvec, + COMMUTATOR = <<#>> +); + +CREATE OPERATOR <<#>> ( + PROCEDURE = _vchord_scalar8_sphere_ip_in, + LEFTARG = scalar8, + RIGHTARG = sphere_scalar8, + COMMUTATOR = <<#>> +); + CREATE OPERATOR <<=>> ( PROCEDURE = _vchord_vector_sphere_cosine_in, LEFTARG = vector, @@ -28,11 +99,37 @@ CREATE OPERATOR <<=>> ( COMMUTATOR = <<=>> ); +CREATE OPERATOR <<=>> ( + PROCEDURE = _vchord_halfvec_sphere_cosine_in, + LEFTARG = halfvec, + RIGHTARG = sphere_halfvec, + COMMUTATOR = <<=>> +); + +CREATE OPERATOR <<=>> ( + PROCEDURE = _vchord_scalar8_sphere_cosine_in, + LEFTARG = scalar8, + RIGHTARG = sphere_scalar8, + COMMUTATOR = <<=>> +); + -- List of functions CREATE FUNCTION sphere(vector, real) RETURNS sphere_vector IMMUTABLE PARALLEL SAFE LANGUAGE sql AS 'SELECT ROW($1, $2)'; +CREATE FUNCTION sphere(halfvec, real) RETURNS sphere_halfvec +IMMUTABLE PARALLEL SAFE LANGUAGE sql AS 'SELECT ROW($1, $2)'; + +CREATE FUNCTION sphere(scalar8, real) RETURNS sphere_scalar8 +IMMUTABLE PARALLEL SAFE LANGUAGE sql AS 'SELECT ROW($1, $2)'; + +CREATE FUNCTION quantize_to_scalar8(vector) RETURNS scalar8 +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchord_vector_quantize_to_scalar8_wrapper'; + +CREATE FUNCTION quantize_to_scalar8(halfvec) RETURNS scalar8 +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchord_halfvec_quantize_to_scalar8_wrapper'; + CREATE FUNCTION vchordrq_amhandler(internal) RETURNS index_am_handler IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_amhandler_wrapper'; @@ -56,6 +153,9 @@ CREATE ACCESS METHOD Vchordrqfscan TYPE INDEX HANDLER Vchordrqfscan_amhandler; CREATE OPERATOR FAMILY vector_l2_ops USING vchordrq; CREATE OPERATOR FAMILY vector_ip_ops USING vchordrq; CREATE OPERATOR FAMILY vector_cosine_ops USING vchordrq; +CREATE OPERATOR FAMILY halfvec_l2_ops USING vchordrq; +CREATE OPERATOR FAMILY halfvec_ip_ops USING vchordrq; +CREATE OPERATOR FAMILY halfvec_cosine_ops USING vchordrq; CREATE OPERATOR FAMILY vector_l2_ops USING Vchordrqfscan; CREATE OPERATOR FAMILY vector_ip_ops USING Vchordrqfscan; @@ -81,6 +181,24 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 2 <<=>> (vector, sphere_vector) FOR SEARCH, FUNCTION 1 _vchordrq_support_vector_cosine_ops(); +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING vchordrq FAMILY halfvec_l2_ops AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + OPERATOR 2 <<->> (halfvec, sphere_halfvec) FOR SEARCH, + FUNCTION 1 _vchordrq_support_halfvec_l2_ops(); + +CREATE OPERATOR CLASS halfvec_ip_ops + FOR TYPE halfvec USING vchordrq FAMILY halfvec_ip_ops AS + OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, + OPERATOR 2 <<#>> (halfvec, sphere_halfvec) FOR SEARCH, + FUNCTION 1 _vchordrq_support_halfvec_ip_ops(); + +CREATE OPERATOR CLASS halfvec_cosine_ops + FOR TYPE halfvec USING vchordrq FAMILY halfvec_cosine_ops AS + OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, + OPERATOR 2 <<=>> (halfvec, sphere_halfvec) FOR SEARCH, + FUNCTION 1 _vchordrq_support_halfvec_cosine_ops(); + CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING Vchordrqfscan FAMILY vector_l2_ops AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 0000000..af08ee7 --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1 @@ +pub mod scalar8; diff --git a/src/types/scalar8.rs b/src/types/scalar8.rs new file mode 100644 index 0000000..55c0074 --- /dev/null +++ b/src/types/scalar8.rs @@ -0,0 +1,286 @@ +use base::distance::Distance; +use base::vector::{VectorBorrowed, VectorOwned}; +use serde::{Deserialize, Serialize}; +use std::ops::RangeBounds; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Scalar8Owned { + sum_of_x2: f32, + k: f32, + b: f32, + sum_of_code: f32, + code: Vec, +} + +impl Scalar8Owned { + #[allow(dead_code)] + #[inline(always)] + pub fn new(sum_of_x2: f32, k: f32, b: f32, sum_of_code: f32, code: Vec) -> Self { + Self::new_checked(sum_of_x2, k, b, sum_of_code, code).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked( + sum_of_x2: f32, + k: f32, + b: f32, + sum_of_code: f32, + code: Vec, + ) -> Option { + if !(1..=65535).contains(&code.len()) { + return None; + } + Some(unsafe { Self::new_unchecked(sum_of_x2, k, b, sum_of_code, code) }) + } + + /// # Safety + /// + /// * `code.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked( + sum_of_x2: f32, + k: f32, + b: f32, + sum_of_code: f32, + code: Vec, + ) -> Self { + Self { + sum_of_x2, + k, + b, + sum_of_code, + code, + } + } +} + +impl VectorOwned for Scalar8Owned { + type Borrowed<'a> = Scalar8Borrowed<'a>; + + #[inline(always)] + fn as_borrowed(&self) -> Scalar8Borrowed<'_> { + Scalar8Borrowed { + sum_of_x2: self.sum_of_x2, + k: self.k, + b: self.b, + sum_of_code: self.sum_of_code, + code: self.code.as_slice(), + } + } + + #[inline(always)] + fn zero(dims: u32) -> Self { + Self { + sum_of_x2: 0.0, + k: 0.0, + b: 0.0, + sum_of_code: 0.0, + code: vec![0; dims as usize], + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Scalar8Borrowed<'a> { + sum_of_x2: f32, + k: f32, + b: f32, + sum_of_code: f32, + code: &'a [u8], +} + +impl<'a> Scalar8Borrowed<'a> { + #[inline(always)] + pub fn new(sum_of_x2: f32, k: f32, b: f32, sum_of_code: f32, code: &'a [u8]) -> Self { + Self::new_checked(sum_of_x2, k, b, sum_of_code, code).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked( + sum_of_x2: f32, + k: f32, + b: f32, + sum_of_code: f32, + code: &'a [u8], + ) -> Option { + if !(1..=65535).contains(&code.len()) { + return None; + } + Some(unsafe { Self::new_unchecked(sum_of_x2, k, b, sum_of_code, code) }) + } + + /// # Safety + /// + /// * `code.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked( + sum_of_x2: f32, + k: f32, + b: f32, + sum_of_code: f32, + code: &'a [u8], + ) -> Self { + Self { + sum_of_x2, + k, + b, + sum_of_code, + code, + } + } + + #[inline(always)] + pub fn sum_of_x2(&self) -> f32 { + self.sum_of_x2 + } + + #[inline(always)] + pub fn k(&self) -> f32 { + self.k + } + + #[inline(always)] + pub fn b(&self) -> f32 { + self.b + } + + #[inline(always)] + pub fn sum_of_code(&self) -> f32 { + self.sum_of_code + } + + #[inline(always)] + pub fn code(&self) -> &'a [u8] { + self.code + } +} + +impl VectorBorrowed for Scalar8Borrowed<'_> { + type Owned = Scalar8Owned; + + #[inline(always)] + fn dims(&self) -> u32 { + self.code.len() as u32 + } + + #[inline(always)] + fn own(&self) -> Scalar8Owned { + Scalar8Owned { + sum_of_x2: self.sum_of_x2, + k: self.k, + b: self.b, + sum_of_code: self.sum_of_code, + code: self.code.to_owned(), + } + } + + #[inline(always)] + fn norm(&self) -> f32 { + self.sum_of_x2.sqrt() + } + + #[inline(always)] + fn operator_dot(self, rhs: Self) -> Distance { + assert_eq!(self.code.len(), rhs.code.len()); + let xy = self.k * rhs.k * base::simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + + self.b * rhs.b * self.code.len() as f32 + + self.k * rhs.b * self.sum_of_code + + self.b * rhs.k * rhs.sum_of_code; + Distance::from(-xy) + } + + #[inline(always)] + fn operator_l2(self, rhs: Self) -> Distance { + assert_eq!(self.code.len(), rhs.code.len()); + let xy = self.k * rhs.k * base::simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + + self.b * rhs.b * self.code.len() as f32 + + self.k * rhs.b * self.sum_of_code + + self.b * rhs.k * rhs.sum_of_code; + let x2 = self.sum_of_x2; + let y2 = rhs.sum_of_x2; + Distance::from(x2 + y2 - 2.0 * xy) + } + + #[inline(always)] + fn operator_cos(self, rhs: Self) -> Distance { + assert_eq!(self.code.len(), rhs.code.len()); + let xy = self.k * rhs.k * base::simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + + self.b * rhs.b * self.code.len() as f32 + + self.k * rhs.b * self.sum_of_code + + self.b * rhs.k * rhs.sum_of_code; + let x2 = self.sum_of_x2; + let y2 = rhs.sum_of_x2; + Distance::from(1.0 - xy / (x2 * y2).sqrt()) + } + + #[inline(always)] + fn operator_hamming(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn operator_jaccard(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn function_normalize(&self) -> Scalar8Owned { + let l = self.sum_of_x2.sqrt(); + Scalar8Owned { + sum_of_x2: 1.0, + k: self.k / l, + b: self.b / l, + sum_of_code: self.sum_of_code, + code: self.code.to_owned(), + } + } + + fn operator_add(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_sub(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_mul(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_and(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_or(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_xor(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + #[inline(always)] + fn subvector(&self, bounds: impl RangeBounds) -> Option { + let start_bound = bounds.start_bound().map(|x| *x as usize); + let end_bound = bounds.end_bound().map(|x| *x as usize); + let code = self.code.get((start_bound, end_bound))?; + if code.is_empty() { + return None; + } + Self::Owned::new_checked( + { + // recover it as much as possible + let mut result = 0.0; + for &x in code { + let y = self.k * (x as f32) + self.b; + result += y * y; + } + result + }, + self.k, + self.b, + base::simd::u8::reduce_sum_of_x_as_u32(code) as f32, + code.to_owned(), + ) + } +} diff --git a/src/utils/infinite_byte_chunks.rs b/src/utils/infinite_byte_chunks.rs new file mode 100644 index 0000000..c61b87e --- /dev/null +++ b/src/utils/infinite_byte_chunks.rs @@ -0,0 +1,20 @@ +#[derive(Debug, Clone)] +pub struct InfiniteByteChunks { + iter: I, +} + +impl InfiniteByteChunks { + pub fn new(iter: I) -> Self { + Self { iter } + } +} + +impl, const N: usize> Iterator for InfiniteByteChunks { + type Item = [u8; N]; + + fn next(&mut self) -> Option { + Some(std::array::from_fn::(|_| { + self.iter.next().unwrap_or(0) + })) + } +} diff --git a/src/utils/k_means.rs b/src/utils/k_means.rs index ac58a08..97a810f 100644 --- a/src/utils/k_means.rs +++ b/src/utils/k_means.rs @@ -1,7 +1,7 @@ #![allow(clippy::ptr_arg)] use super::parallelism::{ParallelIterator, Parallelism}; -use base::scalar::*; +use base::simd::*; use half::f16; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -121,7 +121,7 @@ fn rabitq_index( let mut a3 = Vec::new(); let mut a4 = Vec::new(); for vectors in centroids.chunks(32) { - use quantization::fast_scan::b4::pack; + use base::simd::fast_scan::b4::pack; let code_alphas = std::array::from_fn::<_, 32, _>(|i| { if let Some(vector) = vectors.get(i) { code_alpha(vector) @@ -197,7 +197,7 @@ fn rabitq_index( ), epsilon: f32, ) -> [Distance; 32] { - use quantization::fast_scan::b4::fast_scan_b4; + use base::simd::fast_scan::b4::fast_scan_b4; let &(dis_v_2, b, k, qvector_sum, ref s) = lut; let r = fast_scan_b4(dims.div_ceil(4), t, s); std::array::from_fn(|i| { @@ -210,17 +210,17 @@ fn rabitq_index( }) } use base::distance::Distance; - use quantization::quantize; + use base::simd::quantize; let lut = { let vector = &samples[i]; let dis_v_2 = f32::reduce_sum_of_x2(vector); let (k, b, qvector) = - quantize::quantize::<15>(f32::vector_to_f32_borrowed(vector).as_ref()); + quantize::quantize(f32::vector_to_f32_borrowed(vector).as_ref(), 15.0); let qvector_sum = if vector.len() <= 4369 { - quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + u8::reduce_sum_of_x_as_u16(&qvector) as f32 } else { - quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + u8::reduce_sum_of_x_as_u32(&qvector) as f32 }; (dis_v_2, b, k, qvector_sum, gen(qvector)) }; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1b07dc6..2d9a3b7 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,3 @@ +pub mod infinite_byte_chunks; pub mod k_means; pub mod parallelism; diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index c65bd85..d292628 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -2,25 +2,25 @@ use crate::postgres::BufferWriteGuard; use crate::postgres::Relation; use crate::vchordrq::algorithm::rabitq; use crate::vchordrq::algorithm::tuples::*; -use crate::vchordrq::algorithm::vectors; use crate::vchordrq::index::am_options::Opfamily; use crate::vchordrq::types::VchordrqBuildOptions; use crate::vchordrq::types::VchordrqExternalBuildOptions; use crate::vchordrq::types::VchordrqIndexingOptions; use crate::vchordrq::types::VchordrqInternalBuildOptions; +use crate::vchordrq::types::VectorOptions; use base::distance::DistanceKind; -use base::index::VectorOptions; -use base::scalar::ScalarLike; use base::search::Pointer; +use base::simd::ScalarLike; +use base::vector::VectorBorrowed; use rand::Rng; use rkyv::ser::serializers::AllocSerializer; use std::marker::PhantomData; use std::sync::Arc; -pub trait HeapRelation { +pub trait HeapRelation { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, Vec)); + F: FnMut((Pointer, V)); fn opfamily(&self) -> Opfamily; } @@ -28,7 +28,7 @@ pub trait Reporter { fn tuples_total(&mut self, tuples_total: u64); } -pub fn build( +pub fn build, R: Reporter>( vector_options: VectorOptions, vchordrq_options: VchordrqIndexingOptions, heap_relation: T, @@ -56,13 +56,14 @@ pub fn build( let mut samples = Vec::new(); let mut number_of_samples = 0_u32; heap_relation.traverse(false, |(_, vector)| { - assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); + let vector = vector.as_borrowed(); + assert_eq!(dims, vector.dims(), "invalid vector dimensions"); if number_of_samples < max_number_of_samples { - samples.push(vector); + samples.push(V::build_to_vecf32(vector)); number_of_samples += 1; } else { let index = rand.gen_range(0..max_number_of_samples) as usize; - samples[index] = vector; + samples[index] = V::build_to_vecf32(vector); } tuples_total += 1; }); @@ -74,21 +75,22 @@ pub fn build( }; let mut meta = Tape::create(&relation, false); assert_eq!(meta.first(), 0); - let mut vectors = Tape::create(&relation, true); + let mut vectors = Tape::>::create(&relation, true); let mut pointer_of_means = Vec::>::new(); for i in 0..structures.len() { let mut level = Vec::new(); for j in 0..structures[i].len() { - let slices = vectors::vector_split(&structures[i].means[j]); - let mut chain = None; + let vector = V::build_from_vecf32(&structures[i].means[j]); + let (metadata, slices) = V::vector_split(vector.as_borrowed()); + let mut chain = Err(metadata); for i in (0..slices.len()).rev() { - chain = Some(vectors.push(&VectorTuple { + chain = Ok(vectors.push(&VectorTuple { payload: None, slice: slices[i].to_vec(), chain, })); } - level.push(chain.unwrap()); + level.push(chain.ok().unwrap()); } pointer_of_means.push(level); } @@ -224,7 +226,7 @@ impl Structure { if vector_options.dims != vector.as_borrowed().dims() { pgrx::error!("extern build: incorrect dimension, id = {id}"); } - vectors.insert(id, crate::projection::project(vector.slice())); + vectors.insert(id, crate::projection::project(vector.as_borrowed().slice())); } }); let mut children = parents diff --git a/src/vchordrq/algorithm/insert.rs b/src/vchordrq/algorithm/insert.rs index c37f7c3..ee47e9a 100644 --- a/src/vchordrq/algorithm/insert.rs +++ b/src/vchordrq/algorithm/insert.rs @@ -1,23 +1,23 @@ use crate::postgres::Relation; -use crate::vchordrq::algorithm::rabitq; use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; use base::always_equal::AlwaysEqual; use base::distance::Distance; use base::distance::DistanceKind; -use base::scalar::ScalarLike; use base::search::Pointer; +use base::vector::VectorBorrowed; use std::cmp::Reverse; use std::collections::BinaryHeap; -pub fn insert( +pub fn insert( relation: Relation, payload: Pointer, - vector: Vec, + vector: V, distance_kind: DistanceKind, in_building: bool, ) { + let vector = vector.as_borrowed(); let meta_guard = relation.read(0); let meta_tuple = meta_guard .get() @@ -26,25 +26,26 @@ pub fn insert( .expect("data corruption") .expect("data corruption"); let dims = meta_tuple.dims; - assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); - let vector = crate::projection::project(&vector); + assert_eq!(dims, vector.dims(), "invalid vector dimensions"); + let vector = V::random_projection(vector); + let vector = vector.as_borrowed(); let is_residual = meta_tuple.is_residual; let default_lut = if !is_residual { - Some(rabitq::fscan_preprocess(&vector)) + Some(V::rabitq_fscan_preprocess(vector)) } else { None }; let h0_vector = { - let slices = vectors::vector_split(&vector); - let mut chain = None; + let (metadata, slices) = V::vector_split(vector); + let mut chain = Err(metadata); for i in (0..slices.len()).rev() { - let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple { + let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple:: { slice: slices[i].to_vec(), payload: Some(payload.as_u64()), chain, }) .unwrap(); - chain = Some(append( + chain = Ok(append( relation.clone(), meta_tuple.vectors_first, &tuple, @@ -53,13 +54,13 @@ pub fn insert( true, )); } - chain.unwrap() + chain.ok().unwrap() }; let h0_payload = payload.as_u64(); let mut list = { - let Some((_, original)) = vectors::vector_dist( + let Some((_, original)) = vectors::vector_dist::( relation.clone(), - &vector, + vector, meta_tuple.mean, None, None, @@ -69,11 +70,14 @@ pub fn insert( }; (meta_tuple.first, original) }; - let make_list = |list: (u32, Option>)| { + let make_list = |list: (u32, Option)| { let mut results = Vec::new(); { let lut = if is_residual { - &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) + &V::rabitq_fscan_preprocess( + V::residual(vector, list.1.as_ref().map(|x| x.as_borrowed()).unwrap()) + .as_borrowed(), + ) } else { default_lut.as_ref().unwrap() }; @@ -114,9 +118,9 @@ pub fn insert( { while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); - let Some((Some(dis_u), original)) = vectors::vector_dist( + let Some((Some(dis_u), original)) = vectors::vector_dist::( relation.clone(), - &vector, + vector, mean, None, Some(distance_kind), @@ -134,9 +138,12 @@ pub fn insert( list = make_list(list); } let code = if is_residual { - rabitq::code(dims, &f32::vector_sub(&vector, list.1.as_ref().unwrap())) + V::rabitq_code( + dims, + V::residual(vector, list.1.as_ref().map(|x| x.as_borrowed()).unwrap()).as_borrowed(), + ) } else { - rabitq::code(dims, &vector) + V::rabitq_code(dims, vector) }; let tuple = rkyv::to_bytes::<_, 8192>(&Height0Tuple { mean: h0_vector, diff --git a/src/vchordrq/algorithm/prewarm.rs b/src/vchordrq/algorithm/prewarm.rs index 9244138..26bb986 100644 --- a/src/vchordrq/algorithm/prewarm.rs +++ b/src/vchordrq/algorithm/prewarm.rs @@ -3,7 +3,7 @@ use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; use std::fmt::Write; -pub fn prewarm(relation: Relation, height: i32) -> String { +pub fn prewarm(relation: Relation, height: i32) -> String { let mut message = String::new(); let meta_guard = relation.read(0); let meta_tuple = meta_guard @@ -21,7 +21,7 @@ pub fn prewarm(relation: Relation, height: i32) -> String { let mut results = Vec::new(); let counter = 1_usize; { - vectors::vector_warm(relation.clone(), meta_tuple.mean); + vectors::vector_warm::(relation.clone(), meta_tuple.mean); results.push(meta_tuple.first); } writeln!(message, "number of tuples: {}", results.len()).unwrap(); @@ -44,7 +44,7 @@ pub fn prewarm(relation: Relation, height: i32) -> String { .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - vectors::vector_warm(relation.clone(), h1_tuple.mean); + vectors::vector_warm::(relation.clone(), h1_tuple.mean); results.push(h1_tuple.first); } current = h1_guard.get().get_opaque().next; diff --git a/src/vchordrq/algorithm/rabitq.rs b/src/vchordrq/algorithm/rabitq.rs index b3746b8..b7b3858 100644 --- a/src/vchordrq/algorithm/rabitq.rs +++ b/src/vchordrq/algorithm/rabitq.rs @@ -1,5 +1,5 @@ use base::distance::{Distance, DistanceKind}; -use base::scalar::ScalarLike; +use base::simd::ScalarLike; #[derive(Debug, Clone)] pub struct Code { @@ -12,7 +12,7 @@ pub struct Code { impl Code { pub fn t(&self) -> Vec { - use quantization::utils::InfiniteByteChunks; + use crate::utils::infinite_byte_chunks::InfiniteByteChunks; let mut result = Vec::new(); for x in InfiniteByteChunks::<_, 64>::new(self.signs.iter().copied()) .take(self.signs.len().div_ceil(64)) @@ -62,13 +62,13 @@ pub fn code(dims: u32, vector: &[f32]) -> Code { pub type Lut = (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)); pub fn fscan_preprocess(vector: &[f32]) -> Lut { - use quantization::quantize; + use base::simd::quantize; let dis_v_2 = f32::reduce_sum_of_x2(vector); - let (k, b, qvector) = quantize::quantize::<15>(vector); + let (k, b, qvector) = quantize::quantize(vector, 15.0); let qvector_sum = if vector.len() <= 4369 { - quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 } else { - quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32 }; (dis_v_2, b, k, qvector_sum, binarize(&qvector)) } @@ -119,34 +119,10 @@ fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { (t0, t1, t2, t3) } -#[detect::multiversion(v2, fallback)] fn asymmetric_binary_dot_product(x: &[u64], y: &(Vec, Vec, Vec, Vec)) -> u32 { - assert_eq!(x.len(), y.0.len()); - assert_eq!(x.len(), y.1.len()); - assert_eq!(x.len(), y.2.len()); - assert_eq!(x.len(), y.3.len()); - let n = x.len(); - let (mut t0, mut t1, mut t2, mut t3) = (0, 0, 0, 0); - for i in 0..n { - t0 += (x[i] & y.0[i]).count_ones(); - } - for i in 0..n { - t1 += (x[i] & y.1[i]).count_ones(); - } - for i in 0..n { - t2 += (x[i] & y.2[i]).count_ones(); - } - for i in 0..n { - t3 += (x[i] & y.3[i]).count_ones(); - } + let t0 = base::simd::bit::sum_of_and(x, &y.0); + let t1 = base::simd::bit::sum_of_and(x, &y.1); + let t2 = base::simd::bit::sum_of_and(x, &y.2); + let t3 = base::simd::bit::sum_of_and(x, &y.3); (t0 << 0) + (t1 << 1) + (t2 << 2) + (t3 << 3) } - -pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { - match d { - DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), - DistanceKind::Dot => Distance::from_f32(-f32::reduce_sum_of_xy(lhs, rhs)), - DistanceKind::Hamming => unimplemented!(), - DistanceKind::Jaccard => unimplemented!(), - } -} diff --git a/src/vchordrq/algorithm/scan.rs b/src/vchordrq/algorithm/scan.rs index 14ecba1..df4d93d 100644 --- a/src/vchordrq/algorithm/scan.rs +++ b/src/vchordrq/algorithm/scan.rs @@ -1,23 +1,23 @@ use crate::postgres::Relation; -use crate::vchordrq::algorithm::rabitq; use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; use base::always_equal::AlwaysEqual; use base::distance::Distance; use base::distance::DistanceKind; -use base::scalar::ScalarLike; use base::search::Pointer; +use base::vector::VectorBorrowed; use std::cmp::Reverse; use std::collections::BinaryHeap; -pub fn scan( +pub fn scan( relation: Relation, - vector: Vec, + vector: V, distance_kind: DistanceKind, probes: Vec, epsilon: f32, ) -> impl Iterator { + let vector = vector.as_borrowed(); let meta_guard = relation.read(0); let meta_tuple = meta_guard .get() @@ -27,19 +27,19 @@ pub fn scan( .expect("data corruption"); let dims = meta_tuple.dims; let height_of_root = meta_tuple.height_of_root; - assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); + assert_eq!(dims, vector.dims(), "invalid vector dimensions"); assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes"); - let vector = crate::projection::project(&vector); + let vector = V::random_projection(vector); let is_residual = meta_tuple.is_residual; let default_lut = if !is_residual { - Some(rabitq::fscan_preprocess(&vector)) + Some(V::rabitq_fscan_preprocess(vector.as_borrowed())) } else { None }; let mut lists: Vec<_> = vec![{ - let Some((_, original)) = vectors::vector_dist( + let Some((_, original)) = vectors::vector_dist::( relation.clone(), - &vector, + vector.as_borrowed(), meta_tuple.mean, None, None, @@ -49,11 +49,17 @@ pub fn scan( }; (meta_tuple.first, original) }]; - let make_lists = |lists: Vec<(u32, Option>)>, probes| { + let make_lists = |lists: Vec<(u32, Option)>, probes| { let mut results = Vec::new(); for list in lists { let lut = if is_residual { - &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) + &V::rabitq_fscan_preprocess( + V::residual( + vector.as_borrowed(), + list.1.as_ref().map(|x| x.as_borrowed()).unwrap(), + ) + .as_borrowed(), + ) } else { default_lut.as_ref().unwrap() }; @@ -94,9 +100,9 @@ pub fn scan( std::iter::from_fn(|| { while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); - let Some((Some(dis_u), original)) = vectors::vector_dist( + let Some((Some(dis_u), original)) = vectors::vector_dist::( relation.clone(), - &vector, + vector.as_borrowed(), mean, None, Some(distance_kind), @@ -119,7 +125,13 @@ pub fn scan( let mut results = Vec::new(); for list in lists { let lut = if is_residual { - &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) + &V::rabitq_fscan_preprocess( + V::residual( + vector.as_borrowed(), + list.1.as_ref().map(|x| x.as_borrowed()).unwrap(), + ) + .as_borrowed(), + ) } else { default_lut.as_ref().unwrap() }; @@ -160,9 +172,9 @@ pub fn scan( std::iter::from_fn(move || { while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap(); - let Some((Some(dis_u), _)) = vectors::vector_dist( + let Some((Some(dis_u), _)) = vectors::vector_dist::( relation.clone(), - &vector, + vector.as_borrowed(), mean, Some(pay_u), Some(distance_kind), diff --git a/src/vchordrq/algorithm/tuples.rs b/src/vchordrq/algorithm/tuples.rs index cf6236f..40a795c 100644 --- a/src/vchordrq/algorithm/tuples.rs +++ b/src/vchordrq/algorithm/tuples.rs @@ -1,4 +1,233 @@ -use rkyv::{Archive, Deserialize, Serialize}; +use super::rabitq::{self, Code, Lut}; +use crate::vchordrq::types::OwnedVector; +use base::distance::DistanceKind; +use base::simd::ScalarLike; +use base::vector::{VectOwned, VectorOwned}; +use half::f16; +use rkyv::{Archive, ArchiveUnsized, CheckBytes, Deserialize, Serialize}; + +pub trait Vector: VectorOwned { + type Metadata: Copy + + Serialize< + rkyv::ser::serializers::CompositeSerializer< + rkyv::ser::serializers::AlignedSerializer, + rkyv::ser::serializers::FallbackScratch< + rkyv::ser::serializers::HeapScratch<8192>, + rkyv::ser::serializers::AllocScratch, + >, + rkyv::ser::serializers::SharedSerializeMap, + >, + > + for<'a> CheckBytes>; + type Element: Copy + + Serialize< + rkyv::ser::serializers::CompositeSerializer< + rkyv::ser::serializers::AlignedSerializer, + rkyv::ser::serializers::FallbackScratch< + rkyv::ser::serializers::HeapScratch<8192>, + rkyv::ser::serializers::AllocScratch, + >, + rkyv::ser::serializers::SharedSerializeMap, + >, + > + for<'a> CheckBytes> + + Archive; + + fn metadata_from_archived( + archived: &::Archived, + ) -> Self::Metadata; + + fn vector_split(vector: Self::Borrowed<'_>) -> (Self::Metadata, Vec<&[Self::Element]>); + fn vector_merge(metadata: Self::Metadata, slice: &[Self::Element]) -> Self; + fn from_owned(vector: OwnedVector) -> Self; + + type DistanceAccumulator; + fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator; + fn distance_next( + accumulator: &mut Self::DistanceAccumulator, + left: &[Self::Element], + right: &[Self::Element], + ); + fn distance_end( + accumulator: Self::DistanceAccumulator, + left: Self::Metadata, + right: Self::Metadata, + ) -> f32; + + fn random_projection(vector: Self::Borrowed<'_>) -> Self; + + fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self; + + fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut; + + fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code; + + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec; + + fn build_from_vecf32(x: &[f32]) -> Self; +} + +impl Vector for VectOwned { + type Metadata = (); + + type Element = f32; + + fn metadata_from_archived(_: &::Archived) -> Self::Metadata { + () + } + + fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f32]>) { + let vector = vector.slice(); + ( + (), + match vector.len() { + 0..=960 => vec![vector], + 961..=1280 => vec![&vector[..640], &vector[640..]], + 1281.. => vector.chunks(1920).collect(), + }, + ) + } + + fn vector_merge((): Self::Metadata, slice: &[Self::Element]) -> Self { + VectOwned::new(slice.to_vec()) + } + + fn from_owned(vector: OwnedVector) -> Self { + match vector { + OwnedVector::Vecf32(x) => x, + _ => unreachable!(), + } + } + + type DistanceAccumulator = (DistanceKind, f32); + fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator { + (distance_kind, 0.0) + } + fn distance_next( + accumulator: &mut Self::DistanceAccumulator, + left: &[Self::Element], + right: &[Self::Element], + ) { + match accumulator.0 { + DistanceKind::L2 => accumulator.1 += f32::reduce_sum_of_d2(left, right), + DistanceKind::Dot => accumulator.1 += -f32::reduce_sum_of_xy(left, right), + DistanceKind::Hamming => unreachable!(), + DistanceKind::Jaccard => unreachable!(), + } + } + fn distance_end( + accumulator: Self::DistanceAccumulator, + (): Self::Metadata, + (): Self::Metadata, + ) -> f32 { + accumulator.1 + } + + fn random_projection(vector: Self::Borrowed<'_>) -> Self { + Self::new(crate::projection::project(vector.slice())) + } + + fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self { + Self::new(ScalarLike::vector_sub(vector.slice(), center.slice())) + } + + fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { + rabitq::fscan_preprocess(vector.slice()) + } + + fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { + rabitq::code(dims, vector.slice()) + } + + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { + vector.slice().to_vec() + } + + fn build_from_vecf32(x: &[f32]) -> Self { + Self::new(x.to_vec()) + } +} + +impl Vector for VectOwned { + type Metadata = (); + + type Element = f16; + + fn metadata_from_archived(_: &::Archived) -> Self::Metadata { + () + } + + fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f16]>) { + let vector = vector.slice(); + ( + (), + match vector.len() { + 0..=1920 => vec![vector], + 1921..=2560 => vec![&vector[..1280], &vector[1280..]], + 2561.. => vector.chunks(3840).collect(), + }, + ) + } + + fn vector_merge((): Self::Metadata, slice: &[Self::Element]) -> Self { + VectOwned::new(slice.to_vec()) + } + + fn from_owned(vector: OwnedVector) -> Self { + match vector { + OwnedVector::Vecf16(x) => x, + _ => unreachable!(), + } + } + + type DistanceAccumulator = (DistanceKind, f32); + fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator { + (distance_kind, 0.0) + } + fn distance_next( + accumulator: &mut Self::DistanceAccumulator, + left: &[Self::Element], + right: &[Self::Element], + ) { + match accumulator.0 { + DistanceKind::L2 => accumulator.1 += f16::reduce_sum_of_d2(left, right), + DistanceKind::Dot => accumulator.1 += -f16::reduce_sum_of_xy(left, right), + DistanceKind::Hamming => unreachable!(), + DistanceKind::Jaccard => unreachable!(), + } + } + fn distance_end( + accumulator: Self::DistanceAccumulator, + (): Self::Metadata, + (): Self::Metadata, + ) -> f32 { + accumulator.1 + } + + fn random_projection(vector: Self::Borrowed<'_>) -> Self { + Self::new(f16::vector_from_f32(&crate::projection::project( + &f16::vector_to_f32(vector.slice()), + ))) + } + + fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self { + Self::new(ScalarLike::vector_sub(vector.slice(), center.slice())) + } + + fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { + rabitq::fscan_preprocess(&f16::vector_to_f32(vector.slice())) + } + + fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { + rabitq::code(dims, &f16::vector_to_f32(vector.slice())) + } + + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { + f16::vector_to_f32(vector.slice()) + } + + fn build_from_vecf32(x: &[f32]) -> Self { + Self::new(f16::vector_from_f32(x)) + } +} #[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] #[archive(check_bytes)] @@ -15,10 +244,10 @@ pub struct MetaTuple { #[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] #[archive(check_bytes)] -pub struct VectorTuple { - pub slice: Vec, +pub struct VectorTuple { + pub slice: Vec, pub payload: Option, - pub chain: Option<(u32, u16)>, + pub chain: Result<(u32, u16), V::Metadata>, } #[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] diff --git a/src/vchordrq/algorithm/vacuum.rs b/src/vchordrq/algorithm/vacuum.rs index 2737702..2b219c4 100644 --- a/src/vchordrq/algorithm/vacuum.rs +++ b/src/vchordrq/algorithm/vacuum.rs @@ -2,7 +2,7 @@ use crate::postgres::Relation; use crate::vchordrq::algorithm::tuples::*; use base::search::Pointer; -pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) -> bool) { +pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) -> bool) { // step 1: vacuum height_0_tuple { let meta_guard = relation.read(0); @@ -78,8 +78,8 @@ pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) - let Some(vector_tuple) = read.get().get(i) else { continue; }; - let vector_tuple = rkyv::check_archived_root::(vector_tuple) - .expect("data corruption"); + let vector_tuple = + unsafe { rkyv::archived_root::>(vector_tuple) }; if let Some(payload) = vector_tuple.payload.as_ref().copied() { if callback(Pointer::new(payload)) { break 'flag true; @@ -95,8 +95,8 @@ pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) - let Some(vector_tuple) = write.get().get(i) else { continue; }; - let vector_tuple = rkyv::check_archived_root::(vector_tuple) - .expect("data corruption"); + let vector_tuple = + unsafe { rkyv::archived_root::>(vector_tuple) }; if let Some(payload) = vector_tuple.payload.as_ref().copied() { if callback(Pointer::new(payload)) { write.get_mut().free(i); diff --git a/src/vchordrq/algorithm/vectors.rs b/src/vchordrq/algorithm/vectors.rs index c1f8627..6a23f74 100644 --- a/src/vchordrq/algorithm/vectors.rs +++ b/src/vchordrq/algorithm/vectors.rs @@ -1,34 +1,26 @@ +use super::tuples::Vector; use crate::postgres::Relation; -use crate::vchordrq::algorithm::rabitq::distance; use crate::vchordrq::algorithm::tuples::VectorTuple; use base::distance::Distance; use base::distance::DistanceKind; -pub fn vector_split(vector: &[f32]) -> Vec<&[f32]> { - match vector.len() { - 0..=960 => vec![vector], - 961..=1280 => vec![&vector[..640], &vector[640..]], - 1281.. => vector.chunks(1920).collect(), - } -} - -pub fn vector_dist( +pub fn vector_dist( relation: Relation, - vector: &[f32], + vector: V::Borrowed<'_>, mean: (u32, u16), payload: Option, for_distance: Option, for_original: bool, -) -> Option<(Option, Option>)> { +) -> Option<(Option, Option)> { if for_distance.is_none() && !for_original && payload.is_none() { return Some((None, None)); } - let slices = vector_split(vector); - let mut cursor = Some(mean); - let mut result = 0.0f32; + let (left_metadata, slices) = V::vector_split(vector); + let mut cursor = Ok(mean); + let mut result = for_distance.map(|x| V::distance_begin(x)); let mut original = Vec::new(); for i in 0..slices.len() { - let Some(mean) = cursor else { + let Ok(mean) = cursor else { // fails consistency check return None; }; @@ -37,40 +29,47 @@ pub fn vector_dist( // fails consistency check return None; }; - let vector_tuple = - rkyv::check_archived_root::(vector_tuple).expect("data corruption"); + let vector_tuple = unsafe { rkyv::archived_root::>(vector_tuple) }; if vector_tuple.payload != payload { // fails consistency check return None; } - if let Some(distance_kind) = for_distance { - result += distance(distance_kind, slices[i], &vector_tuple.slice).to_f32(); + if let Some(result) = result.as_mut() { + V::distance_next(result, slices[i], &vector_tuple.slice); } if for_original { original.extend_from_slice(&vector_tuple.slice); } - cursor = vector_tuple.chain.as_ref().cloned(); + cursor = match &vector_tuple.chain { + rkyv::result::ArchivedResult::Ok(x) => Ok(*x), + rkyv::result::ArchivedResult::Err(x) => Err(V::metadata_from_archived(x)), + }; } + let Err(right_metadata) = cursor else { + panic!("data corruption") + }; Some(( - for_distance.map(|_| Distance::from_f32(result)), - for_original.then_some(original), + result.map(|r| Distance::from_f32(V::distance_end(r, left_metadata, right_metadata))), + for_original.then(|| V::vector_merge(right_metadata, &original)), )) } -pub fn vector_warm(relation: Relation, mean: (u32, u16)) { - let mut cursor = Some(mean); - while let Some(mean) = cursor { +pub fn vector_warm(relation: Relation, mean: (u32, u16)) { + let mut cursor = Ok(mean); + while let Ok(mean) = cursor { let vector_guard = relation.read(mean.0); let Some(vector_tuple) = vector_guard.get().get(mean.1) else { // fails consistency check return; }; - let vector_tuple = - rkyv::check_archived_root::(vector_tuple).expect("data corruption"); + let vector_tuple = unsafe { rkyv::archived_root::>(vector_tuple) }; if vector_tuple.payload.is_some() { // fails consistency check return; } - cursor = vector_tuple.chain.as_ref().cloned(); + cursor = match &vector_tuple.chain { + rkyv::result::ArchivedResult::Ok(x) => Ok(*x), + rkyv::result::ArchivedResult::Err(x) => Err(V::metadata_from_archived(x)), + }; } } diff --git a/src/vchordrq/index/am.rs b/src/vchordrq/index/am.rs index d3c3e8b..e9182c7 100644 --- a/src/vchordrq/index/am.rs +++ b/src/vchordrq/index/am.rs @@ -1,11 +1,15 @@ use crate::postgres::Relation; use crate::vchordrq::algorithm; use crate::vchordrq::algorithm::build::{HeapRelation, Reporter}; +use crate::vchordrq::algorithm::tuples::Vector; use crate::vchordrq::index::am_options::{Opfamily, Reloption}; use crate::vchordrq::index::am_scan::Scanner; use crate::vchordrq::index::utils::{ctid_to_pointer, pointer_to_ctid}; use crate::vchordrq::index::{am_options, am_scan}; +use crate::vchordrq::types::VectorKind; use base::search::Pointer; +use base::vector::VectOwned; +use half::f16; use pgrx::datum::Internal; use pgrx::pg_sys::Datum; @@ -163,17 +167,17 @@ pub unsafe extern "C" fn ambuild( index_info: *mut pgrx::pg_sys::IndexInfo, opfamily: Opfamily, } - impl HeapRelation for Heap { + impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, Vec)), + F: FnMut((Pointer, V)), { pub struct State<'a, F> { pub this: &'a Heap, pub callback: F, } #[pgrx::pg_guard] - unsafe extern "C" fn call( + unsafe extern "C" fn call( _index: pgrx::pg_sys::Relation, ctid: pgrx::pg_sys::ItemPointer, values: *mut Datum, @@ -181,21 +185,14 @@ pub unsafe extern "C" fn ambuild( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((Pointer, Vec)), + F: FnMut((Pointer, V)), { - use base::vector::OwnedVector; let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; let pointer = unsafe { ctid_to_pointer(ctid.read()) }; if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), - }; - (state.callback)((pointer, vector.into_vec())); + (state.callback)((pointer, V::from_owned(vector))); } } let table_am = unsafe { &*(*self.heap).rd_tableam }; @@ -213,7 +210,7 @@ pub unsafe extern "C" fn ambuild( progress, 0, pgrx::pg_sys::InvalidBlockNumber, - Some(call::), + Some(call::), (&mut state) as *mut State as *mut _, std::ptr::null_mut(), ); @@ -246,13 +243,22 @@ pub unsafe extern "C" fn ambuild( }; let mut reporter = PgReporter {}; let index_relation = unsafe { Relation::new(index) }; - algorithm::build::build( - vector_options, - vchordrq_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ); + match opfamily.vector_kind() { + VectorKind::Vecf32 => algorithm::build::build::, Heap, _>( + vector_options, + vchordrq_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ), + VectorKind::Vecf16 => algorithm::build::build::, Heap, _>( + vector_options, + vchordrq_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ), + } if let Some(leader) = unsafe { VchordrqLeader::enter(heap, index, (*index_info).ii_Concurrent) } { unsafe { @@ -283,17 +289,42 @@ pub unsafe extern "C" fn ambuild( } else { let mut indtuples = 0; reporter.tuples_done(indtuples); - heap_relation.traverse(true, |(payload, vector)| { - algorithm::insert::insert( - index_relation.clone(), - payload, - vector, - opfamily.distance_kind(), - true, - ); - indtuples += 1; - reporter.tuples_done(indtuples); - }); + match opfamily.vector_kind() { + VectorKind::Vecf32 => { + HeapRelation::>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::>( + unsafe { Relation::new(index) }, + pointer, + vector, + opfamily.distance_kind(), + true, + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }, + ); + } + VectorKind::Vecf16 => { + HeapRelation::>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::>( + unsafe { Relation::new(index) }, + pointer, + vector, + opfamily.distance_kind(), + true, + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }, + ); + } + } } unsafe { pgrx::pgbox::PgBox::::alloc0().into_pg() } } @@ -540,17 +571,17 @@ unsafe fn parallel_build( opfamily: Opfamily, scan: *mut pgrx::pg_sys::TableScanDescData, } - impl HeapRelation for Heap { + impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, Vec)), + F: FnMut((Pointer, V)), { pub struct State<'a, F> { pub this: &'a Heap, pub callback: F, } #[pgrx::pg_guard] - unsafe extern "C" fn call( + unsafe extern "C" fn call( _index: pgrx::pg_sys::Relation, ctid: pgrx::pg_sys::ItemPointer, values: *mut Datum, @@ -558,21 +589,14 @@ unsafe fn parallel_build( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((Pointer, Vec)), + F: FnMut((Pointer, V)), { - use base::vector::OwnedVector; let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; let pointer = unsafe { ctid_to_pointer(ctid.read()) }; if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), - }; - (state.callback)((pointer, vector.into_vec())); + (state.callback)((pointer, V::from_owned(vector))); } } let table_am = unsafe { &*(*self.heap).rd_tableam }; @@ -590,7 +614,7 @@ unsafe fn parallel_build( progress, 0, pgrx::pg_sys::InvalidBlockNumber, - Some(call::), + Some(call::), (&mut state) as *mut State as *mut _, self.scan, ); @@ -612,28 +636,54 @@ unsafe fn parallel_build( opfamily, scan, }; - heap_relation.traverse(reporter.is_some(), |(payload, vector)| { - algorithm::insert::insert( - index_relation.clone(), - payload, - vector, - opfamily.distance_kind(), - true, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).indtuples += 1; - indtuples = (*vchordrqshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); - } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); - } + match opfamily.vector_kind() { + VectorKind::Vecf32 => { + HeapRelation::>::traverse(&heap_relation, true, |(pointer, vector)| { + algorithm::insert::insert::>( + index_relation.clone(), + pointer, + vector, + opfamily.distance_kind(), + true, + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); } - }); - + VectorKind::Vecf16 => { + HeapRelation::>::traverse(&heap_relation, true, |(pointer, vector)| { + algorithm::insert::insert::>( + index_relation.clone(), + pointer, + vector, + opfamily.distance_kind(), + true, + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + } + } unsafe { pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); (*vchordrqshared).nparticipantsdone += 1; @@ -658,23 +708,26 @@ pub unsafe extern "C" fn aminsert( _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use base::vector::OwnedVector; let opfamily = unsafe { am_options::opfamily(index) }; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), - }; let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - algorithm::insert::insert( - unsafe { Relation::new(index) }, - pointer, - vector.into_vec(), - opfamily.distance_kind(), - ); + match opfamily.vector_kind() { + VectorKind::Vecf32 => algorithm::insert::insert::>( + unsafe { Relation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + opfamily.distance_kind(), + false, + ), + VectorKind::Vecf16 => algorithm::insert::insert::>( + unsafe { Relation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + opfamily.distance_kind(), + false, + ), + } } false } @@ -691,24 +744,26 @@ pub unsafe extern "C" fn aminsert( _index_unchanged: bool, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use base::vector::OwnedVector; let opfamily = unsafe { am_options::opfamily(index) }; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), - }; let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - algorithm::insert::insert( - unsafe { Relation::new(index) }, - pointer, - vector.into_vec(), - opfamily.distance_kind(), - false, - ); + match opfamily.vector_kind() { + VectorKind::Vecf32 => algorithm::insert::insert::>( + unsafe { Relation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + opfamily.distance_kind(), + false, + ), + VectorKind::Vecf16 => algorithm::insert::insert::>( + unsafe { Relation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + opfamily.distance_kind(), + false, + ), + } } false } @@ -833,15 +888,25 @@ pub unsafe extern "C" fn ambulkdelete( pgrx::pg_sys::palloc0(size_of::()).cast() }; } + let opfamily = unsafe { am_options::opfamily((*info).index) }; let callback = callback.unwrap(); let callback = |p: Pointer| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; - algorithm::vacuum::vacuum( - unsafe { Relation::new((*info).index) }, - || unsafe { - pgrx::pg_sys::vacuum_delay_point(); - }, - callback, - ); + match opfamily.vector_kind() { + VectorKind::Vecf32 => algorithm::vacuum::vacuum::>( + unsafe { Relation::new((*info).index) }, + || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }, + callback, + ), + VectorKind::Vecf16 => algorithm::vacuum::vacuum::>( + unsafe { Relation::new((*info).index) }, + || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }, + callback, + ), + } stats } diff --git a/src/vchordrq/index/am_options.rs b/src/vchordrq/index/am_options.rs index 971273f..a357da4 100644 --- a/src/vchordrq/index/am_options.rs +++ b/src/vchordrq/index/am_options.rs @@ -1,10 +1,13 @@ +use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput; +use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecOutput; use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use crate::datatype::typmod::Typmod; use crate::vchordrq::types::VchordrqIndexingOptions; +use crate::vchordrq::types::VectorOptions; +use crate::vchordrq::types::{BorrowedVector, OwnedVector, VectorKind}; use base::distance::*; -use base::index::*; -use base::vector::*; +use base::vector::VectorBorrowed; use pgrx::datum::FromDatum; use pgrx::heap_tuple::PgHeapTuple; use serde::Deserialize; @@ -26,7 +29,7 @@ impl Reloption { }]; unsafe fn options(&self) -> &CStr { unsafe { - let ptr = std::ptr::addr_of!(*self) + let ptr = (&raw const *self) .cast::() .offset(self.options as _); CStr::from_ptr(ptr) @@ -56,6 +59,9 @@ fn convert_name_to_vd(name: &str) -> Option<(VectorKind, PgDistanceKind)> { Some("vector_l2") => Some((VectorKind::Vecf32, PgDistanceKind::L2)), Some("vector_ip") => Some((VectorKind::Vecf32, PgDistanceKind::Dot)), Some("vector_cosine") => Some((VectorKind::Vecf32, PgDistanceKind::Cos)), + Some("halfvec_l2") => Some((VectorKind::Vecf16, PgDistanceKind::L2)), + Some("halfvec_ip") => Some((VectorKind::Vecf16, PgDistanceKind::Dot)), + Some("halfvec_cosine") => Some((VectorKind::Vecf16, PgDistanceKind::Cos)), _ => None, } } @@ -130,7 +136,10 @@ impl Opfamily { let vector = unsafe { PgvectorVectorInput::from_datum(datum, false).unwrap() }; self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed())) } - _ => unreachable!(), + VectorKind::Vecf16 => { + let vector = unsafe { PgvectorHalfvecInput::from_datum(datum, false).unwrap() }; + self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed())) + } }; Some(vector) } @@ -148,7 +157,10 @@ impl Opfamily { .get_by_index::(NonZero::new(1).unwrap()) .unwrap() .map(|vector| self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed()))), - _ => unreachable!(), + VectorKind::Vecf16 => tuple + .get_by_index::(NonZero::new(1).unwrap()) + .unwrap() + .map(|vector| self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed()))), }; let radius = tuple.get_by_index::(NonZero::new(2).unwrap()).unwrap(); (center, radius) @@ -161,8 +173,6 @@ impl Opfamily { (B::Vecf32(x), PgDistanceKind::Dot) => O::Vecf32(x.own()), (B::Vecf32(x), PgDistanceKind::Cos) => O::Vecf32(x.function_normalize()), (B::Vecf16(x), _) => O::Vecf16(x.own()), - (B::SVecf32(x), _) => O::SVecf32(x.own()), - (B::BVector(x), _) => O::BVector(x.own()), } } pub fn process(self, x: Distance) -> f32 { @@ -175,6 +185,9 @@ impl Opfamily { pub fn distance_kind(self) -> DistanceKind { self.pg_distance.to_distance() } + pub fn vector_kind(self) -> VectorKind { + self.vector + } } pub unsafe fn opfamily(index: pgrx::pg_sys::Relation) -> Opfamily { diff --git a/src/vchordrq/index/am_scan.rs b/src/vchordrq/index/am_scan.rs index e97d352..1b78ff0 100644 --- a/src/vchordrq/index/am_scan.rs +++ b/src/vchordrq/index/am_scan.rs @@ -1,12 +1,16 @@ use super::am_options::Opfamily; use crate::postgres::Relation; use crate::vchordrq::algorithm::scan::scan; +use crate::vchordrq::algorithm::tuples::Vector; use crate::vchordrq::gucs::executing::epsilon; use crate::vchordrq::gucs::executing::max_scan_tuples; use crate::vchordrq::gucs::executing::probes; +use crate::vchordrq::types::OwnedVector; +use crate::vchordrq::types::VectorKind; use base::distance::Distance; use base::search::*; -use base::vector::*; +use base::vector::VectOwned; +use half::f16; pub enum Scanner { Initial { @@ -34,7 +38,7 @@ pub fn scan_build( for orderby_vector in orderbys { if pair.is_none() { pair = orderby_vector; - } else if orderby_vector.is_some() && pair != orderby_vector { + } else if orderby_vector.is_some() { pgrx::error!("vector search with multiple vectors is not supported"); } } @@ -42,10 +46,6 @@ pub fn scan_build( if pair.is_none() { pair = sphere_vector; threshold = sphere_threshold; - } else if pair == sphere_vector { - if threshold.is_none() || sphere_threshold < threshold { - threshold = sphere_threshold; - } } else { recheck = true; break; @@ -74,28 +74,46 @@ pub fn scan_next(scanner: &mut Scanner, relation: Relation) -> Option<(Pointer, } = scanner { if let Some((vector, opfamily)) = vector.as_ref() { - let vbase = scan( - relation, - match vector { - OwnedVector::Vecf32(x) => x.slice().to_vec(), - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), - }, - opfamily.distance_kind(), - probes(), - epsilon(), - ); - *scanner = Scanner::Vbase { - vbase: if let Some(max_scan_tuples) = max_scan_tuples() { - Box::new(vbase.take(max_scan_tuples as usize)) - } else { - Box::new(vbase) - }, - threshold: *threshold, - recheck: *recheck, - opfamily: *opfamily, - }; + match opfamily.vector_kind() { + VectorKind::Vecf32 => { + let vbase = scan::>( + relation, + VectOwned::::from_owned(vector.clone()), + opfamily.distance_kind(), + probes(), + epsilon(), + ); + *scanner = Scanner::Vbase { + vbase: if let Some(max_scan_tuples) = max_scan_tuples() { + Box::new(vbase.take(max_scan_tuples as usize)) + } else { + Box::new(vbase) + }, + threshold: *threshold, + recheck: *recheck, + opfamily: *opfamily, + }; + } + VectorKind::Vecf16 => { + let vbase = scan::>( + relation, + VectOwned::::from_owned(vector.clone()), + opfamily.distance_kind(), + probes(), + epsilon(), + ); + *scanner = Scanner::Vbase { + vbase: if let Some(max_scan_tuples) = max_scan_tuples() { + Box::new(vbase.take(max_scan_tuples as usize)) + } else { + Box::new(vbase) + }, + threshold: *threshold, + recheck: *recheck, + opfamily: *opfamily, + }; + } + } } else { *scanner = Scanner::Empty {}; } diff --git a/src/vchordrq/index/functions.rs b/src/vchordrq/index/functions.rs index 40df8b5..05f348f 100644 --- a/src/vchordrq/index/functions.rs +++ b/src/vchordrq/index/functions.rs @@ -1,5 +1,9 @@ +use super::am_options; use crate::postgres::Relation; use crate::vchordrq::algorithm::prewarm::prewarm; +use crate::vchordrq::types::VectorKind; +use base::vector::VectOwned; +use half::f16; use pgrx::pg_sys::Oid; use pgrx_catalog::{PgAm, PgClass}; @@ -18,7 +22,11 @@ fn _vchordrq_prewarm(indexrelid: Oid, height: i32) -> String { } let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::ShareLock as _) }; let relation = unsafe { Relation::new(index) }; - let message = prewarm(relation, height); + let opfamily = unsafe { am_options::opfamily(index) }; + let message = match opfamily.vector_kind() { + VectorKind::Vecf32 => prewarm::>(relation, height), + VectorKind::Vecf16 => prewarm::>(relation, height), + }; unsafe { pgrx::pg_sys::index_close(index, pgrx::pg_sys::ShareLock as _); } diff --git a/src/vchordrq/index/opclass.rs b/src/vchordrq/index/opclass.rs index e71da44..a2dc861 100644 --- a/src/vchordrq/index/opclass.rs +++ b/src/vchordrq/index/opclass.rs @@ -12,3 +12,18 @@ fn _vchordrq_support_vector_ip_ops() -> String { fn _vchordrq_support_vector_cosine_ops() -> String { "vector_cosine_ops".to_string() } + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrq_support_halfvec_l2_ops() -> String { + "halfvec_l2_ops".to_string() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrq_support_halfvec_ip_ops() -> String { + "halfvec_ip_ops".to_string() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrq_support_halfvec_cosine_ops() -> String { + "halfvec_cosine_ops".to_string() +} diff --git a/src/vchordrq/types.rs b/src/vchordrq/types.rs index 2301150..0e1bdc0 100644 --- a/src/vchordrq/types.rs +++ b/src/vchordrq/types.rs @@ -1,3 +1,6 @@ +use base::distance::DistanceKind; +use base::vector::{VectBorrowed, VectOwned}; +use half::f16; use serde::{Deserialize, Serialize}; use validator::{Validate, ValidationError, ValidationErrors}; @@ -95,3 +98,47 @@ impl VchordrqIndexingOptions { false } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum OwnedVector { + Vecf32(VectOwned), + Vecf16(VectOwned), +} + +#[derive(Debug, Clone, Copy)] +pub enum BorrowedVector<'a> { + Vecf32(VectBorrowed<'a, f32>), + Vecf16(VectBorrowed<'a, f16>), +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum VectorKind { + Vecf32, + Vecf16, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +#[validate(schema(function = "Self::validate_self"))] +pub struct VectorOptions { + #[validate(range(min = 1, max = 1_048_575))] + #[serde(rename = "dimensions")] + pub dims: u32, + #[serde(rename = "vector")] + pub v: VectorKind, + #[serde(rename = "distance")] + pub d: DistanceKind, +} + +impl VectorOptions { + pub fn validate_self(&self) -> Result<(), ValidationError> { + match (self.v, self.d, self.dims) { + (VectorKind::Vecf32, DistanceKind::L2, 1..65536) => Ok(()), + (VectorKind::Vecf32, DistanceKind::Dot, 1..65536) => Ok(()), + (VectorKind::Vecf16, DistanceKind::L2, 1..65536) => Ok(()), + (VectorKind::Vecf16, DistanceKind::Dot, 1..65536) => Ok(()), + _ => Err(ValidationError::new("not valid vector options")), + } + } +} diff --git a/src/vchordrqfscan/algorithm/build.rs b/src/vchordrqfscan/algorithm/build.rs index 0528a5c..25a6611 100644 --- a/src/vchordrqfscan/algorithm/build.rs +++ b/src/vchordrqfscan/algorithm/build.rs @@ -7,10 +7,10 @@ use crate::vchordrqfscan::types::VchordrqfscanBuildOptions; use crate::vchordrqfscan::types::VchordrqfscanExternalBuildOptions; use crate::vchordrqfscan::types::VchordrqfscanIndexingOptions; use crate::vchordrqfscan::types::VchordrqfscanInternalBuildOptions; +use crate::vchordrqfscan::types::VectorOptions; use base::distance::DistanceKind; -use base::index::VectorOptions; -use base::scalar::ScalarLike; use base::search::Pointer; +use base::simd::ScalarLike; use rand::Rng; use rkyv::ser::serializers::AllocSerializer; use std::marker::PhantomData; @@ -268,7 +268,7 @@ impl Structure { if vector_options.dims != vector.as_borrowed().dims() { pgrx::error!("extern build: incorrect dimension, id = {id}"); } - vectors.insert(id, crate::projection::project(vector.slice())); + vectors.insert(id, crate::projection::project(vector.as_borrowed().slice())); } }); let mut children = parents diff --git a/src/vchordrqfscan/algorithm/insert.rs b/src/vchordrqfscan/algorithm/insert.rs index e89b110..4dfd432 100644 --- a/src/vchordrqfscan/algorithm/insert.rs +++ b/src/vchordrqfscan/algorithm/insert.rs @@ -6,8 +6,8 @@ use crate::vchordrqfscan::algorithm::tuples::*; use base::always_equal::AlwaysEqual; use base::distance::Distance; use base::distance::DistanceKind; -use base::scalar::ScalarLike; use base::search::Pointer; +use base::simd::ScalarLike; use std::cmp::Reverse; use std::collections::BinaryHeap; diff --git a/src/vchordrqfscan/algorithm/rabitq.rs b/src/vchordrqfscan/algorithm/rabitq.rs index cf72ca5..65c4996 100644 --- a/src/vchordrqfscan/algorithm/rabitq.rs +++ b/src/vchordrqfscan/algorithm/rabitq.rs @@ -1,6 +1,6 @@ +use crate::utils::infinite_byte_chunks::InfiniteByteChunks; use base::distance::{Distance, DistanceKind}; -use base::scalar::ScalarLike; -use quantization::utils::InfiniteByteChunks; +use base::simd::ScalarLike; #[derive(Debug, Clone)] pub struct Code { @@ -74,19 +74,19 @@ pub fn pack_codes(dims: u32, codes: [Code; 32]) -> PackedCodes { .take(dims.div_ceil(4) as usize) .collect::>() }); - quantization::fast_scan::b4::pack(dims.div_ceil(4), signs).collect() + base::simd::fast_scan::b4::pack(dims.div_ceil(4), signs).collect() }, } } pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec) { - use quantization::quantize; + use base::simd::quantize; let dis_v_2 = f32::reduce_sum_of_x2(vector); - let (k, b, qvector) = quantize::quantize::<15>(vector); + let (k, b, qvector) = quantize::quantize(vector, 15.0); let qvector_sum = if vector.len() <= 4369 { - quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 } else { - quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32 }; (dis_v_2, b, k, qvector_sum, compress(qvector)) } @@ -105,7 +105,7 @@ pub fn fscan_process_lowerbound( epsilon: f32, ) -> [Distance; 32] { let &(dis_v_2, b, k, qvector_sum, ref s) = lut; - let r = quantization::fast_scan::b4::fast_scan_b4(dims.div_ceil(4), t, s); + let r = base::simd::fast_scan::b4::fast_scan_b4(dims.div_ceil(4), t, s); match distance_kind { DistanceKind::L2 => std::array::from_fn(|i| { let rough = dis_u_2[i] diff --git a/src/vchordrqfscan/algorithm/scan.rs b/src/vchordrqfscan/algorithm/scan.rs index 63264f1..202949e 100644 --- a/src/vchordrqfscan/algorithm/scan.rs +++ b/src/vchordrqfscan/algorithm/scan.rs @@ -6,8 +6,8 @@ use crate::vchordrqfscan::algorithm::tuples::*; use base::always_equal::AlwaysEqual; use base::distance::Distance; use base::distance::DistanceKind; -use base::scalar::ScalarLike; use base::search::Pointer; +use base::simd::ScalarLike; use std::cmp::Reverse; use std::collections::BinaryHeap; diff --git a/src/vchordrqfscan/index/am.rs b/src/vchordrqfscan/index/am.rs index e234414..84a4c8b 100644 --- a/src/vchordrqfscan/index/am.rs +++ b/src/vchordrqfscan/index/am.rs @@ -183,7 +183,7 @@ pub unsafe extern "C" fn ambuild( ) where F: FnMut((Pointer, Vec)), { - use base::vector::OwnedVector; + use crate::vchordrqfscan::types::OwnedVector; let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; @@ -191,9 +191,6 @@ pub unsafe extern "C" fn ambuild( if let Some(vector) = vector { let vector = match vector { OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), }; (state.callback)((pointer, vector.into_vec())); } @@ -572,7 +569,7 @@ unsafe fn parallel_build( ) where F: FnMut((Pointer, Vec)), { - use base::vector::OwnedVector; + use crate::vchordrqfscan::types::OwnedVector; let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; @@ -580,9 +577,6 @@ unsafe fn parallel_build( if let Some(vector) = vector { let vector = match vector { OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), }; (state.callback)((pointer, vector.into_vec())); } @@ -669,15 +663,12 @@ pub unsafe extern "C" fn aminsert( _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use base::vector::OwnedVector; + use crate::vchordrqfscan::types::OwnedVector; let opfamily = unsafe { am_options::opfamily(index) }; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; if let Some(vector) = vector { let vector = match vector { OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), }; let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); algorithm::insert::insert( @@ -702,15 +693,12 @@ pub unsafe extern "C" fn aminsert( _index_unchanged: bool, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use base::vector::OwnedVector; + use crate::vchordrqfscan::types::OwnedVector; let opfamily = unsafe { am_options::opfamily(index) }; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; if let Some(vector) = vector { let vector = match vector { OwnedVector::Vecf32(x) => x, - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), }; let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); algorithm::insert::insert( diff --git a/src/vchordrqfscan/index/am_options.rs b/src/vchordrqfscan/index/am_options.rs index 51a1009..b49b7a2 100644 --- a/src/vchordrqfscan/index/am_options.rs +++ b/src/vchordrqfscan/index/am_options.rs @@ -1,10 +1,9 @@ use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use crate::datatype::typmod::Typmod; -use crate::vchordrqfscan::types::VchordrqfscanIndexingOptions; +use crate::vchordrqfscan::types::*; use base::distance::*; -use base::index::*; -use base::vector::*; +use base::vector::VectorBorrowed; use pgrx::datum::FromDatum; use pgrx::heap_tuple::PgHeapTuple; use serde::Deserialize; @@ -26,7 +25,7 @@ impl Reloption { }]; unsafe fn options(&self) -> &CStr { unsafe { - let ptr = std::ptr::addr_of!(*self) + let ptr = (&raw const *self) .cast::() .offset(self.options as _); CStr::from_ptr(ptr) @@ -132,7 +131,6 @@ impl Opfamily { let vector = unsafe { PgvectorVectorInput::from_datum(datum, false).unwrap() }; self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed())) } - _ => unreachable!(), }; Some(vector) } @@ -150,7 +148,6 @@ impl Opfamily { .get_by_index::(NonZero::new(1).unwrap()) .unwrap() .map(|vector| self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed()))), - _ => unreachable!(), }; let radius = tuple.get_by_index::(NonZero::new(2).unwrap()).unwrap(); (center, radius) @@ -162,9 +159,6 @@ impl Opfamily { (B::Vecf32(x), PgDistanceKind::L2) => O::Vecf32(x.own()), (B::Vecf32(x), PgDistanceKind::Dot) => O::Vecf32(x.own()), (B::Vecf32(x), PgDistanceKind::Cos) => O::Vecf32(x.function_normalize()), - (B::Vecf16(x), _) => O::Vecf16(x.own()), - (B::SVecf32(x), _) => O::SVecf32(x.own()), - (B::BVector(x), _) => O::BVector(x.own()), } } pub fn process(self, x: Distance) -> f32 { diff --git a/src/vchordrqfscan/index/am_scan.rs b/src/vchordrqfscan/index/am_scan.rs index 7396bd1..b07edb7 100644 --- a/src/vchordrqfscan/index/am_scan.rs +++ b/src/vchordrqfscan/index/am_scan.rs @@ -4,9 +4,9 @@ use crate::vchordrqfscan::algorithm::scan::scan; use crate::vchordrqfscan::gucs::executing::epsilon; use crate::vchordrqfscan::gucs::executing::max_scan_tuples; use crate::vchordrqfscan::gucs::executing::probes; +use crate::vchordrqfscan::types::OwnedVector; use base::distance::Distance; use base::search::*; -use base::vector::*; pub enum Scanner { Initial { @@ -34,7 +34,7 @@ pub fn scan_build( for orderby_vector in orderbys { if pair.is_none() { pair = orderby_vector; - } else if orderby_vector.is_some() && pair != orderby_vector { + } else if orderby_vector.is_some() { pgrx::error!("vector search with multiple vectors is not supported"); } } @@ -42,10 +42,6 @@ pub fn scan_build( if pair.is_none() { pair = sphere_vector; threshold = sphere_threshold; - } else if pair == sphere_vector { - if threshold.is_none() || sphere_threshold < threshold { - threshold = sphere_threshold; - } } else { recheck = true; break; @@ -78,9 +74,6 @@ pub fn scan_next(scanner: &mut Scanner, relation: Relation) -> Option<(Pointer, relation, match vector { OwnedVector::Vecf32(x) => x.slice().to_vec(), - OwnedVector::Vecf16(_) => unreachable!(), - OwnedVector::SVecf32(_) => unreachable!(), - OwnedVector::BVector(_) => unreachable!(), }, opfamily.distance_kind(), probes(), diff --git a/src/vchordrqfscan/types.rs b/src/vchordrqfscan/types.rs index 0fbe82e..1180e64 100644 --- a/src/vchordrqfscan/types.rs +++ b/src/vchordrqfscan/types.rs @@ -1,3 +1,5 @@ +use base::distance::DistanceKind; +use base::vector::{VectBorrowed, VectOwned}; use serde::{Deserialize, Serialize}; use validator::{Validate, ValidationError, ValidationErrors}; @@ -95,3 +97,42 @@ impl VchordrqfscanIndexingOptions { false } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum OwnedVector { + Vecf32(VectOwned), +} + +#[derive(Debug, Clone, Copy)] +pub enum BorrowedVector<'a> { + Vecf32(VectBorrowed<'a, f32>), +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum VectorKind { + Vecf32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +#[validate(schema(function = "Self::validate_self"))] +pub struct VectorOptions { + #[validate(range(min = 1, max = 1_048_575))] + #[serde(rename = "dimensions")] + pub dims: u32, + #[serde(rename = "vector")] + pub v: VectorKind, + #[serde(rename = "distance")] + pub d: DistanceKind, +} + +impl VectorOptions { + pub fn validate_self(&self) -> Result<(), ValidationError> { + match (self.v, self.d, self.dims) { + (VectorKind::Vecf32, DistanceKind::L2, 1..65536) => Ok(()), + (VectorKind::Vecf32, DistanceKind::Dot, 1..65536) => Ok(()), + _ => Err(ValidationError::new("not valid vector options")), + } + } +} diff --git a/tests/logic/distance.slt b/tests/logic/distance.slt new file mode 100644 index 0000000..e03f32f --- /dev/null +++ b/tests/logic/distance.slt @@ -0,0 +1,59 @@ +query I +SELECT round(('[1,2,3]'::vector <-> '[2,3,4]'::vector):: numeric, 3); +---- +1.732 + +query I +SELECT round(('[1,2,3]'::vector <#> '[2,3,4]'::vector):: numeric, 3); +---- +-20.000 + +query I +SELECT round(('[1,2,3]'::vector <=> '[2,3,4]'::vector):: numeric, 3); +---- +0.007 + +query I +SELECT round(('[1,2,3]'::halfvec <-> '[2,3,4]'::halfvec):: numeric, 3); +---- +1.732 + +query I +SELECT round(('[1,2,3]'::halfvec <#> '[2,3,4]'::halfvec):: numeric, 3); +---- +-20.000 + +query I +SELECT round(('[1,2,3]'::halfvec <=> '[2,3,4]'::halfvec):: numeric, 3); +---- +0.007 + +query I +SELECT round((quantize_to_scalar8('[1,2,3]'::vector) <-> quantize_to_scalar8('[2,3,4]'::vector)):: numeric, 1); +---- +1.7 + +query I +SELECT round((quantize_to_scalar8('[1,2,3]'::vector) <#> quantize_to_scalar8('[2,3,4]'::vector)):: numeric, 1); +---- +-20.0 + +query I +SELECT round((quantize_to_scalar8('[1,2,3]'::vector) <=> quantize_to_scalar8('[2,3,4]'::vector)):: numeric, 2); +---- +0.01 + +query I +SELECT round((quantize_to_scalar8('[1,2,3]'::halfvec) <-> quantize_to_scalar8('[2,3,4]'::halfvec)):: numeric, 1); +---- +1.7 + +query I +SELECT round((quantize_to_scalar8('[1,2,3]'::halfvec) <#> quantize_to_scalar8('[2,3,4]'::halfvec)):: numeric, 1); +---- +-20.0 + +query I +SELECT round((quantize_to_scalar8('[1,2,3]'::halfvec) <=> quantize_to_scalar8('[2,3,4]'::halfvec)):: numeric, 2); +---- +0.01