Skip to content

Commit

Permalink
Add truncate function (#2432)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
add truncate function

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
kche0169 authored Jan 13, 2025
1 parent edf6434 commit d14df22
Show file tree
Hide file tree
Showing 9 changed files with 328 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/infinity_embedded/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def binary_exp_to_paser_exp(binary_expr_key) -> str:
return "/"
elif binary_expr_key == "mod":
return "%"
elif binary_expr_key == "trunc":
return binary_expr_key
else:
raise InfinityException(ErrorCode.INVALID_EXPRESSION, f"unknown binary expression: {binary_expr_key}")

Expand Down
4 changes: 3 additions & 1 deletion python/infinity_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,9 @@ def to_result(self):
if k not in df_dict:
df_dict[k] = ()
tup = df_dict[k]
if res[k].isdigit() or is_float(res[k]):
if isinstance(res[k], str) and len(res[k]) > 0 and res[k][0] == " ":
new_tup = tup + (res[k],)
elif res[k].isdigit() or is_float(res[k]):
new_tup = tup + (eval(res[k]),)
elif is_list(res[k]):
new_tup = tup + (ast.literal_eval(res[k]),)
Expand Down
2 changes: 2 additions & 0 deletions python/test_pysdk/common/common_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def function_return_type(function_name, param_type) :
return dtype('float64')
elif function_name == "filter_text" or function_name == "filter_fulltext" or function_name == "or" or function_name == "and" or function_name == "not":
return dtype('bool')
elif function_name == "trunc":
return dtype('str_')
else:
return param_type

Expand Down
21 changes: 21 additions & 0 deletions python/test_pysdk/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,24 @@ def test_select_round(self, suffix):

res = db_obj.drop_table("test_select_round" + suffix)
assert res.error_code == ErrorCode.OK

def test_select_truncate(self, suffix):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table("test_select_truncate" + suffix, ConflictType.Ignore)
db_obj.create_table("test_select_truncate" + suffix,
{"c1": {"type": "double"},
"c2": {"type": "float"}}, ConflictType.Error)
table_obj = db_obj.get_table("test_select_truncate" + suffix)
table_obj.insert(
[{"c1": "2.123", "c2": "2.123"}, {"c1": "-2.123", "c2": "-2.123"}, {"c1": "2", "c2": "2"}, {"c1": "2.1", "c2":" 2.1"}])

res, extra_res = table_obj.output(["trunc(c1, 2)", "trunc(c2, 2)"]).to_df()
print(res)
pd.testing.assert_frame_equal(res, pd.DataFrame({'(c1 trunc 2)': (" 2.12", " -2.12", " 2.00", " 2.10"),
'(c2 trunc 2)': (" 2.12", " -2.12", " 2.00", " 2.10")})
.astype({'(c1 trunc 2)': dtype('str_'), '(c2 trunc 2)': dtype('str_')}))


res = db_obj.drop_table("test_select_truncate" + suffix)
assert res.error_code == ErrorCode.OK

6 changes: 6 additions & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ module;
#include <forward_list>
#include <functional>
#include <iomanip>
#include <ios>
#include <iostream>
#include <iterator>
#include <list>
Expand Down Expand Up @@ -160,6 +161,11 @@ using std::stable_sort;
using std::tie;
using std::transform;
using std::unique;
using std::setprecision;
using std::fixed;

using std::string;
using std::stringstream;

namespace ranges {

Expand Down
4 changes: 2 additions & 2 deletions src/function/builtin_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import default_values;
import special_function;
import internal_types;
import data_type;

import trunc;
import logical_type;

namespace infinity {
Expand Down Expand Up @@ -118,7 +118,7 @@ void BuiltinFunctions::RegisterScalarFunction() {
RegisterIsnanFunction(catalog_ptr_);
RegisterIsinfFunction(catalog_ptr_);
RegisterIsfiniteFunction(catalog_ptr_);

RegisterTruncFunction(catalog_ptr_);
// register comparison operator
RegisterEqualsFunction(catalog_ptr_);
RegisterInEqualFunction(catalog_ptr_);
Expand Down
113 changes: 113 additions & 0 deletions src/function/scalar/truncate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;
#include <cstdio>
module trunc;
import stl;
import catalog;
import status;
import logical_type;
import infinity_exception;
import scalar_function;
import scalar_function_set;
import third_party;
import internal_types;
import data_type;
import column_vector;

namespace infinity {

struct TruncFunction {
template <typename TA, typename TB, typename TC, typename TD>
static inline void Run(TA left, TB right, TC &result, TD result_ptr) {
Status status = Status::NotSupport("Not implemented");
RecoverableError(status);
}

};

template <>
inline void TruncFunction::Run(DoubleT left, BigIntT right, VarcharT &result, ColumnVector *result_ptr) {
constexpr int MaxRight = 17;
constexpr int MinBufferSize = 50;

if (right < static_cast<BigIntT>(0) || std::isnan(right) || std::isinf(right)) {
Status status = Status::InvalidDataType();
RecoverableError(status);
return;
}

char buffer[MinBufferSize];
buffer[0] =' ';

right = (right > MaxRight) ? MaxRight : right;

int len = std::snprintf(buffer + 1, sizeof(buffer) - 2, "%.*f", (int)right, left);
if (len < 0) {
Status status = Status::InvalidDataType();
RecoverableError(status);
return;
}
std::string truncated_str(buffer, len + 1);
result_ptr->AppendVarcharInner(truncated_str, result);

}

template <>
inline void TruncFunction::Run(FloatT left, BigIntT right, VarcharT &result, ColumnVector *result_ptr) {
constexpr int MaxRight = 7;
constexpr int MinBufferSize = 20;

if (right < static_cast<BigIntT>(0) || std::isnan(right) || std::isinf(right)) {
Status status = Status::InvalidDataType();
RecoverableError(status);
return;
}
char buffer[MinBufferSize];
buffer[0] =' ';
right = (right > MaxRight) ? MaxRight : right;
int len = std::snprintf(buffer + 1, sizeof(buffer) - 2, "%.*f", (int)right, left);
if (len < 0) {
Status status = Status::InvalidDataType();
RecoverableError(status);
return;
}
std::string truncated_str(buffer, len + 1);
result_ptr->AppendVarcharInner(truncated_str, result);
}


void RegisterTruncFunction(const UniquePtr<Catalog> &catalog_ptr) {
String func_name = "trunc";

SharedPtr<ScalarFunctionSet> function_set_ptr = MakeShared<ScalarFunctionSet>(func_name);

ScalarFunction truncate_double_bigint(func_name,
{DataType(LogicalType::kDouble), DataType(LogicalType::kBigInt)},
DataType(LogicalType::kVarchar),
&ScalarFunction::BinaryFunctionToVarlen<DoubleT, BigIntT, VarcharT, TruncFunction>);
function_set_ptr->AddFunction(truncate_double_bigint);

ScalarFunction truncate_float_bigint(func_name,
{DataType(LogicalType::kFloat), DataType(LogicalType::kBigInt)},
DataType(LogicalType::kVarchar),
&ScalarFunction::BinaryFunctionToVarlen<FloatT, BigIntT, VarcharT, TruncFunction>);
function_set_ptr->AddFunction(truncate_float_bigint);


Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr);
}

} // namespace infinity
12 changes: 12 additions & 0 deletions src/function/scalar/truncate.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module;

export module trunc;

import stl;

namespace infinity {

class Catalog;
export void RegisterTruncFunction(const UniquePtr<Catalog> &catalog_ptr);

}
167 changes: 167 additions & 0 deletions src/unit_test/function/scalar/truncate_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include "gtest/gtest.h"

import stl;
import base_test;
import infinity_exception;
import infinity_context;

import catalog;
import logger;

import default_values;
import value;

import base_expression;
import column_expression;
import column_vector;
import data_block;

import function_set;
import function;

import global_resource_usage;

import data_type;
import internal_types;
import logical_type;

import scalar_function;
import scalar_function_set;

import trunc;
import third_party;

using namespace infinity;

class TruncateFunctionsTest : public BaseTestParamStr {};

INSTANTIATE_TEST_SUITE_P(TestWithDifferentParams, TruncateFunctionsTest, ::testing::Values(BaseTestParamStr::NULL_CONFIG_PATH));

TEST_P(TruncateFunctionsTest, truncate_func) {
using namespace infinity;

UniquePtr<Catalog> catalog_ptr = MakeUnique<Catalog>();

RegisterTruncFunction(catalog_ptr);

String op = "trunc";

SharedPtr<FunctionSet> function_set = Catalog::GetFunctionSetByName(catalog_ptr.get(), op);
EXPECT_EQ(function_set->type_, FunctionType::kScalar);
SharedPtr<ScalarFunctionSet> scalar_function_set = std::static_pointer_cast<ScalarFunctionSet>(function_set);

{
Vector<SharedPtr<BaseExpression>> inputs;

DataType data_type1(LogicalType::kFloat);
DataType data_type2(LogicalType::kBigInt);
SharedPtr<DataType> result_type = MakeShared<DataType>(LogicalType::kVarchar);
SharedPtr<ColumnExpression> col1_expr_ptr = MakeShared<ColumnExpression>(data_type1, "t1", 1, "c1", 0, 0);
SharedPtr<ColumnExpression> col2_expr_ptr = MakeShared<ColumnExpression>(data_type2, "t1", 1, "c2", 1, 0);

inputs.emplace_back(col1_expr_ptr);
inputs.emplace_back(col2_expr_ptr);

ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs);
EXPECT_STREQ("trunc(Float, BigInt)->Varchar", func.ToString().c_str());

Vector<SharedPtr<DataType>> column_types;
column_types.emplace_back(MakeShared<DataType>(data_type1));
column_types.emplace_back(MakeShared<DataType>(data_type2));

SizeT row_count = DEFAULT_VECTOR_SIZE;

DataBlock data_block;
data_block.Init(column_types);

for (SizeT i = 0; i < row_count; ++i) {
data_block.AppendValue(0, Value::MakeFloat(static_cast<f32>(i)));
data_block.AppendValue(1, Value::MakeBigInt(static_cast<i64>(i)));
}
data_block.Finalize();

for (SizeT i = 0; i < row_count; ++i) {
Value v1 = data_block.GetValue(0, i);
Value v2 = data_block.GetValue(1, i);
EXPECT_EQ(v1.type_.type(), LogicalType::kFloat);
EXPECT_EQ(v2.type_.type(), LogicalType::kBigInt);
EXPECT_FLOAT_EQ(v1.value_.float32, static_cast<f32>(i));
EXPECT_EQ(v2.value_.big_int, static_cast<i64>(i));
}

SharedPtr<ColumnVector> result = MakeShared<ColumnVector>(result_type);
result->Initialize();
func.function_(data_block, result);

for (SizeT i = 0; i < row_count; ++i) {
Value v = result->GetValue(i);
EXPECT_EQ(v.type_.type(), LogicalType::kVarchar);
}
}

{
Vector<SharedPtr<BaseExpression>> inputs;

DataType data_type1(LogicalType::kDouble);
DataType data_type2(LogicalType::kBigInt);
SharedPtr<DataType> result_type = MakeShared<DataType>(LogicalType::kVarchar);
SharedPtr<ColumnExpression> col1_expr_ptr = MakeShared<ColumnExpression>(data_type1, "t1", 1, "c1", 0, 0);
SharedPtr<ColumnExpression> col2_expr_ptr = MakeShared<ColumnExpression>(data_type2, "t1", 1, "c2", 1, 0);

inputs.emplace_back(col1_expr_ptr);
inputs.emplace_back(col2_expr_ptr);

ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs);
EXPECT_STREQ("trunc(Double, BigInt)->Varchar", func.ToString().c_str());

Vector<SharedPtr<DataType>> column_types;
column_types.emplace_back(MakeShared<DataType>(data_type1));
column_types.emplace_back(MakeShared<DataType>(data_type2));

SizeT row_count = DEFAULT_VECTOR_SIZE;

DataBlock data_block;
data_block.Init(column_types);

for (SizeT i = 0; i < row_count; ++i) {
data_block.AppendValue(0, Value::MakeDouble(static_cast<f64>(i)));
data_block.AppendValue(1, Value::MakeBigInt(static_cast<i64>(i)));
}
data_block.Finalize();

for (SizeT i = 0; i < row_count; ++i) {
Value v1 = data_block.GetValue(0, i);
Value v2 = data_block.GetValue(1, i);
EXPECT_EQ(v1.type_.type(), LogicalType::kDouble);
EXPECT_EQ(v2.type_.type(), LogicalType::kBigInt);
EXPECT_FLOAT_EQ(v1.value_.float64, static_cast<f64>(i));
EXPECT_EQ(v2.value_.big_int, static_cast<i64>(i));
}

SharedPtr<ColumnVector> result = MakeShared<ColumnVector>(result_type);
result->Initialize();
func.function_(data_block, result);

for (SizeT i = 0; i < row_count; ++i) {
Value v = result->GetValue(i);
EXPECT_EQ(v.type_.type(), LogicalType::kVarchar);
}
}

}

0 comments on commit d14df22

Please sign in to comment.