blob: 19955c4dabe7578d39fba5f09f83e0c33272e77e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <ranges>
#include <type_traits>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Core/Field.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/Regexps.h>
#include <Common/Exception.h>
#include <Common/OptimizedRegularExpression.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_COLUMN;
}
}
namespace local_engine
{
class TrivialCharSplitter
{
public:
using Pos = const char *;
TrivialCharSplitter(const String & delimiter_) : delimiter(delimiter_) { }
void reset(Pos str_begin_, Pos str_end_)
{
str_begin = str_begin_;
str_end = str_end_;
str_cursor = str_begin;
delimiter_begin = nullptr;
delimiter_end = nullptr;
}
Pos getDelimiterBegin() const { return delimiter_begin; }
Pos getDelimiterEnd() const { return delimiter_end; }
bool next(Pos & token_begin, Pos & token_end)
{
if (str_cursor >= str_end)
return false;
token_begin = str_cursor;
auto next_token_pos = static_cast<Pos>(memmem(str_cursor, str_end - str_cursor, delimiter.c_str(), delimiter.size()));
// If delimiter is not found, return the remaining string.
if (!next_token_pos)
{
token_end = str_end;
str_cursor = str_end;
delimiter_begin = nullptr;
delimiter_end = nullptr;
}
else
{
delimiter_begin = next_token_pos;
token_end = next_token_pos;
str_cursor = next_token_pos + delimiter.size();
delimiter_end = str_cursor;
}
return true;
}
private:
String delimiter;
Pos str_begin;
Pos str_end;
Pos str_cursor;
Pos delimiter_begin;
Pos delimiter_end;
};
struct RegularSplitter
{
public:
using Pos = const char *;
RegularSplitter(const String & delimiter_) : delimiter(delimiter_)
{
if (!delimiter.empty())
re = std::make_shared<OptimizedRegularExpression>(DB::Regexps::createRegexp<false, false, false>(delimiter));
}
void reset(Pos str_begin_, Pos str_end_)
{
str_begin = str_begin_;
str_end = str_end_;
str_cursor = str_begin;
delimiter_begin = nullptr;
delimiter_end = nullptr;
}
Pos getDelimiterBegin() const { return delimiter_begin; }
Pos getDelimiterEnd() const { return delimiter_end; }
bool next(Pos & token_begin, Pos & token_end)
{
if (str_cursor >= str_end)
return false;
// If delimiter is empty, return each character as a token.
if (!re)
{
token_begin = str_cursor;
++str_cursor;
delimiter_begin = str_cursor;
delimiter_end = str_cursor;
token_end = str_cursor;
}
else
{
if (!re->match(str_cursor, str_end - str_cursor, matches))
{
token_begin = str_cursor;
token_end = str_end;
str_cursor = str_end;
delimiter_begin = nullptr;
delimiter_end = nullptr;
return true;
}
token_begin = str_cursor;
token_end = str_cursor + matches[0].offset;
delimiter_begin = token_end;
str_cursor = token_end + matches[0].length;
delimiter_end = str_cursor;
}
return true;
}
private:
String delimiter;
DB::Regexps::RegexpPtr re;
OptimizedRegularExpression::MatchVec matches;
Pos str_begin;
Pos str_end;
Pos str_cursor;
Pos delimiter_begin;
Pos delimiter_end;
};
static bool isConstColumn(const DB::IColumn & col)
{
return col.isConst();
}
template <typename PairGenerator, typename KVGenerator>
class SparkFunctionStrToMap : public DB::IFunction
{
public:
using Pos = const char *;
static constexpr auto name = "spark_str_to_map";
static DB::FunctionPtr create(const DB::ContextPtr) { return std::make_shared<SparkFunctionStrToMap<PairGenerator, KVGenerator>>(); }
String getName() const override { return name; }
bool isVariadic() const override { return false; }
size_t getNumberOfArguments() const override { return 3; }
bool useDefaultImplementationForConstants() const override { return true; }
DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; }
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & arguments) const override
{
DB::FunctionArgumentDescriptors mandatory_args{
{"string", static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString), nullptr, "String"},
{"pair_delimiter",
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
&isConstColumn,
"String"},
{"key_value_delimiter",
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
&isConstColumn,
"String"},
};
validateFunctionArguments(*this, arguments, mandatory_args);
auto map_typ = std::make_shared<DB::DataTypeMap>(
std::make_shared<DB::DataTypeString>(), makeNullable(std::make_shared<DB::DataTypeString>()));
if (arguments[0].type->isNullable())
return std::make_shared<DB::DataTypeNullable>(map_typ);
else
return map_typ;
}
DB::ColumnPtr executeImpl(
const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t /*input_rows_count*/) const override
{
String pair_delim = (*arguments[1].column)[0].safeGet<String>();
String kv_delim = (*arguments[2].column)[0].safeGet<String>();
PairGenerator pair_generator(pair_delim);
KVGenerator kv_generator(kv_delim);
auto col_map = result_type->createColumn();
const DB::ColumnString * col_str = nullptr;
const DB::ColumnUInt8 * null_map = nullptr;
if (arguments[0].column->isNullable())
{
const auto * col_null = DB::checkAndGetColumn<DB::ColumnNullable>(arguments[0].column.get());
col_str = DB::checkAndGetColumn<DB::ColumnString>(col_null->getNestedColumnPtr().get());
null_map = &(col_null->getNullMapColumn());
}
else
{
col_str = DB::checkAndGetColumn<DB::ColumnString>(arguments[0].column.get());
}
const auto & strs = col_str->getChars();
const auto & offsets = col_str->getOffsets();
DB::ColumnString::Offset prev_offset = 0;
for (size_t i = 0, n = offsets.size(); i < n; ++i)
{
if (null_map && (*null_map)[n] != 0)
col_map->insertDefault();
else
{
DB::Map map;
Pos str_begin = reinterpret_cast<Pos>(&strs[prev_offset]);
Pos str_end = reinterpret_cast<Pos>(&strs[offsets[i]]);
LOG_TRACE(
getLogger("SparkFunctionStrToMap"),
"str_begin: {}, str_end: {}. {}",
0,
str_end - str_begin,
std::string_view(str_begin, str_end - str_begin));
pair_generator.reset(str_begin, str_end);
Pos pair_begin;
Pos pair_end;
while (pair_generator.next(pair_begin, pair_end))
{
LOG_TRACE(
getLogger("SparkFunctionStrToMap"),
"pair_begin: {}, pair_end: {}, {}",
pair_begin - str_begin,
pair_end - str_begin,
std::string_view(pair_begin, pair_end - pair_begin));
kv_generator.reset(pair_begin, pair_end);
Pos key_begin;
Pos key_end;
if (kv_generator.next(key_begin, key_end))
{
DB::Tuple tuple(2);
size_t key_len = key_end - key_begin;
tuple[0] = key_end == str_end ? std::string_view(key_begin, key_len - 1) : std::string_view(key_begin, key_len);
auto delimiter_begin = kv_generator.getDelimiterBegin();
auto delimiter_end = kv_generator.getDelimiterEnd();
LOG_TRACE(
getLogger("SparkFunctionStrToMap"),
"key_begin: {}, key_end: {}, delim_begin: {}, delim_end: {}, key:{}",
key_begin - str_begin,
key_end - str_begin,
delimiter_begin - str_begin,
delimiter_end - str_begin,
std::string_view(key_begin, key_end - key_begin));
if (delimiter_begin && delimiter_begin != str_end)
{
DB::Field value = pair_end == str_end ? std::string_view(delimiter_end, pair_end - delimiter_end - 1)
: std::string_view(delimiter_end, pair_end - delimiter_end);
tuple[1] = std::move(value);
}
else
{
// Not found delimiter, the value should be null
tuple[1] = DB::Null();
}
map.emplace_back(std::move(tuple));
}
else if (pair_begin == pair_end)
{
// Empty pair. key is empty string, but value is null.
DB::Tuple tuple(2);
tuple[0] = std::string();
tuple[1] = DB::Null();
map.emplace_back(std::move(tuple));
}
else
{
LOG_WARNING(getLogger("SparkFunctionStrToMap"), "Split key value failed.");
}
}
col_map->insert(std::move(map));
}
prev_offset = offsets[i];
}
return col_map;
}
};
class SparkFunctionStrToMapOverloadResolver : public DB::IFunctionOverloadResolver
{
public:
static constexpr auto name = "spark_str_to_map";
static DB::FunctionOverloadResolverPtr create(const DB::ContextPtr context)
{
return std::make_shared<SparkFunctionStrToMapOverloadResolver>(context);
}
explicit SparkFunctionStrToMapOverloadResolver(DB::ContextPtr context_)
: context(context_), trivial_function(SparkFunctionStrToMap<TrivialCharSplitter, TrivialCharSplitter>::create(context))
{
}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 3; }
bool isVariadic() const override { return false; }
DB::FunctionBasePtr buildImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & return_type) const override
{
// The delimiter could be a regular expression.
bool is_trivial_pair_delim = patternIsTrivialChar(arguments[1]);
bool is_trivial_kv_delim = patternIsTrivialChar(arguments[2]);
DB::FunctionPtr function_ptr = nullptr;
if (is_trivial_pair_delim && is_trivial_kv_delim)
function_ptr = trivial_function;
else if (is_trivial_pair_delim && !is_trivial_kv_delim)
function_ptr = SparkFunctionStrToMap<TrivialCharSplitter, RegularSplitter>::create(context);
else if (!is_trivial_pair_delim && is_trivial_kv_delim)
function_ptr = SparkFunctionStrToMap<RegularSplitter, TrivialCharSplitter>::create(context);
else
function_ptr = SparkFunctionStrToMap<RegularSplitter, RegularSplitter>::create(context);
return std::make_unique<DB::FunctionToFunctionBaseAdaptor>(
function_ptr,
DB::DataTypes{std::from_range_t{}, arguments | std::views::transform([](const auto & elem) { return elem.type; })},
return_type);
}
DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & arguments) const override
{
return trivial_function->getReturnTypeImpl(arguments);
}
private:
DB::ContextPtr context;
DB::FunctionPtr trivial_function;
bool patternIsTrivialChar(const DB::ColumnWithTypeAndName & argument) const
{
const DB::ColumnConst * col = checkAndGetColumnConstStringOrFixedString(argument.column.get());
if (!col)
return false;
String pattern = col->getValue<String>();
if (pattern.empty())
return false;
OptimizedRegularExpression re = DB::Regexps::createRegexp<false, false, false>(pattern);
std::string required_substring;
bool is_trivial;
bool required_substring_is_prefix;
re.getAnalyzeResult(required_substring, is_trivial, required_substring_is_prefix);
return is_trivial && required_substring == pattern;
}
};
REGISTER_FUNCTION(SparkStrToMap)
{
factory.registerFunction<SparkFunctionStrToMapOverloadResolver>();
}
}