Skip to content

Commit

Permalink
feat: unify vchordrq and vchordrqfscan
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi committed Jan 16, 2025
1 parent e26e5e3 commit 80d3492
Show file tree
Hide file tree
Showing 67 changed files with 4,093 additions and 4,770 deletions.
269 changes: 71 additions & 198 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 4 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,13 @@ pg16 = ["pgrx/pg16", "pgrx-catalog/pg16"]
pg17 = ["pgrx/pg17", "pgrx-catalog/pg17"]

[dependencies]
algorithm = { path = "./crates/algorithm" }
always_equal = { path = "./crates/always_equal" }
distance = { path = "./crates/distance" }
rabitq = { path = "./crates/rabitq" }
random_orthogonal_matrix = { path = "./crates/random_orthogonal_matrix" }
simd = { path = "./crates/simd" }
vector = { path = "./crates/vector" }

# lock rkyv version forever so that data is always compatible
rkyv = { version = "=0.7.45", features = ["validation"] }

half.workspace = true
log = "0.4.22"
paste = "1"
Expand All @@ -41,9 +37,11 @@ rayon = "1.10.0"
serde.workspace = true
toml = "0.8.19"
validator = { version = "0.19.0", features = ["derive"] }
zerocopy = "0.8.14"
zerocopy-derive = "0.8.14"

[patch.crates-io]
half = { git = "https://github.com/tensorchord/half-rs.git" }
half = { git = "https://github.com/usamoi/half-rs.git" }

[lints]
workspace = true
Expand All @@ -57,7 +55,7 @@ version = "0.0.0"
edition = "2021"

[workspace.dependencies]
half = { version = "2.4.1", features = ["rkyv", "serde"] }
half = { version = "2.4.1", features = ["serde", "zerocopy"] }
rand = "0.8.5"
serde = "1"

Expand Down
7 changes: 0 additions & 7 deletions crates/algorithm/Cargo.toml

This file was deleted.

72 changes: 6 additions & 66 deletions crates/rabitq/src/binary.rs
Original file line number Diff line number Diff line change
@@ -1,67 +1,9 @@
use distance::Distance;
use simd::Floating;

#[derive(Debug, Clone)]
pub struct Code {
pub dis_u_2: f32,
pub factor_ppc: f32,
pub factor_ip: f32,
pub factor_err: f32,
pub signs: Vec<u8>,
}

impl Code {
pub fn t(&self) -> Vec<u64> {
use crate::utils::InfiniteByteChunks;
let mut result = Vec::new();
for x in InfiniteByteChunks::<_, 64>::new(self.signs.iter().copied())
.take(self.signs.len().div_ceil(64))
{
let mut r = 0_u64;
for i in 0..64 {
r |= (x[i] as u64) << i;
}
result.push(r);
}
result
}
}

pub fn code(dims: u32, vector: &[f32]) -> Code {
let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector);
let sum_of_x_2 = f32::reduce_sum_of_x2(vector);
let dis_u = sum_of_x_2.sqrt();
let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt();
let x_x0 = dis_u / x0;
let fac_norm = (dims as f32).sqrt();
let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt();
let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt();
let factor_ip = -2.0f32 / fac_norm * x_x0;
let cnt_pos = vector
.iter()
.map(|x| x.is_sign_positive() as i32)
.sum::<i32>();
let cnt_neg = vector
.iter()
.map(|x| x.is_sign_negative() as i32)
.sum::<i32>();
let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32;
let mut signs = Vec::new();
for i in 0..dims {
signs.push(vector[i as usize].is_sign_positive() as u8);
}
Code {
dis_u_2: sum_of_x_2,
factor_ppc,
factor_ip,
factor_err,
signs,
}
}

pub type Lut = (f32, f32, f32, f32, (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>));

