blob: 970eae5b501aa87607986f529254fc410fc165aa [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/aggregate_functions/aggregate_function_ai_agg.h"
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <vector>
#include "http/http_client.h"
#include "runtime/query_context.h"
#include "testutil/column_helper.h"
#include "testutil/mock/mock_runtime_state.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/columns/column_string.h"
#include "vec/common/arena.h"
#include "vec/data_types/data_type_string.h"
namespace doris::vectorized {
// declare function
void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory);
class MockHttpClient : public HttpClient {
public:
curl_slist* get() { return this->_header_list; }
private:
std::unordered_map<std::string, std::string> _headers;
std::string _content_type;
};
class AggregateFunctionAIAggTest : public ::testing::Test {
public:
void SetUp() override {
_runtime_state = std::make_unique<MockRuntimeState>();
_query_ctx = _runtime_state->_query_ctx_uptr;
AggregateFunctionSimpleFactory factory;
register_aggregate_function_ai_agg(factory);
_factory = &factory;
_data_types = {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>()};
_agg_function = _factory->get("ai_agg", _data_types, nullptr, false, -1);
ASSERT_TRUE(_agg_function != nullptr);
_agg_function->set_query_context(_query_ctx.get());
}
void TearDown() override { AggregateFunctionAIAggData::_ctx = nullptr; }
protected:
std::unique_ptr<MockRuntimeState> _runtime_state;
std::shared_ptr<QueryContext> _query_ctx;
AggregateFunctionSimpleFactory* _factory;
DataTypes _data_types;
AggregateFunctionPtr _agg_function;
Arena _arena;
};
TEST_F(AggregateFunctionAIAggTest, add_test) {
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
resource_col->insert_data("mock_resource", 13);
text_col->insert_data("Hello world", 11);
task_col->insert_data("summarize this text", 19);
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
_agg_function->add(place, columns, 0, _arena);
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
EXPECT_TRUE(data.inited);
EXPECT_EQ(data.get_task(), "summarize this text");
EXPECT_EQ(data.data.size(), 11); // "Hello world"
_agg_function->destroy(place);
}
TEST_F(AggregateFunctionAIAggTest, multiple_add_test) {
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
std::vector<std::string> texts = {"First text", "Second text", "Third text"};
for (const auto& text : texts) {
resource_col->insert_data("mock_resource", 13);
text_col->insert_data(text.c_str(), text.size());
task_col->insert_data("summarize", 9);
}
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
for (size_t i = 0; i < texts.size(); ++i) {
_agg_function->add(place, columns, i, _arena);
}
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
EXPECT_TRUE(data.inited);
EXPECT_EQ(data.get_task(), "summarize");
std::string expected = "First text\nSecond text\nThird text";
std::string actual(reinterpret_cast<const char*>(data.data.data()), data.data.size());
EXPECT_EQ(actual, expected);
_agg_function->destroy(place);
}
TEST_F(AggregateFunctionAIAggTest, merge_test) {
std::unique_ptr<char[]> memory1(new char[_agg_function->size_of_data()]);
std::unique_ptr<char[]> memory2(new char[_agg_function->size_of_data()]);
AggregateDataPtr place1 = memory1.get();
AggregateDataPtr place2 = memory2.get();
_agg_function->create(place1);
_agg_function->create(place2);
auto resource_col1 = ColumnString::create();
auto text_col1 = ColumnString::create();
auto task_col1 = ColumnString::create();
resource_col1->insert_data("mock_resource", 13);
text_col1->insert_data("First part", 10);
task_col1->insert_data("analyze", 7);
const IColumn* columns1[3] = {resource_col1.get(), text_col1.get(), task_col1.get()};
_agg_function->add(place1, columns1, 0, _arena);
auto resource_col2 = ColumnString::create();
auto text_col2 = ColumnString::create();
auto task_col2 = ColumnString::create();
resource_col2->insert_data("mock_resource", 13);
text_col2->insert_data("Second part", 11);
task_col2->insert_data("analyze", 7);
const IColumn* columns2[3] = {resource_col2.get(), text_col2.get(), task_col2.get()};
_agg_function->add(place2, columns2, 0, _arena);
_agg_function->merge(place1, place2, _arena);
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place1);
std::string actual(reinterpret_cast<const char*>(data.data.data()), data.data.size());
std::string expected = "First part\nSecond part";
EXPECT_EQ(actual, expected);
_agg_function->destroy(place1);
_agg_function->destroy(place2);
}
TEST_F(AggregateFunctionAIAggTest, serialize_deserialize_test) {
std::unique_ptr<char[]> memory1(new char[_agg_function->size_of_data()]);
AggregateDataPtr place1 = memory1.get();
_agg_function->create(place1);
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
resource_col->insert_data("mock_resource", 13);
text_col->insert_data("Test data for serialization", 28);
task_col->insert_data("process", 7);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
_agg_function->add(place1, columns, 0, _arena);
auto serialize_column = _agg_function->create_serialize_column();
_agg_function->serialize_without_key_to_column(place1, *serialize_column);
std::unique_ptr<char[]> memory2(new char[_agg_function->size_of_data()]);
AggregateDataPtr place2 = memory2.get();
_agg_function->create(place2);
_agg_function->deserialize_and_merge_from_column(place2, *serialize_column, _arena);
const auto& data1 = *reinterpret_cast<const AggregateFunctionAIAggData*>(place1);
const auto& data2 = *reinterpret_cast<const AggregateFunctionAIAggData*>(place2);
EXPECT_EQ(data1.inited, data2.inited);
std::string str1(reinterpret_cast<const char*>(data1.data.data()), data1.data.size());
std::string str2(reinterpret_cast<const char*>(data2.data.data()), data2.data.size());
EXPECT_EQ(str1, str2);
_agg_function->destroy(place1);
_agg_function->destroy(place2);
}
TEST_F(AggregateFunctionAIAggTest, reset_test) {
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
resource_col->insert_data("mock_resource", 13);
text_col->insert_data("Some text", 9);
task_col->insert_data("task", 4);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
_agg_function->add(place, columns, 0, _arena);
const auto& data_before = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
EXPECT_TRUE(data_before.inited);
EXPECT_FALSE(data_before.data.empty());
_agg_function->reset(place);
const auto& data_after = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
EXPECT_FALSE(data_after.inited);
EXPECT_TRUE(data_after.data.empty());
EXPECT_TRUE(data_after.get_task().empty());
_agg_function->destroy(place);
}
TEST_F(AggregateFunctionAIAggTest, empty_data_test) {
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
resource_col->insert_data("mock_resource", 13);
text_col->insert_data("", 0);
task_col->insert_data("process", 7);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
_agg_function->add(place, columns, 0, _arena);
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
EXPECT_TRUE(data.inited);
EXPECT_EQ(data.data.size(), 0);
try {
ColumnString to;
_agg_function->insert_result_into(place, to);
} catch (const Exception& e) {
EXPECT_EQ(e.code(), ErrorCode::INVALID_ARGUMENT);
EXPECT_THAT(e.to_string().c_str(), ::testing::EndsWith("data is empty"));
}
_agg_function->destroy(place);
}
TEST_F(AggregateFunctionAIAggTest, merge_empty_test) {
std::unique_ptr<char[]> memory1(new char[_agg_function->size_of_data()]);
std::unique_ptr<char[]> memory2(new char[_agg_function->size_of_data()]);
AggregateDataPtr place1 = memory1.get();
AggregateDataPtr place2 = memory2.get();
_agg_function->create(place1);
_agg_function->create(place2);
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
resource_col->insert_data("mock_resource", 13);
text_col->insert_data("Test data", 9);
task_col->insert_data("analyze", 7);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
_agg_function->add(place1, columns, 0, _arena);
_agg_function->merge(place1, place2, _arena);
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place1);
std::string actual(reinterpret_cast<const char*>(data.data.data()), data.data.size());
EXPECT_EQ(actual, "Test data");
_agg_function->destroy(place1);
_agg_function->destroy(place2);
}
TEST_F(AggregateFunctionAIAggTest, return_type_and_name_test) {
auto return_type = _agg_function->get_return_type();
EXPECT_TRUE(return_type != nullptr);
EXPECT_EQ(return_type->get_name(), "String");
EXPECT_EQ(_agg_function->get_name(), "ai_agg");
}
TEST_F(AggregateFunctionAIAggTest, add_batch_single_place_test) {
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
std::vector<std::string> texts = {"First batch text", "Second batch text", "Third batch text"};
for (const auto& text : texts) {
resource_col->insert_data("mock_resource", 13);
text_col->insert_data(text.c_str(), text.size());
task_col->insert_data("summarize", 9);
}
constexpr size_t batch_size = 3;
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
_agg_function->add_batch_single_place(batch_size, place, columns, _arena);
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
EXPECT_TRUE(data.inited);
EXPECT_EQ(data.get_task(), "summarize");
std::string expected = "First batch text\nSecond batch text\nThird batch text";
std::string actual(reinterpret_cast<const char*>(data.data.data()), data.data.size());
EXPECT_EQ(actual, expected);
_agg_function->destroy(place);
}
TEST_F(AggregateFunctionAIAggTest, add_batch_single_place_multiple_calls_test) {
auto resource_col = ColumnString::create();
auto text_col = ColumnString::create();
auto task_col = ColumnString::create();
std::vector<std::string> first_batch = {"Initial text 1", "Initial text 2"};
for (const auto& text : first_batch) {
resource_col->insert_data("mock_resource", 13);
text_col->insert_data(text.c_str(), text.size());
task_col->insert_data("analyze", 7);
}
constexpr size_t batch_size = 2;
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()};
_agg_function->add_batch_single_place(batch_size, place, columns, _arena);
auto resource_col2 = ColumnString::create();
auto text_col2 = ColumnString::create();
auto task_col2 = ColumnString::create();
std::vector<std::string> second_batch = {"Additional text 1", "Additional text 2"};
for (const auto& text : second_batch) {
resource_col2->insert_data("mock_resource", 13);
text_col2->insert_data(text.c_str(), text.size());
task_col2->insert_data("analyze", 7);
}
const IColumn* columns2[3] = {resource_col2.get(), text_col2.get(), task_col2.get()};
_agg_function->add_batch_single_place(batch_size, place, columns2, _arena);
const auto& data = *reinterpret_cast<const AggregateFunctionAIAggData*>(place);
EXPECT_TRUE(data.inited);
EXPECT_EQ(data.get_task(), "analyze");
std::string expected = "Initial text 1\nInitial text 2\nAdditional text 1\nAdditional text 2";
std::string actual(reinterpret_cast<const char*>(data.data.data()), data.data.size());
EXPECT_EQ(actual, expected);
_agg_function->destroy(place);
}
TEST_F(AggregateFunctionAIAggTest, mock_resource_send_request_test) {
std::vector<std::string> resources = {"mock_resource"};
std::vector<std::string> texts = {"test input"};
std::vector<std::string> task = {"summarize"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
auto col_task = ColumnHelper::create_column<DataTypeString>(task);
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
const IColumn* columns[3] = {col_resource.get(), col_text.get(), col_task.get()};
_agg_function->add(place, columns, 0, _arena);
ColumnString result_column;
_agg_function->insert_result_into(place, result_column);
StringRef result_ref = result_column.get_data_at(0);
std::string result(result_ref.data, result_ref.size);
_agg_function->destroy(place);
}
TEST_F(AggregateFunctionAIAggTest, missing_ai_resources_metadata_test) {
auto empty_query_ctx = MockQueryContext::create();
_agg_function->set_query_context(empty_query_ctx.get());
std::vector<std::string> resources = {"resource_name"};
std::vector<std::string> texts = {"test input"};
std::vector<std::string> task = {"summarize"};
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
auto col_task = ColumnHelper::create_column<DataTypeString>(task);
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
_agg_function->create(place);
const IColumn* columns[3] = {col_resource.get(), col_text.get(), col_task.get()};
try {
_agg_function->add(place, columns, 0, _arena);
FAIL() << "Expected exception for missing AI resources";
} catch (const Exception& e) {
EXPECT_EQ(e.code(), ErrorCode::INTERNAL_ERROR);
EXPECT_NE(e.to_string().find("AI resources metadata missing"), std::string::npos);
}
_agg_function->destroy(place);
}
} // namespace doris::vectorized