Skip to content

Commit

Permalink
test: add migrate case (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
elstic authored Aug 5, 2024
1 parent 6314666 commit 3515d66
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 22 deletions.
24 changes: 5 additions & 19 deletions tests/base/collection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ def flush(self, check_task=None, check_items=None, **kwargs):

@trace()
def search(self, data, anns_field, param, limit, expr=None,
partition_names=None, output_fields=None, timeout=None, round_decimal=-1,
check_task=None, check_items=None, **kwargs):
partition_names=None, output_fields=None, timeout=None # round_decimal=-1,
, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout

func_name = sys._getframe().f_code.co_name
res, check = api_request([self.collection.search, data, anns_field, param, limit,
expr, partition_names, output_fields, timeout, round_decimal], **kwargs)
expr, partition_names, output_fields, timeout], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
data=data, anns_field=anns_field, param=param, limit=limit,
expr=expr, partition_names=partition_names,
Expand All @@ -181,9 +181,10 @@ def hybrid_search(self, reqs, rerank, limit,
output_fields=None, timeout=None, round_decimal=-1,
check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
partition_names = None
func_name = sys._getframe().f_code.co_name
res, check = api_request([self.collection.hybrid_search, reqs, rerank, limit,
output_fields, timeout, round_decimal], **kwargs)
partition_names, output_fields, timeout], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
reqs=reqs, rerank=rerank, limit=limit,
output_fields=output_fields,
Expand All @@ -206,21 +207,6 @@ def search_iterator(self, data, anns_field, param, batch_size, limit=-1, expr=No
timeout=timeout, **kwargs).run()
return res, check_result

@trace()
def hybrid_search(self, reqs, rerank, limit, partition_names=None, output_fields=None, timeout=None, round_decimal=-1,
check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout

func_name = sys._getframe().f_code.co_name
res, check = api_request([self.collection.hybrid_search, reqs, rerank, limit,
partition_names, output_fields, timeout, round_decimal], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
reqs=reqs, rerank=rerank, limit=limit,
partition_names=partition_names,
output_fields=output_fields,
timeout=timeout, **kwargs).run()
return res, check_result

@trace()
def query(self, expr, output_fields=None, partition_names=None, timeout=None, check_task=None, check_items=None,
**kwargs):
Expand Down
108 changes: 108 additions & 0 deletions tests/milvus_lite/test_milvus_lite_migrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
import random
import time
from multiprocessing import Process
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)
from pymilvus.client.types import LoadState
import pytest
import os
from base.client_base import TestcaseBase
from common.common_type import CaseLabel, CheckTasks
from base.high_level_api_wrapper import HighLevelApiWrapper
from common import common_type as ct


client_w = HighLevelApiWrapper()
prefix = "milvus_lite_migrate"
default_dim = ct.default_dim
default_primary_key_field_name = "id"
default_vector_field_name = "vector"
default_float_field_name = ct.default_float_field_name
default_bool_field_name = ct.default_bool_field_name
# default_nb = ct.default_nb
default_limit = ct.default_limit
default_nq = ct.default_nq
default_string_field_name = ct.default_string_field_name
default_int32_array_field_name = ct.default_int32_array_field_name
c_name = "coll_migrate_1"
d_file = f"local_migrate_case_test.db"


class TestMilvusLiteMigrate(TestcaseBase):
"""
1. create collection for data preparation after normal search
2. dump to json file
3. verify that the key in the generated json file matches the schema
"""
@pytest.mark.tags(CaseLabel.L2)
def test_milvus_lite_migrate_json_file(self):
num_entities, dim = 3000, 8
connections.connect("default", uri=d_file)
fields = [
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim)
]
schema = CollectionSchema(fields, "collection is the simplest demo to introduce the APIs")
hello_milvus = Collection(c_name, schema)
rng = np.random.default_rng(seed=19530)
entities = [
# provide the pk field because `auto_id` is set to False
rng.random(num_entities).tolist(), # field random, only supports list
[str(i) for i in range(num_entities)],
rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list
]
insert_result = hello_milvus.insert(entities)
index = {
"index_type": "FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

hello_milvus.create_index("embeddings", index)
hello_milvus.load()
assert utility.load_state(c_name) == LoadState.Loaded
vectors_to_search = entities[-1][-2:]
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["random"])

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")

result = hello_milvus.query(expr="random > 0.5", output_fields=["random", "embeddings"])
r1 = hello_milvus.query(expr="random > 0.5", limit=4, output_fields=["random"])
r2 = hello_milvus.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"])
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5",
output_fields=["random"])

ids = insert_result.primary_keys
expr = f'pk in ["{ids[0]}" , "{ids[1]}"]'
result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"])

@pytest.mark.tags(CaseLabel.L2)
def test_check_json_file_key(self):
"""
pytest test_milvus_lite_migrate.py::TestMilvusLiteMigrate::test_check_json_file_key -s
To check whether the data export is working properly using the above command, start a new command line
"""
command = f"milvus-lite dump -d ./{d_file} -c {c_name} -p ./data_json"
os.system(command)
time.sleep(3)
dir = os.popen("ls data_json/").read()
dir = dir.replace("\n", "")
with open(f'data_json/{dir}/1.json') as user_file:
file_contents = user_file.read()
parsed_json = json.loads(file_contents)
os.popen("rm -rf data_json")
keys = parsed_json["rows"][0].keys()
assert list(keys) == ['random', 'pk', 'embeddings']
5 changes: 2 additions & 3 deletions tests/milvus_lite/test_milvus_lite_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ def test_hybrid_search_different_limit_round_decimal(self, primary_field, limit)
# search to get the base line of hybrid_search
search_res = collection_w.search(vectors[:1], vector_name_list[i],
default_search_params, limit,
default_expr, round_decimal=5,
default_expr,
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"ids": insert_ids,
Expand All @@ -1674,10 +1674,9 @@ def test_hybrid_search_different_limit_round_decimal(self, primary_field, limit)
search_res_dict[ids[j]] = distance_array[j]
search_res_dict_array.append(search_res_dict)
# 4. calculate hybrid search base line
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics, 5)
ids_answer, score_answer = cf.get_hybrid_search_base_results(search_res_dict_array, weights, metrics)
# 5. hybrid search
hybrid_res = collection_w.hybrid_search(req_list, WeightedRanker(*weights), limit,
round_decimal=5,
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"ids": insert_ids,
Expand Down

0 comments on commit 3515d66

Please sign in to comment.