pub fn preprocess(vector: &[f32]) -> Lut {
pub fn preprocess(
vector: &[f32],
) -> (f32, f32, f32, f32, (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>)) {
let dis_v_2 = f32::reduce_sum_of_x2(vector);
let (k, b, qvector) = simd::quantize::quantize(vector, 15.0);
let qvector_sum = if vector.len() <= 4369 {
Expand All @@ -73,8 +15,7 @@ pub fn preprocess(vector: &[f32]) -> Lut {
}

pub fn process_lowerbound_l2(
_: u32,
lut: &Lut,
lut: &(f32, f32, f32, f32, (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>)),
(dis_u_2, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]),
epsilon: f32,
) -> Distance {
Expand All @@ -87,8 +28,7 @@ pub fn process_lowerbound_l2(
}

pub fn process_lowerbound_dot(
_: u32,
lut: &Lut,
lut: &(f32, f32, f32, f32, (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>)),
(_, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]),
epsilon: f32,
) -> Distance {
Expand All @@ -99,7 +39,7 @@ pub fn process_lowerbound_dot(
Distance::from_f32(rough - epsilon * err)
}

fn binarize(vector: &[u8]) -> (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>) {
pub fn binarize(vector: &[u8]) -> (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>) {
let n = vector.len();
let mut t0 = vec![0u64; n.div_ceil(64)];
let mut t1 = vec![0u64; n.div_ceil(64)];
Expand Down
160 changes: 42 additions & 118 deletions crates/rabitq/src/block.rs
Original file line number Diff line number Diff line change
@@ -1,85 +1,7 @@
use distance::Distance;
use simd::Floating;

#[derive(Debug, Clone)]
pub struct Code {
pub dis_u_2: f32,
pub factor_ppc: f32,
pub factor_ip: f32,
pub factor_err: f32,
pub signs: Vec<u8>,
}

pub fn code(dims: u32, vector: &[f32]) -> Code {
let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector);
let sum_of_x_2 = f32::reduce_sum_of_x2(vector);
let dis_u = sum_of_x_2.sqrt();
let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt();
let x_x0 = dis_u / x0;
let fac_norm = (dims as f32).sqrt();
let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt();
let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt();
let factor_ip = -2.0f32 / fac_norm * x_x0;
let cnt_pos = vector
.iter()
.map(|x| x.is_sign_positive() as i32)
.sum::<i32>();
let cnt_neg = vector
.iter()
.map(|x| x.is_sign_negative() as i32)
.sum::<i32>();
let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32;
let mut signs = Vec::new();
for i in 0..dims {
signs.push(vector[i as usize].is_sign_positive() as u8);
}
Code {
dis_u_2: sum_of_x_2,
factor_ppc,
factor_ip,
factor_err,
signs,
}
}

pub fn dummy_code(dims: u32) -> Code {
Code {
dis_u_2: 0.0,
factor_ppc: 0.0,
factor_ip: 0.0,
factor_err: 0.0,
signs: vec![0; dims as _],
}
}

pub struct PackedCodes {
pub dis_u_2: [f32; 32],
pub factor_ppc: [f32; 32],
pub factor_ip: [f32; 32],
pub factor_err: [f32; 32],
pub t: Vec<u8>,
}

pub fn pack_codes(dims: u32, codes: [Code; 32]) -> PackedCodes {
use crate::utils::InfiniteByteChunks;
PackedCodes {
dis_u_2: std::array::from_fn(|i| codes[i].dis_u_2),
factor_ppc: std::array::from_fn(|i| codes[i].factor_ppc),
factor_ip: std::array::from_fn(|i| codes[i].factor_ip),
factor_err: std::array::from_fn(|i| codes[i].factor_err),
t: {
let signs = codes.map(|code| {
InfiniteByteChunks::new(code.signs.into_iter())
.map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3)
.take(dims.div_ceil(4) as usize)
.collect::<Vec<_>>()
});
simd::fast_scan::pack(dims.div_ceil(4), signs).collect()
},
}
}

pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec<u8>) {
pub fn preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec<[u64; 2]>) {
let dis_v_2 = f32::reduce_sum_of_x2(vector);
let (k, b, qvector) = simd::quantize::quantize(vector, 15.0);
let qvector_sum = if vector.len() <= 4369 {
Expand All @@ -90,20 +12,19 @@ pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec<u8>) {
(dis_v_2, b, k, qvector_sum, compress(qvector))
}

pub fn fscan_process_lowerbound_l2(
dims: u32,
lut: &(f32, f32, f32, f32, Vec<u8>),
pub fn process_lowerbound_l2(
lut: &(f32, f32, f32, f32, Vec<[u64; 2]>),
(dis_u_2, factor_ppc, factor_ip, factor_err, t): (
&[f32; 32],
&[f32; 32],
&[f32; 32],
&[f32; 32],
&[u8],
&[[u64; 2]],
),
epsilon: f32,
) -> [Distance; 32] {
let &(dis_v_2, b, k, qvector_sum, ref s) = lut;
let r = simd::fast_scan::fast_scan(dims.div_ceil(4), t, s);
let r = simd::fast_scan::fast_scan(t, s);
std::array::from_fn(|i| {
let rough = dis_u_2[i]
+ dis_v_2
Expand All @@ -114,20 +35,19 @@ pub fn fscan_process_lowerbound_l2(
})
}

pub fn fscan_process_lowerbound_dot(
dims: u32,
lut: &(f32, f32, f32, f32, Vec<u8>),
pub fn process_lowerbound_dot(
lut: &(f32, f32, f32, f32, Vec<[u64; 2]>),
(_, factor_ppc, factor_ip, factor_err, t): (
&[f32; 32],
&[f32; 32],
&[f32; 32],
&[f32; 32],
&[u8],
&[[u64; 2]],
),
epsilon: f32,
) -> [Distance; 32] {
let &(dis_v_2, b, k, qvector_sum, ref s) = lut;
let r = simd::fast_scan::fast_scan(dims.div_ceil(4), t, s);
let r = simd::fast_scan::fast_scan(t, s);
std::array::from_fn(|i| {
let rough =
0.5 * b * factor_ppc[i] + 0.5 * ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k;
Expand All @@ -136,37 +56,41 @@ pub fn fscan_process_lowerbound_dot(
})
}

fn compress(mut qvector: Vec<u8>) -> Vec<u8> {
let dims = qvector.len() as u32;
let width = dims.div_ceil(4);
qvector.resize(qvector.len().next_multiple_of(4), 0);
let mut t = vec![0u8; width as usize * 16];
for i in 0..width as usize {
pub fn compress(mut vector: Vec<u8>) -> Vec<[u64; 2]> {
let width = vector.len().div_ceil(4);
vector.resize(width * 4, 0);
let mut result = vec![[0u64, 0u64]; width];
for i in 0..width {
unsafe {
// this hint is used to skip bound checks
std::hint::assert_unchecked(4 * i + 3 < qvector.len());
std::hint::assert_unchecked(16 * i + 15 < t.len());
std::hint::assert_unchecked(4 * i + 3 < vector.len());
}
let t0 = qvector[4 * i + 0];
let t1 = qvector[4 * i + 1];
let t2 = qvector[4 * i + 2];
let t3 = qvector[4 * i + 3];
t[16 * i + 0b0000] = 0;
t[16 * i + 0b0001] = t0;
t[16 * i + 0b0010] = t1;
t[16 * i + 0b0011] = t1 + t0;
t[16 * i + 0b0100] = t2;
t[16 * i + 0b0101] = t2 + t0;
t[16 * i + 0b0110] = t2 + t1;
t[16 * i + 0b0111] = t2 + t1 + t0;
t[16 * i + 0b1000] = t3;
t[16 * i + 0b1001] = t3 + t0;
t[16 * i + 0b1010] = t3 + t1;
t[16 * i + 0b1011] = t3 + t1 + t0;
t[16 * i + 0b1100] = t3 + t2;
t[16 * i + 0b1101] = t3 + t2 + t0;
t[16 * i + 0b1110] = t3 + t2 + t1;
t[16 * i + 0b1111] = t3 + t2 + t1 + t0;
let t_0 = vector[4 * i + 0];
let t_1 = vector[4 * i + 1];
let t_2 = vector[4 * i + 2];
let t_3 = vector[4 * i + 3];
result[i] = [
u64::from_le_bytes([
0,
t_0,
t_1,
t_1 + t_0,
t_2,
t_2 + t_0,
t_2 + t_1,
t_2 + t_1 + t_0,
]),
u64::from_le_bytes([
t_3,
t_3 + t_0,
t_3 + t_1,
t_3 + t_1 + t_0,
t_3 + t_2,
t_3 + t_2 + t_0,
t_3 + t_2 + t_1,
t_3 + t_2 + t_1 + t_0,
]),
];
}
t
result
}
Loading

0 comments on commit 80d3492

Please sign in to comment.