blob: 5703ec193b83c57ab30f0ae0b6bf04d088904f4c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest-message.h>
#include <gtest/gtest-test-part.h>
#include <gtest/gtest.h>
#include <stdint.h>
#include <memory>
#include <random>
#include <string>
#include "gtest/gtest_pred_impl.h"
#include "runtime/primitive_type.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_topn.h"
#include "vec/columns/column.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/columns/column_vector.h"
#include "vec/core/field.h"
#include "vec/data_types/data_type_date.h"
#include "vec/data_types/data_type_date_time.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/data_type_string.h"
#include "vec/exprs/vectorized_agg_fn.h"
const int agg_test_batch_size = 4096;
namespace doris::vectorized {
// declare function
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_replace_reader_load(AggregateFunctionSimpleFactory& factory);
TEST(AggTest, basic_test) {
Arena arena;
auto column_vector_int32 = ColumnInt32::create();
for (int i = 0; i < agg_test_batch_size; i++) {
column_vector_int32->insert(Field::create_field<TYPE_INT>(cast_to_nearest_field_type(i)));
}
// test implement interface
AggregateFunctionSimpleFactory factory;
register_aggregate_function_sum(factory);
DataTypePtr data_type(std::make_shared<DataTypeInt32>());
DataTypes data_types = {data_type};
auto agg_function = factory.get("sum", data_types, nullptr, false, -1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[1] = {column_vector_int32.get()};
for (int i = 0; i < agg_test_batch_size; i++) {
agg_function->add(place, column, i, arena);
}
int ans = 0;
for (int i = 0; i < agg_test_batch_size; i++) {
ans += i;
}
EXPECT_EQ(ans, *reinterpret_cast<int32_t*>(place));
agg_function->destroy(place);
}
TEST(AggTest, topn_test) {
Arena arena;
MutableColumns datas(2);
datas[0] = ColumnString::create();
datas[1] = ColumnInt32::create();
int top = 10;
for (int i = 0; i < agg_test_batch_size; i++) {
std::string str = std::to_string(agg_test_batch_size / (i + 1));
datas[0]->insert_data(str.c_str(), str.length());
datas[1]->insert_data(reinterpret_cast<char*>(&top), sizeof(top));
}
AggregateFunctionSimpleFactory factory;
register_aggregate_function_topn(factory);
DataTypes data_types = {std::make_shared<DataTypeString>(), std::make_shared<DataTypeInt32>()};
auto agg_function = factory.get("topn", data_types, nullptr, false, -1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
IColumn* columns[2] = {datas[0].get(), datas[1].get()};
for (int i = 0; i < agg_test_batch_size; i++) {
agg_function->add(place, const_cast<const IColumn**>(columns), i, arena);
}
std::string result = reinterpret_cast<AggregateFunctionTopNData<TYPE_STRING>*>(place)->get();
std::string expect_result =
"{\"1\":2048,\"2\":683,\"3\":341,\"4\":205,\"5\":137,\"6\":97,\"7\":73,\"8\":57,\"9\":"
"46,\"10\":37}";
EXPECT_EQ(result, expect_result);
agg_function->destroy(place);
}
TEST(AggTest, window_function_test) {
AggregateFunctionSimpleFactory factory;
register_aggregate_function_bit(factory);
DataTypes data_types = {make_nullable(std::make_shared<DataTypeInt8>())};
bool is_window_function = true;
auto agg_function = factory.get("group_bit_or", data_types, nullptr, true, -1,
{.is_window_function = is_window_function, .column_names = {}});
auto size = agg_function->size_of_data();
EXPECT_EQ(size, 6);
size = agg_function->align_of_data();
EXPECT_EQ(size, 4);
is_window_function = false;
auto agg_function2 =
factory.get("group_bit_or", data_types, nullptr, true, -1,
{.is_window_function = is_window_function, .column_names = {}});
auto size2 = agg_function2->size_of_data();
EXPECT_EQ(size2, 2);
size2 = agg_function2->align_of_data();
EXPECT_EQ(size2, 1);
}
TEST(AggTest, window_function_test2) {
AggregateFunctionSimpleFactory factory;
register_aggregate_function_sum(factory);
register_aggregate_function_topn(factory);
register_aggregate_function_bit(factory);
bool is_window_function = true;
auto agg_function_sum =
factory.get("sum", {std::make_shared<DataTypeInt32>()}, nullptr, true, -1,
{.is_window_function = is_window_function, .column_names = {}});
auto agg_function_group =
factory.get("group_bit_or", {make_nullable(std::make_shared<DataTypeInt8>())}, nullptr,
true, -1, {.is_window_function = is_window_function, .column_names = {}});
auto agg_function_topn = factory.get(
"topn", {std::make_shared<DataTypeString>(), std::make_shared<DataTypeInt32>()},
nullptr, true, -1, {.is_window_function = is_window_function, .column_names = {}});
std::vector<vectorized::AggregateFunctionPtr> _agg_functions;
_agg_functions.push_back(agg_function_sum);
_agg_functions.push_back(agg_function_group);
_agg_functions.push_back(agg_function_topn);
auto lambda_function = [&]() {
/// The offset of the n-th functions.
std::vector<size_t> _offsets_of_aggregate_states;
/// The total size of the row from the functions.
size_t _total_size_of_aggregate_states = 0;
/// The max align size for functions
size_t _align_aggregate_states = 1;
auto _agg_functions_size = 3;
_offsets_of_aggregate_states.resize(_agg_functions_size);
for (size_t i = 0; i < _agg_functions_size; ++i) {
_offsets_of_aggregate_states[i] = _total_size_of_aggregate_states;
// aggregate states are aligned based on maximum requirement
_align_aggregate_states =
std::max(_align_aggregate_states, _agg_functions[i]->align_of_data());
_total_size_of_aggregate_states += _agg_functions[i]->size_of_data();
// If not the last aggregate_state, we need pad it so that next aggregate_state will be aligned.
if (i + 1 < _agg_functions_size) {
size_t alignment_of_next_state = _agg_functions[i + 1]->align_of_data();
if ((alignment_of_next_state & (alignment_of_next_state - 1)) != 0) {
DCHECK(false) << "Logical error: align_of_data is not 2^N "
<< alignment_of_next_state;
}
/// Extend total_size to next alignment requirement
/// Add padding by rounding up 'total_size_of_aggregate_states' to be a multiplier of alignment_of_next_state.
_total_size_of_aggregate_states =
(_total_size_of_aggregate_states + alignment_of_next_state - 1) /
alignment_of_next_state * alignment_of_next_state;
}
}
vectorized::Arena _agg_arena_pool;
vectorized::AggregateDataPtr _fn_place_ptr = _agg_arena_pool.aligned_alloc(
_total_size_of_aggregate_states, _align_aggregate_states);
EXPECT_TRUE(_fn_place_ptr != nullptr);
for (size_t i = 0; i < _agg_functions_size; ++i) {
_agg_functions[i]->create(_fn_place_ptr + _offsets_of_aggregate_states[i]);
}
};
std::random_device rd;
std::mt19937 g(rd());
for (int i = 0; i < 5; ++i) {
std::shuffle(_agg_functions.begin(), _agg_functions.end(), g);
lambda_function();
}
}
TEST(AggTest, datetime_min_max_test) {
Arena arena;
auto column_datetime = ColumnDateTime::create();
// Insert DateTime values (stored as Int64)
// Format: YYYYMMDDHHmmss as Int64
std::vector<int64_t> datetime_values = {
20230101120000LL, // 2023-01-01 12:00:00
20230615093000LL, // 2023-06-15 09:30:00
20220315180000LL, // 2022-03-15 18:00:00
20240801220000LL, // 2024-08-01 22:00:00
20230301060000LL // 2023-03-01 06:00:00
};
for (auto val : datetime_values) {
column_datetime->insert_data(reinterpret_cast<const char*>(&val), sizeof(int64_t));
}
// Test MIN
{
AggregateFunctionSimpleFactory factory;
register_aggregate_function_minmax(factory);
DataTypePtr data_type = std::make_shared<DataTypeDateTime>();
DataTypes data_types = {data_type};
auto agg_function = factory.get("min", data_types, data_type, false, -1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[1] = {column_datetime.get()};
for (size_t i = 0; i < datetime_values.size(); i++) {
agg_function->add(place, column, i, arena);
}
// Get result by inserting into a column
auto result_column = ColumnDateTime::create();
agg_function->insert_result_into(place, *result_column);
int64_t result = result_column->get_element(0);
EXPECT_EQ(result, 20220315180000LL); // Minimum datetime
agg_function->destroy(place);
}
// Test MAX
{
AggregateFunctionSimpleFactory factory;
register_aggregate_function_minmax(factory);
DataTypePtr data_type = std::make_shared<DataTypeDateTime>();
DataTypes data_types = {data_type};
auto agg_function = factory.get("max", data_types, data_type, false, -1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[1] = {column_datetime.get()};
for (size_t i = 0; i < datetime_values.size(); i++) {
agg_function->add(place, column, i, arena);
}
// Get result by inserting into a column
auto result_column = ColumnDateTime::create();
agg_function->insert_result_into(place, *result_column);
int64_t result = result_column->get_element(0);
EXPECT_EQ(result, 20240801220000LL); // Maximum datetime
agg_function->destroy(place);
}
}
TEST(AggTest, date_replace_test) {
Arena arena;
auto column_date = ColumnDate::create();
// Insert Date values (stored as Int64)
std::vector<int64_t> date_values = {100, 200, 300, 400, 500};
for (auto val : date_values) {
column_date->insert_data(reinterpret_cast<const char*>(&val), sizeof(int64_t));
}
AggregateFunctionSimpleFactory factory;
register_aggregate_function_replace_reader_load(factory);
DataTypePtr data_type = std::make_shared<DataTypeDate>();
DataTypes data_types = {data_type};
auto agg_function = factory.get("replace_load", data_types, data_type, false, -1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[1] = {column_date.get()};
for (size_t i = 0; i < date_values.size(); i++) {
agg_function->add(place, column, i, arena);
}
// REPLACE should return the last value
auto result_column = ColumnDate::create();
agg_function->insert_result_into(place, *result_column);
int64_t result = result_column->get_element(0);
EXPECT_EQ(result, 500LL);
agg_function->destroy(place);
}
TEST(AggTest, datetime_replace_if_not_null_test) {
Arena arena;
auto column_datetime = ColumnDateTime::create();
auto null_map = ColumnUInt8::create();
// Insert DateTime values with some nulls (stored as Int64 in YYYYMMDDHHmmss format)
std::vector<int64_t> datetime_values = {20230101120000LL, 20230615093000LL, 20220315180000LL,
20240801220000LL, 20230301060000LL};
std::vector<uint8_t> nulls = {0, 1, 0, 0, 0}; // Second value is null
for (auto val : datetime_values) {
column_datetime->insert_data(reinterpret_cast<const char*>(&val), sizeof(int64_t));
}
for (auto null : nulls) {
null_map->insert_data(reinterpret_cast<const char*>(&null), sizeof(uint8_t));
}
auto nullable_column = ColumnNullable::create(std::move(column_datetime), std::move(null_map));
AggregateFunctionSimpleFactory factory;
register_aggregate_function_replace_reader_load(factory);
DataTypePtr data_type = make_nullable(std::make_shared<DataTypeDateTime>());
DataTypes data_types = {data_type};
auto agg_function = factory.get("replace_if_not_null_load", data_types, data_type, true, -1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[1] = {nullable_column.get()};
for (size_t i = 0; i < datetime_values.size(); i++) {
agg_function->add(place, column, i, arena);
}
// REPLACE_IF_NOT_NULL should return the last non-null value
// Need to use nullable column for result since input is nullable
auto result_nested_column = ColumnDateTime::create();
auto result_null_map = ColumnUInt8::create();
auto result_column =
ColumnNullable::create(std::move(result_nested_column), std::move(result_null_map));
agg_function->insert_result_into(place, *result_column);
// Get the nested column and read the value
const auto& nested_col = assert_cast<const ColumnDateTime&>(result_column->get_nested_column());
int64_t result = nested_col.get_element(0);
EXPECT_EQ(result, 20230301060000LL); // Last value
agg_function->destroy(place);
}
} // namespace doris::vectorized