| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include <gtest/gtest-message.h> |
| #include <gtest/gtest-test-part.h> |
| #include <stddef.h> |
| #include <stdint.h> |
| |
| #include <memory> |
| #include <ostream> |
| #include <string> |
| |
| #include "agg_function_test.h" |
| #include "common/logging.h" |
| #include "gtest/gtest_pred_impl.h" |
| #include "vec/aggregate_functions/aggregate_function.h" |
| #include "vec/aggregate_functions/aggregate_function_simple_factory.h" |
| #include "vec/columns/column_array.h" |
| #include "vec/columns/column_string.h" |
| #include "vec/columns/column_vector.h" |
| #include "vec/common/arena.h" |
| #include "vec/common/string_buffer.hpp" |
| #include "vec/core/types.h" |
| #include "vec/data_types/data_type.h" |
| #include "vec/data_types/data_type_array.h" |
| #include "vec/data_types/data_type_date_or_datetime_v2.h" |
| #include "vec/data_types/data_type_decimal.h" |
| #include "vec/data_types/data_type_nullable.h" |
| #include "vec/data_types/data_type_number.h" |
| #include "vec/data_types/data_type_string.h" |
| |
| namespace doris { |
| namespace vectorized { |
| class IColumn; |
| } // namespace vectorized |
| } // namespace doris |
| |
| namespace doris::vectorized { |
| |
| void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory); |
| |
| class VAggCollectTest : public testing::Test { |
| public: |
| void SetUp() { |
| AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance(); |
| register_aggregate_function_collect_list(factory); |
| } |
| |
| void TearDown() {} |
| |
| bool is_distinct(const std::string& fn_name) { return fn_name == "collect_set"; } |
| |
| template <typename DataType> |
| void agg_collect_add_elements(AggregateFunctionPtr agg_function, AggregateDataPtr place, |
| size_t input_nums, bool support_complex = false) { |
| using FieldType = typename DataType::FieldType; |
| MutableColumnPtr input_col; |
| if (support_complex) { |
| auto type = |
| std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataType>())); |
| input_col = type->create_column(); |
| } else { |
| auto type = std::make_shared<DataType>(); |
| input_col = type->create_column(); |
| } |
| for (size_t i = 0; i < input_nums; ++i) { |
| for (size_t j = 0; j < _repeated_times; ++j) { |
| if (support_complex) { |
| if constexpr (std::is_same_v<DataType, DataTypeString>) { |
| Array vec1 = {Field::create_field<TYPE_STRING>( |
| String("item0" + std::to_string(i))), |
| Field::create_field<TYPE_STRING>( |
| String("item1" + std::to_string(i)))}; |
| input_col->insert(Field::create_field<TYPE_ARRAY>(vec1)); |
| } else { |
| input_col->insert_default(); |
| } |
| continue; |
| } |
| if constexpr (std::is_same_v<DataType, DataTypeString>) { |
| auto item = std::string("item") + std::to_string(i); |
| input_col->insert_data(item.c_str(), item.size()); |
| } else { |
| auto item = FieldType(static_cast<uint64_t>(i)); |
| input_col->insert_data(reinterpret_cast<const char*>(&item), 0); |
| } |
| } |
| } |
| EXPECT_EQ(input_col->size(), input_nums * _repeated_times); |
| |
| const IColumn* column[1] = {input_col.get()}; |
| for (int i = 0; i < input_col->size(); i++) { |
| agg_function->add(place, column, i, _agg_arena_pool); |
| } |
| } |
| |
| template <typename DataType> |
| void test_agg_collect(const std::string& fn_name, size_t input_nums = 0, |
| bool support_complex = false) { |
| DataTypes data_types = {(DataTypePtr)std::make_shared<DataType>()}; |
| if (support_complex) { |
| data_types = { |
| (DataTypePtr)std::make_shared<DataTypeArray>(make_nullable(data_types[0]))}; |
| } |
| LOG(INFO) << "test_agg_collect for " << fn_name << "(" << data_types[0]->get_name() << ")"; |
| AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance(); |
| auto agg_function = factory.get(fn_name, data_types, nullptr, false, -1); |
| EXPECT_NE(agg_function, nullptr); |
| |
| std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); |
| AggregateDataPtr place = memory.get(); |
| agg_function->create(place); |
| |
| agg_collect_add_elements<DataType>(agg_function, place, input_nums, support_complex); |
| |
| ColumnString buf; |
| VectorBufferWriter buf_writer(buf); |
| agg_function->serialize(place, buf_writer); |
| buf_writer.commit(); |
| VectorBufferReader buf_reader(buf.get_data_at(0)); |
| agg_function->deserialize(place, buf_reader, _agg_arena_pool); |
| |
| std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]); |
| AggregateDataPtr place2 = memory2.get(); |
| agg_function->create(place2); |
| |
| agg_collect_add_elements<DataType>(agg_function, place2, input_nums, support_complex); |
| |
| agg_function->merge(place, place2, _agg_arena_pool); |
| auto column_result = |
| ColumnArray::create(std::move(make_nullable(data_types[0]->create_column()))); |
| agg_function->insert_result_into(place, column_result->assume_mutable_ref()); |
| EXPECT_EQ(column_result->size(), 1); |
| EXPECT_EQ(column_result->get_offsets()[0], |
| is_distinct(fn_name) ? input_nums : 2 * input_nums * _repeated_times); |
| |
| auto column_result2 = |
| ColumnArray::create(std::move(make_nullable(data_types[0]->create_column()))); |
| agg_function->insert_result_into(place2, column_result2->assume_mutable_ref()); |
| EXPECT_EQ(column_result2->size(), 1); |
| EXPECT_EQ(column_result2->get_offsets()[0], |
| is_distinct(fn_name) ? input_nums : input_nums * _repeated_times); |
| |
| agg_function->destroy(place); |
| agg_function->destroy(place2); |
| } |
| |
| private: |
| const size_t _repeated_times = 2; |
| vectorized::Arena _agg_arena_pool; |
| }; |
| |
| TEST_F(VAggCollectTest, test_empty) { |
| test_agg_collect<DataTypeInt8>("collect_list"); |
| test_agg_collect<DataTypeInt8>("collect_set"); |
| test_agg_collect<DataTypeInt16>("collect_list"); |
| test_agg_collect<DataTypeInt16>("collect_set"); |
| test_agg_collect<DataTypeInt32>("collect_list"); |
| test_agg_collect<DataTypeInt32>("collect_set"); |
| test_agg_collect<DataTypeInt64>("collect_list"); |
| test_agg_collect<DataTypeInt64>("collect_set"); |
| test_agg_collect<DataTypeInt128>("collect_list"); |
| test_agg_collect<DataTypeInt128>("collect_set"); |
| |
| test_agg_collect<DataTypeDecimalV2>("collect_list"); |
| test_agg_collect<DataTypeDecimalV2>("collect_set"); |
| |
| test_agg_collect<DataTypeDateV2>("collect_list"); |
| test_agg_collect<DataTypeDateV2>("collect_set"); |
| |
| test_agg_collect<DataTypeString>("collect_list"); |
| test_agg_collect<DataTypeString>("collect_set"); |
| } |
| |
| TEST_F(VAggCollectTest, test_with_data) { |
| test_agg_collect<DataTypeInt32>("collect_list", 7); |
| test_agg_collect<DataTypeInt32>("collect_set", 9); |
| test_agg_collect<DataTypeInt128>("collect_list", 20); |
| test_agg_collect<DataTypeInt128>("collect_set", 30); |
| |
| test_agg_collect<DataTypeDecimalV2>("collect_list", 10); |
| test_agg_collect<DataTypeDecimalV2>("collect_set", 11); |
| |
| test_agg_collect<DataTypeDateTimeV2>("collect_list", 5); |
| test_agg_collect<DataTypeDateTimeV2>("collect_set", 6); |
| |
| test_agg_collect<DataTypeString>("collect_list", 10); |
| test_agg_collect<DataTypeString>("collect_set", 5); |
| } |
| |
| TEST_F(VAggCollectTest, test_complex_data_type) { |
| test_agg_collect<DataTypeInt8>("collect_list", 7, true); |
| test_agg_collect<DataTypeInt128>("array_agg", 9, true); |
| |
| test_agg_collect<DataTypeDateTimeV2>("collect_list", 5, true); |
| test_agg_collect<DataTypeDateTimeV2>("array_agg", 6, true); |
| |
| test_agg_collect<DataTypeString>("collect_list", 10, true); |
| test_agg_collect<DataTypeString>("array_agg", 5, true); |
| } |
| |
| struct AggregateFunctionCollectTest : public AggregateFunctiontest {}; |
| |
| TEST_F(AggregateFunctionCollectTest, test_collect_list_aint64) { |
| create_agg("collect_list", false, {std::make_shared<DataTypeInt64>()}, |
| std::make_shared<DataTypeInt64>()); |
| |
| auto data_type = std::make_shared<DataTypeInt64>(); |
| auto array_data_type = std::make_shared<DataTypeArray>(make_nullable(data_type)); |
| |
| auto off_column = ColumnOffset64::create(); |
| auto data_column = ColumnInt64::create(); |
| std::vector<ColumnArray::Offset64> offs = {0, 3}; |
| std::vector<int64_t> vals = {1, 2, 3}; |
| for (size_t i = 1; i < offs.size(); ++i) { |
| off_column->insert_data((const char*)(&offs[i]), 0); |
| } |
| for (auto& v : vals) { |
| data_column->insert_data((const char*)(&v), 0); |
| } |
| auto array_column = |
| ColumnArray::create(make_nullable(data_column->clone()), std::move(off_column)); |
| |
| execute(Block({ColumnHelper::create_column_with_name<DataTypeInt64>({1, 2, 3})}), |
| ColumnWithTypeAndName(std::move(array_column), array_data_type, "column")); |
| } |
| |
| TEST_F(AggregateFunctionCollectTest, test_collect_list_aint64_with_max_size) { |
| create_agg("collect_list", false, |
| {std::make_shared<DataTypeInt64>(), std::make_shared<DataTypeInt32>()}, |
| std::make_shared<DataTypeInt64>()); |
| |
| auto data_type = std::make_shared<DataTypeInt64>(); |
| auto array_data_type = std::make_shared<DataTypeArray>(make_nullable(data_type)); |
| |
| auto off_column = ColumnOffset64::create(); |
| auto data_column = ColumnInt64::create(); |
| std::vector<ColumnArray::Offset64> offs = {0, 3}; |
| std::vector<int64_t> vals = {1, 2, 3}; |
| for (size_t i = 1; i < offs.size(); ++i) { |
| off_column->insert_data((const char*)(&offs[i]), 0); |
| } |
| for (auto& v : vals) { |
| data_column->insert_data((const char*)(&v), 0); |
| } |
| auto array_column = |
| ColumnArray::create(make_nullable(data_column->clone()), std::move(off_column)); |
| |
| execute(Block({ColumnHelper::create_column_with_name<DataTypeInt64>({1, 2, 3, 4}), |
| ColumnHelper::create_column_with_name<DataTypeInt32>({3, 3, 3, 3})}), |
| ColumnWithTypeAndName(std::move(array_column), array_data_type, "column")); |
| } |
| |
| TEST_F(AggregateFunctionCollectTest, test_collect_set_aint64) { |
| create_agg("collect_set", false, {std::make_shared<DataTypeInt64>()}, |
| std::make_shared<DataTypeInt64>()); |
| |
| auto data_type = std::make_shared<DataTypeInt64>(); |
| auto array_data_type = std::make_shared<DataTypeArray>(make_nullable(data_type)); |
| |
| auto off_column = ColumnOffset64::create(); |
| auto data_column = ColumnInt64::create(); |
| std::vector<ColumnArray::Offset64> offs = {0, 3}; |
| std::vector<int64_t> vals = {2, 1, 3}; |
| for (size_t i = 1; i < offs.size(); ++i) { |
| off_column->insert_data((const char*)(&offs[i]), 0); |
| } |
| for (auto& v : vals) { |
| data_column->insert_data((const char*)(&v), 0); |
| } |
| auto array_column = |
| ColumnArray::create(make_nullable(data_column->clone()), std::move(off_column)); |
| |
| execute(Block({ColumnHelper::create_column_with_name<DataTypeInt64>({1, 2, 3})}), |
| ColumnWithTypeAndName(std::move(array_column), array_data_type, "column")); |
| } |
| |
| TEST_F(AggregateFunctionCollectTest, test_collect_set_aint64_with_max_size) { |
| create_agg("collect_set", false, |
| {std::make_shared<DataTypeInt64>(), std::make_shared<DataTypeInt32>()}, |
| std::make_shared<DataTypeInt64>()); |
| |
| auto data_type = std::make_shared<DataTypeInt64>(); |
| auto array_data_type = std::make_shared<DataTypeArray>(make_nullable(data_type)); |
| |
| auto off_column = ColumnOffset64::create(); |
| auto data_column = ColumnInt64::create(); |
| std::vector<ColumnArray::Offset64> offs = {0, 3}; |
| std::vector<int64_t> vals = {2, 1, 3}; |
| for (size_t i = 1; i < offs.size(); ++i) { |
| off_column->insert_data((const char*)(&offs[i]), 0); |
| } |
| for (auto& v : vals) { |
| data_column->insert_data((const char*)(&v), 0); |
| } |
| auto array_column = |
| ColumnArray::create(make_nullable(data_column->clone()), std::move(off_column)); |
| |
| execute(Block({ColumnHelper::create_column_with_name<DataTypeInt64>({1, 2, 3, 4, 3}), |
| ColumnHelper::create_column_with_name<DataTypeInt32>({3, 3, 3, 3, 3})}), |
| ColumnWithTypeAndName(std::move(array_column), array_data_type, "column")); |
| } |
| |
| } // namespace doris::vectorized |