blob: f99d7162dbbe6d1e75608a48437c76018964b9e8 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <string>
#include <vector>
#include "vec/common/assert_cast.h"
#include "vec/core/columns_with_type_and_name.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_number.h"
#include "vec/functions/dictionary.h"
#include "vec/functions/ip_address_dictionary.h"
#include "vec/runtime/ip_address_cidr.h"
namespace doris::vectorized {
struct IPRecord {
IPAddressCIDR ip_with_cidr;
size_t row;
IPv6 to_ipv6() const {
if (const auto* address = ip_with_cidr._address.as_v6()) {
IPv6 ipv6;
memcpy(reinterpret_cast<UInt8*>(&ipv6), address, sizeof(IPv6));
return ipv6;
}
return ipv4_to_ipv6(ip_with_cidr._address.as_v4());
}
UInt8 prefix() const {
if (ip_with_cidr._address.as_v6()) {
return ip_with_cidr._prefix;
}
return ip_with_cidr._prefix + 96;
}
};
class MockIPAddressDictionary : public IDictionary {
public:
MockIPAddressDictionary(std::string name, std::vector<DictionaryAttribute> attributes)
: IDictionary(std::move(name), std::move(attributes)) {}
ColumnPtr get_column(const std::string& attribute_name, const DataTypePtr& attribute_type,
const ColumnPtr& key_column, const DataTypePtr& key_type) const override {
MutableColumnPtr res_column = attribute_type->create_column();
const auto& attribute = _values_data[attribute_index(attribute_name)];
if (key_type->get_primitive_type() == TYPE_IPV6) {
const auto* ipv6_column = assert_cast<const ColumnIPv6*>(key_column.get());
std::visit(
[&](auto&& arg) {
using ValueDataType = std::decay_t<decltype(arg)>;
using AttributeRealColumnType = ValueDataType::OutputColumnType;
auto* res_real_column =
assert_cast<AttributeRealColumnType*>(res_column.get());
const auto* attributes_column = arg.get();
for (size_t i = 0; i < ipv6_column->size(); i++) {
IPv6 ipv6 = ipv6_column->get_element(i);
auto it = lookupIP(ipv6);
if (it == ip_not_found()) {
res_column->insert_default();
} else {
const auto idx = it->row;
res_real_column->insert_value(attributes_column->get_element(idx));
}
}
},
attribute);
} else {
const auto* ipv4_column = assert_cast<const ColumnIPv4*>(key_column.get());
std::visit(
[&](auto&& arg) {
using ValueDataType = std::decay_t<decltype(arg)>;
using AttributeRealColumnType = ValueDataType::OutputColumnType;
auto* res_real_column =
assert_cast<AttributeRealColumnType*>(res_column.get());
const auto* attributes_column = arg.get();
for (size_t i = 0; i < ipv4_column->size(); i++) {
IPv4 ipv4 = ipv4_column->get_element(i);
IPv6 ipv6 = ipv4_to_ipv6(ipv4);
auto it = lookupIP(ipv6);
if (it == ip_not_found()) {
res_column->insert_default();
} else {
const auto idx = it->row;
res_real_column->insert_value(attributes_column->get_element(idx));
}
}
},
attribute);
}
return res_column;
}
static DictionaryPtr create_ip_trie_dict(const std::string& name, ColumnPtr& key_column,
ColumnsWithTypeAndName& attribute_data) {
std::vector<DictionaryAttribute> attributes;
std::vector<ColumnPtr> attributes_column;
for (const auto& att : attribute_data) {
attributes.push_back({att.name, att.type});
attributes_column.push_back(att.column);
}
auto dict = std::make_shared<MockIPAddressDictionary>(name, attributes);
dict->load_data(key_column, attributes_column);
return dict;
}
void load_data(ColumnPtr& key_column, std::vector<ColumnPtr>& attributes_column) {
const auto* str_column = assert_cast<const ColumnString*>(key_column.get());
for (size_t i = 0; i < str_column->size(); i++) {
auto ip_str = str_column->get_element(i);
ip_records.push_back(IPRecord {parse_ip_with_cidr(ip_str), i});
}
std::sort(ip_records.begin(), ip_records.end(),
[&](const IPRecord& a, const IPRecord& b) { return a.prefix() > b.prefix(); });
load_values(attributes_column);
}
using RowIdxConstIter = std::vector<IPRecord>::const_iterator;
RowIdxConstIter lookupIP(IPv6 target) const {
for (auto it = ip_records.begin(); it != ip_records.end(); it++) {
IPRecord ip = *it;
auto ipv6 = ip.to_ipv6();
if (match_ipv6_subnet(reinterpret_cast<const UInt8*>(&target),
reinterpret_cast<const UInt8*>(&ipv6), ip.prefix())) {
return it;
}
}
return ip_not_found();
}
RowIdxConstIter ip_not_found() const { return ip_records.end(); }
std::vector<IPRecord> ip_records;
};
inline DictionaryPtr create_mock_ip_trie_dict_from_column(const std::string& name,
ColumnWithTypeAndName key_data,
ColumnsWithTypeAndName attribute_data) {
auto key_column = key_data.column;
auto key_type = key_data.type;
if (!is_string_type(key_type->get_primitive_type())) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT,
"IPAddressDictionary only support string in key , input key type is {} ",
key_type->get_name());
}
for (auto col_type_name : attribute_data) {
if (col_type_name.type->is_nullable() || col_type_name.column->is_nullable()) {
throw doris::Exception(
ErrorCode::INVALID_ARGUMENT,
"IPAddressDictionary only support nullable attribute , input attribute is {} ",
col_type_name.type->get_name());
}
}
DictionaryPtr dict =
MockIPAddressDictionary::create_ip_trie_dict(name, key_column, attribute_data);
return dict;
}
template <typename IPType, bool output>
void test_for_ip_type(std::vector<std::string> ips, std::vector<std::string> ip_string) {
static_assert(std::is_same_v<IPType, DataTypeIPv4> || std::is_same_v<IPType, DataTypeIPv6>,
"IPType must be either DataTypeIPv4 or DataTypeIPv6");
std::cout << "input data size\t" << ips.size() << "\t" << ip_string.size() << "\n";
auto input_key_column = DataTypeString::ColumnType::create();
auto intput_key_data = std::make_shared<DataTypeString>();
auto value_column = DataTypeInt64::ColumnType::create();
auto value_type = std::make_shared<DataTypeInt64>();
for (int i = 0; i < ips.size(); i++) {
input_key_column->insert_value(ips[i]);
value_column->insert_value(i);
}
auto mock_ip_dict = create_mock_ip_trie_dict_from_column(
"mock ip dict", ColumnWithTypeAndName {input_key_column->clone(), intput_key_data, ""},
ColumnsWithTypeAndName {
ColumnWithTypeAndName {value_column->clone(), value_type, "row"},
});
auto ip_dict = create_ip_trie_dict_from_column(
"ip dict", ColumnWithTypeAndName {input_key_column->clone(), intput_key_data, ""},
ColumnsWithTypeAndName {
ColumnWithTypeAndName {value_column->clone(), value_type, "row"},
});
std::string attribute_name = "row";
DataTypePtr attribute_type = value_type;
{
auto key_type = std::make_shared<IPType>();
auto ipv_column = IPType::ColumnType::create();
for (const auto& ip : ip_string) {
if constexpr (std::is_same_v<IPType, DataTypeIPv4>) {
IPv4 ipv4;
EXPECT_TRUE(IPv4Value::from_string(ipv4, ip));
ipv_column->insert_value(ipv4);
} else {
IPv6 ipv6;
EXPECT_TRUE(IPv6Value::from_string(ipv6, ip));
ipv_column->insert_value(ipv6);
}
}
ColumnPtr key_column = ipv_column->clone();
auto mock_result =
mock_ip_dict->get_column(attribute_name, attribute_type, key_column, key_type);
auto result = ip_dict->get_column(attribute_name, attribute_type, key_column, key_type);
const auto* real_mock_result = assert_cast<const ColumnInt64*>(mock_result.get());
const auto* real_result = assert_cast<const ColumnInt64*>(remove_nullable(result).get());
for (int i = 0; i < ip_string.size(); i++) {
if constexpr (output) {
std::cout << ip_string[i] << "\t" << ips[real_mock_result->get_element(i)] << "\t"
<< ips[real_result->get_element(i)] << "\n";
}
EXPECT_EQ(ips[real_mock_result->get_element(i)], ips[real_result->get_element(i)]);
}
}
}
} // namespace doris::vectorized