From dabca33a0246eb35775602a16fd2e3af734e022c Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 23 Dec 2024 19:34:52 +0800 Subject: [PATCH 1/3] chore: type check and external test Signed-off-by: cutecutecat --- scripts/train.py | 2 +- src/vchordrq/algorithm/build.rs | 28 ++++++++++-- tests/logic/external_build.slt | 79 +++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 tests/logic/external_build.slt diff --git a/scripts/train.py b/scripts/train.py index 99b952c..6d3c44b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -170,7 +170,7 @@ def kmeans_cluster( verbose=True, niter=niter, seed=SEED, - spherical=metric == "cos", + spherical=metric != "l2", ) child_kmeans.train(child_train) centroids.append(child_kmeans.centroids) diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index 7c13935..f4eead0 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -203,14 +203,34 @@ impl Structure { ) -> Vec { use std::collections::BTreeMap; let VchordrqExternalBuildOptions { table } = external_build; - let query = format!("SELECT id, parent, vector FROM {table};"); + let dump_query = format!("SELECT id, parent, vector FROM {table};"); + let table_name = table.split('.').last().unwrap().to_string(); + let type_check_query = format!( + "SELECT COUNT(*)::INTEGER + FROM pg_catalog.pg_extension e + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace + LEFT JOIN information_schema.columns i ON i.udt_schema = n.nspname + WHERE e.extname = 'vector' AND i.udt_name = 'vector' + AND i.table_name = '{table_name}' AND i.column_name = 'vector';" + ); let mut parents = BTreeMap::new(); let mut vectors = BTreeMap::new(); pgrx::spi::Spi::connect(|client| { use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use base::vector::VectorBorrowed; use pgrx::pg_sys::panic::ErrorReportable; - let table = client.select(&query, None, None).unwrap_or_report(); + // Check the column of centroid table named `vector`, which type should be pgvector::vector + let type_check = client + .select(&type_check_query, None, None) + .unwrap_or_report(); + let count: Result, _> = type_check.first().get_by_name("count"); + if count != Ok(Some(1)) { + pgrx::warning!("{:?}", count); + pgrx::error!( + "extern build: `vector` column should be pgvector::vector type at the centroid table" + ); + } + let table = client.select(&dump_query, None, None).unwrap_or_report(); for row in table { let id: Option = row.get_by_name("id").unwrap(); let parent: Option = row.get_by_name("parent").unwrap(); @@ -220,7 +240,7 @@ impl Structure { let pop = parents.insert(id, parent); if pop.is_some() { pgrx::error!( - "external build: there are at least two lines have same id, id = {id}" + "extern build: there are at least two lines have same id, id = {id}" ); } if vector_options.dims != vector.as_borrowed().dims() { @@ -263,7 +283,7 @@ impl Structure { parent.push(id); } else { pgrx::error!( - "external build: parent does not exist, id = {id}, parent = {parent}" + "extern build: parent does not exist, id = {id}, parent = {parent}" ); } } else { diff --git a/tests/logic/external_build.slt b/tests/logic/external_build.slt new file mode 100644 index 0000000..09cf35c --- /dev/null +++ b/tests/logic/external_build.slt @@ -0,0 +1,79 @@ +statement ok +CREATE TABLE t (val0 vector(3), val1 halfvec(3)); + +statement ok +INSERT INTO t (val0, val1) +SELECT + ARRAY[random(), random(), random()]::real[]::vector, + ARRAY[random(), random(), random()]::real[]::halfvec +FROM generate_series(1, 100); + +statement ok +CREATE TABLE vector_centroid (id integer, parent integer, vector vector(3)); + +statement ok +INSERT INTO vector_centroid (id, vector) VALUES + (0, '[1.0, 0.0, 0.0]'), + (1, '[0.0, 1.0, 0.0]'), + (2, '[0.0, 0.0, 1.0]'); + +statement ok +CREATE TABLE bad_type_centroid (id integer, parent integer, vector halfvec(3)); + +statement ok +INSERT INTO bad_type_centroid (id, vector) VALUES + (0, '[1.0, 0.0, 0.0]'), + (1, '[0.0, 1.0, 0.0]'), + (2, '[0.0, 0.0, 1.0]'); + +statement ok +CREATE TABLE bad_duplicate_id (id integer, parent integer, vector vector(3)); + +statement ok +INSERT INTO bad_duplicate_id (id, vector) VALUES + (1, '[1.0, 0.0, 0.0]'), + (1, '[0.0, 1.0, 0.0]'), + (2, '[0.0, 0.0, 1.0]'); + +# external build for vector column + +statement ok +CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'public.vector_centroid' +$$); + +# external build for halfvec column + +statement ok +CREATE INDEX ON t USING vchordrq (val1 halfvec_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'public.vector_centroid' +$$); + +# failed: bad vector bad_type + +statement error extern build: `vector` column should be pgvector::vector type at the centroid table +CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'public.bad_type_centroid' +$$); + +# failed: duplicate id + +statement error extern build: there are at least two lines have same id, id = 1 +CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'public.bad_duplicate_id' +$$); + +statement ok +DROP TABLE t, vector_centroid, bad_type_centroid, bad_duplicate_id; \ No newline at end of file From 144363f007bebf9ff44cf0b5aad2fd8a8dde41d1 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Tue, 24 Dec 2024 15:49:42 +0800 Subject: [PATCH 2/3] fix by comments Signed-off-by: cutecutecat --- src/vchordrq/algorithm/build.rs | 64 ++++++++++++++++++--------------- tests/logic/external_build.slt | 29 ++++++++++++--- 2 files changed, 59 insertions(+), 34 deletions(-) diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index f4eead0..7f40aa8 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -203,48 +203,54 @@ impl Structure { ) -> Vec { use std::collections::BTreeMap; let VchordrqExternalBuildOptions { table } = external_build; - let dump_query = format!("SELECT id, parent, vector FROM {table};"); let table_name = table.split('.').last().unwrap().to_string(); - let type_check_query = format!( - "SELECT COUNT(*)::INTEGER - FROM pg_catalog.pg_extension e - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace - LEFT JOIN information_schema.columns i ON i.udt_schema = n.nspname - WHERE e.extname = 'vector' AND i.udt_name = 'vector' - AND i.table_name = '{table_name}' AND i.column_name = 'vector';" - ); let mut parents = BTreeMap::new(); let mut vectors = BTreeMap::new(); pgrx::spi::Spi::connect(|client| { use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use base::vector::VectorBorrowed; use pgrx::pg_sys::panic::ErrorReportable; - // Check the column of centroid table named `vector`, which type should be pgvector::vector - let type_check = client - .select(&type_check_query, None, None) - .unwrap_or_report(); - let count: Result, _> = type_check.first().get_by_name("count"); - if count != Ok(Some(1)) { - pgrx::warning!("{:?}", count); + // Get the schema of pgvector + let schema_query = format!( + "SELECT n.nspname::TEXT + FROM pg_catalog.pg_extension e + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace + LEFT JOIN information_schema.columns i ON i.udt_schema = n.nspname + WHERE e.extname = 'vector' AND i.table_name = '{table_name}' AND i.column_name = 'vector';"); + let nspname: Vec = client + .select(&schema_query, None, None) + .unwrap_or_report() + .map(|data| { + data.get_by_name("nspname") + .expect("external build: cannot get schema of pgvector") + .expect("external build: cannot get schema of pgvector") + }) + .collect(); + // Check the `vector` column is pgvector-based type + let pgvector_schema = if let [schema] = &nspname[..] { + schema.clone() + } else { pgrx::error!( - "extern build: `vector` column should be pgvector::vector type at the centroid table" + "external build: `vector` column should be a pgvector type at the external table" ); - } - let table = client.select(&dump_query, None, None).unwrap_or_report(); - for row in table { + }; + let dump_query = + format!("SELECT id, parent, vector::{pgvector_schema}.vector FROM {table};"); + let centroids = client.select(&dump_query, None, None).unwrap_or_report(); + for row in centroids { let id: Option = row.get_by_name("id").unwrap(); let parent: Option = row.get_by_name("parent").unwrap(); let vector: Option = row.get_by_name("vector").unwrap(); - let id = id.expect("extern build: id could not be NULL"); - let vector = vector.expect("extern build: vector could not be NULL"); + let id = id.expect("external build: id could not be NULL"); + let vector = vector.expect("external build: vector could not be NULL"); let pop = parents.insert(id, parent); if pop.is_some() { pgrx::error!( - "extern build: there are at least two lines have same id, id = {id}" + "external build: there are at least two lines have same id, id = {id}" ); } if vector_options.dims != vector.as_borrowed().dims() { - pgrx::error!("extern build: incorrect dimension, id = {id}"); + pgrx::error!("external build: incorrect dimension, id = {id}"); } vectors.insert(id, crate::projection::project(vector.as_borrowed().slice())); } @@ -283,7 +289,7 @@ impl Structure { parent.push(id); } else { pgrx::error!( - "extern build: parent does not exist, id = {id}, parent = {parent}" + "external build: parent does not exist, id = {id}, parent = {parent}" ); } } else { @@ -295,7 +301,7 @@ impl Structure { } } let Some(root) = root else { - pgrx::error!("extern build: there are no root"); + pgrx::error!("external build: there are no root"); }; let mut heights = BTreeMap::<_, _>::new(); fn dfs_for_heights( @@ -304,7 +310,7 @@ impl Structure { u: i32, ) { if heights.contains_key(&u) { - pgrx::error!("extern build: detect a cycle, id = {u}"); + pgrx::error!("external build: detect a cycle, id = {u}"); } heights.insert(u, None); let mut height = None; @@ -313,7 +319,7 @@ impl Structure { let new = heights[&v].unwrap() + 1; if let Some(height) = height { if height != new { - pgrx::error!("extern build: two heights, id = {u}"); + pgrx::error!("external build: two heights, id = {u}"); } } else { height = Some(new); @@ -331,7 +337,7 @@ impl Structure { .collect::>(); if !(1..=8).contains(&(heights[&root] - 1)) { pgrx::error!( - "extern build: unexpected tree height, height = {}", + "external build: unexpected tree height, height = {}", heights[&root] ); } diff --git a/tests/logic/external_build.slt b/tests/logic/external_build.slt index 09cf35c..68aefa8 100644 --- a/tests/logic/external_build.slt +++ b/tests/logic/external_build.slt @@ -18,14 +18,23 @@ INSERT INTO vector_centroid (id, vector) VALUES (2, '[0.0, 0.0, 1.0]'); statement ok -CREATE TABLE bad_type_centroid (id integer, parent integer, vector halfvec(3)); +CREATE TABLE halfvec_centroid (id integer, parent integer, vector halfvec(3)); statement ok -INSERT INTO bad_type_centroid (id, vector) VALUES +INSERT INTO halfvec_centroid (id, vector) VALUES (0, '[1.0, 0.0, 0.0]'), (1, '[0.0, 1.0, 0.0]'), (2, '[0.0, 0.0, 1.0]'); +statement ok +CREATE TABLE bad_type_centroid (id integer, parent integer, vector real[]); + +statement ok +INSERT INTO bad_type_centroid (id, vector) VALUES + (0, '{1.0, 0.0, 0.0}'), + (1, '{0.0, 1.0, 0.0}'), + (2, '{0.0, 0.0, 1.0}'); + statement ok CREATE TABLE bad_duplicate_id (id integer, parent integer, vector vector(3)); @@ -55,9 +64,19 @@ residual_quantization = true table = 'public.vector_centroid' $$); -# failed: bad vector bad_type +# external build for halfvec column by a halfvec table + +statement ok +CREATE INDEX ON t USING vchordrq (val1 halfvec_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'public.halfvec_centroid' +$$); + +# failed: bad vector data type -statement error extern build: `vector` column should be pgvector::vector type at the centroid table +statement error external build: `vector` column should be a pgvector type at the external table CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) WITH (options = $$ residual_quantization = true @@ -67,7 +86,7 @@ $$); # failed: duplicate id -statement error extern build: there are at least two lines have same id, id = 1 +statement error external build: there are at least two lines have same id, id = 1 CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) WITH (options = $$ residual_quantization = true From 30da4f9eb7a02e5bd245981590c61e5e6e1ae5da Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Tue, 24 Dec 2024 17:36:53 +0800 Subject: [PATCH 3/3] fix Signed-off-by: cutecutecat --- src/vchordrq/algorithm/build.rs | 30 ++++++++---------------------- tests/logic/external_build.slt | 27 +++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index 7f40aa8..06f835e 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -203,37 +203,23 @@ impl Structure { ) -> Vec { use std::collections::BTreeMap; let VchordrqExternalBuildOptions { table } = external_build; - let table_name = table.split('.').last().unwrap().to_string(); let mut parents = BTreeMap::new(); let mut vectors = BTreeMap::new(); pgrx::spi::Spi::connect(|client| { use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use base::vector::VectorBorrowed; use pgrx::pg_sys::panic::ErrorReportable; - // Get the schema of pgvector - let schema_query = format!( - "SELECT n.nspname::TEXT + let schema_query = "SELECT n.nspname::TEXT FROM pg_catalog.pg_extension e LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace - LEFT JOIN information_schema.columns i ON i.udt_schema = n.nspname - WHERE e.extname = 'vector' AND i.table_name = '{table_name}' AND i.column_name = 'vector';"); - let nspname: Vec = client - .select(&schema_query, None, None) + WHERE e.extname = 'vector';"; + let pgvector_schema: String = client + .select(schema_query, None, None) .unwrap_or_report() - .map(|data| { - data.get_by_name("nspname") - .expect("external build: cannot get schema of pgvector") - .expect("external build: cannot get schema of pgvector") - }) - .collect(); - // Check the `vector` column is pgvector-based type - let pgvector_schema = if let [schema] = &nspname[..] { - schema.clone() - } else { - pgrx::error!( - "external build: `vector` column should be a pgvector type at the external table" - ); - }; + .first() + .get_by_name("nspname") + .expect("external build: cannot get schema of pgvector") + .expect("external build: cannot get schema of pgvector"); let dump_query = format!("SELECT id, parent, vector::{pgvector_schema}.vector FROM {table};"); let centroids = client.select(&dump_query, None, None).unwrap_or_report(); diff --git a/tests/logic/external_build.slt b/tests/logic/external_build.slt index 68aefa8..81dc35d 100644 --- a/tests/logic/external_build.slt +++ b/tests/logic/external_build.slt @@ -27,14 +27,23 @@ INSERT INTO halfvec_centroid (id, vector) VALUES (2, '[0.0, 0.0, 1.0]'); statement ok -CREATE TABLE bad_type_centroid (id integer, parent integer, vector real[]); +CREATE TABLE real_centroid (id integer, parent integer, vector real[]); statement ok -INSERT INTO bad_type_centroid (id, vector) VALUES +INSERT INTO real_centroid (id, vector) VALUES (0, '{1.0, 0.0, 0.0}'), (1, '{0.0, 1.0, 0.0}'), (2, '{0.0, 0.0, 1.0}'); +statement ok +CREATE TABLE bad_type_centroid (id integer, parent integer, vector integer); + +statement ok +INSERT INTO bad_type_centroid (id, vector) VALUES + (0, 0), + (1, 0), + (2, 0); + statement ok CREATE TABLE bad_duplicate_id (id integer, parent integer, vector vector(3)); @@ -74,9 +83,19 @@ residual_quantization = true table = 'public.halfvec_centroid' $$); +# external build for halfvec column by a real[] table + +statement ok +CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'public.real_centroid' +$$); + # failed: bad vector data type -statement error external build: `vector` column should be a pgvector type at the external table +statement error cannot cast type integer to (.*)vector CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) WITH (options = $$ residual_quantization = true @@ -95,4 +114,4 @@ table = 'public.bad_duplicate_id' $$); statement ok -DROP TABLE t, vector_centroid, bad_type_centroid, bad_duplicate_id; \ No newline at end of file +DROP TABLE t, vector_centroid, halfvec_centroid, real_centroid, bad_type_centroid, bad_duplicate_id; \ No newline at end of file