From 9413f02c2d51ea79fd5f27d74ba4cc16656083e8 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 23 Dec 2024 19:34:52 +0800 Subject: [PATCH] chore: type check and external test Signed-off-by: cutecutecat --- scripts/train.py | 2 +- src/vchordrq/algorithm/build.rs | 28 ++++++++++-- tests/logic/external_build.slt | 79 +++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 tests/logic/external_build.slt diff --git a/scripts/train.py b/scripts/train.py index 99b952c..6d3c44b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -170,7 +170,7 @@ def kmeans_cluster( verbose=True, niter=niter, seed=SEED, - spherical=metric == "cos", + spherical=metric != "l2", ) child_kmeans.train(child_train) centroids.append(child_kmeans.centroids) diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index 7c13935..f4eead0 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -203,14 +203,34 @@ impl Structure { ) -> Vec { use std::collections::BTreeMap; let VchordrqExternalBuildOptions { table } = external_build; - let query = format!("SELECT id, parent, vector FROM {table};"); + let dump_query = format!("SELECT id, parent, vector FROM {table};"); + let table_name = table.split('.').last().unwrap().to_string(); + let type_check_query = format!( + "SELECT COUNT(*)::INTEGER + FROM pg_catalog.pg_extension e + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace + LEFT JOIN information_schema.columns i ON i.udt_schema = n.nspname + WHERE e.extname = 'vector' AND i.udt_name = 'vector' + AND i.table_name = '{table_name}' AND i.column_name = 'vector';" + ); let mut parents = BTreeMap::new(); let mut vectors = BTreeMap::new(); pgrx::spi::Spi::connect(|client| { use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use base::vector::VectorBorrowed; use pgrx::pg_sys::panic::ErrorReportable; - let table = client.select(&query, None, None).unwrap_or_report(); + // Check the column of centroid table named `vector`, which type should be pgvector::vector + let type_check = client + .select(&type_check_query, None, None) + .unwrap_or_report(); + let count: Result, _> = type_check.first().get_by_name("count"); + if count != Ok(Some(1)) { + pgrx::warning!("{:?}", count); + pgrx::error!( + "extern build: `vector` column should be pgvector::vector type at the centroid table" + ); + } + let table = client.select(&dump_query, None, None).unwrap_or_report(); for row in table { let id: Option = row.get_by_name("id").unwrap(); let parent: Option = row.get_by_name("parent").unwrap(); @@ -220,7 +240,7 @@ impl Structure { let pop = parents.insert(id, parent); if pop.is_some() { pgrx::error!( - "external build: there are at least two lines have same id, id = {id}" + "extern build: there are at least two lines have same id, id = {id}" ); } if vector_options.dims != vector.as_borrowed().dims() { @@ -263,7 +283,7 @@ impl Structure { parent.push(id); } else { pgrx::error!( - "external build: parent does not exist, id = {id}, parent = {parent}" + "extern build: parent does not exist, id = {id}, parent = {parent}" ); } } else { diff --git a/tests/logic/external_build.slt b/tests/logic/external_build.slt new file mode 100644 index 0000000..686de53 --- /dev/null +++ b/tests/logic/external_build.slt @@ -0,0 +1,79 @@ +statement ok +CREATE TABLE t (val0 vector(3), val1 halfvec(3)); + +statement ok +INSERT INTO t (val0, val1) +SELECT + ARRAY[random(), random(), random()]::real[]::vector, + ARRAY[random(), random(), random()]::real[]::halfvec +FROM generate_series(1, 100); + +statement ok +CREATE TABLE vector_centroid (id integer, parent integer, vector vector(3)); + +statement ok +INSERT INTO vector_centroid (id, vector) VALUES + (0, '[1.0, 0.0, 0.0]'), + (1, '[0.0, 1.0, 0.0]'), + (2, '[0.0, 0.0, 1.0]'); + +statement ok +CREATE TABLE bad_type_centroid (id integer, parent integer, halfvec vector(3)); + +statement ok +INSERT INTO bad_type_centroid (id, vector) VALUES + (0, '[1.0, 0.0, 0.0]'), + (1, '[0.0, 1.0, 0.0]'), + (2, '[0.0, 0.0, 1.0]'); + +statement ok +CREATE TABLE bad_duplicate_id (id integer, parent integer, vector vector(3)); + +statement ok +INSERT INTO bad_duplicate_id (id, vector) VALUES + (1, '[1.0, 0.0, 0.0]'), + (1, '[0.0, 1.0, 0.0]'), + (2, '[0.0, 0.0, 1.0]'); + +# external build for vector column + +statement ok +CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'vector_centroid' +$$); + +# external build for halfvec column + +statement ok +CREATE INDEX ON t USING vchordrq (val1 halfvec_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'vector_centroid' +$$); + +# failed: bad vector bad_type + +statement ok +CREATE INDEX ON t USING vchordrq (val vector_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'bad_type_centroid' +$$); + +# failed: duplicate id + +statement ok +CREATE INDEX ON t USING vchordrq (val vector_l2_ops) +WITH (options = $$ +residual_quantization = true +[build.external] +table = 'bad_duplicate_id' +$$); + +statement ok +DROP TABLE t, vector_centroid, bad_type_centroid, bad_duplicate_id; \ No newline at end of file