blob: 68136713f59c6d0ecc9273cc7eeb9cc5f00f1017 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/Regexps.h>
#include <Interpreters/Context.h>
#include <Common/FunctionDocumentation.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_COLUMN;
extern const int INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE;
}
}
using namespace DB;
namespace local_engine
{
using SparkRegexp = OptimizedRegularExpression;
namespace
{
class FunctionRegexpExtractAllSpark : public IFunction
{
public:
using Pos = const char *;
static constexpr auto name = "regexpExtractAllSpark";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionRegexpExtractAllSpark>(); }
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() != 2 && arguments.size() != 3)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}",
getName(),
arguments.size());
FunctionArgumentDescriptors args{
{"haystack", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isString), nullptr, "String"},
{"pattern", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isString), isColumnConst, "const String"},
};
if (arguments.size() == 3)
args.emplace_back(FunctionArgumentDescriptor{"index", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isInteger), nullptr, "Integer"});
validateFunctionArguments(*this, arguments, args);
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>());
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
{
const ColumnPtr column = arguments[0].column;
const ColumnPtr column_pattern = arguments[1].column;
const ColumnPtr column_index = arguments.size() > 2 ? arguments[2].column : nullptr;
/// Check if the second argument is const column
const ColumnConst * col_pattern = typeid_cast<const ColumnConst *>(column_pattern.get());
if (!col_pattern)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be constant string", getName());
/// Check if the first argument is string column(const or not)
const ColumnString * col = nullptr;
const ColumnConst * col_const = typeid_cast<const ColumnConst *>(column.get());
if (col_const)
col = typeid_cast<const ColumnString *>(&col_const->getDataColumn());
else
col = typeid_cast<const ColumnString *>(column.get());
if (!col)
throw Exception(
ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", arguments[0].column->getName(), getName());
auto col_res = ColumnArray::create(ColumnString::create());
ColumnString & res_strings = typeid_cast<ColumnString &>(col_res->getData());
ColumnArray::Offsets & res_offsets = col_res->getOffsets();
ColumnString::Chars & res_strings_chars = res_strings.getChars();
ColumnString::Offsets & res_strings_offsets = res_strings.getOffsets();
if (col_const)
constantVector(
col_const->getValue<String>(),
col_pattern->getValue<String>(),
column_index,
res_offsets,
res_strings_chars,
res_strings_offsets);
else if (!column_index || isColumnConst(*column_index))
{
const auto * col_const_index = typeid_cast<const ColumnConst *>(column_index.get());
ssize_t index = !col_const_index ? 1 : col_const_index->getInt(0);
vectorConstant(
col->getChars(),
col->getOffsets(),
col_pattern->getValue<String>(),
index,
res_offsets,
res_strings_chars,
res_strings_offsets);
}
else
vectorVector(
col->getChars(),
col->getOffsets(),
col_pattern->getValue<String>(),
column_index,
res_offsets,
res_strings_chars,
res_strings_offsets);
return col_res;
}
private:
static void saveMatchs(
Pos start,
Pos end,
const SparkRegexp & regexp,
OptimizedRegularExpression::MatchVec & matches,
size_t match_index,
ColumnArray::Offsets & res_offsets,
ColumnString::Chars & res_strings_chars,
ColumnString::Offsets & res_strings_offsets,
size_t & res_offset,
size_t & res_strings_offset)
{
size_t i = 0;
Pos pos = start;
while (pos < end)
{
regexp.match(pos, end - pos, matches, static_cast<unsigned>(match_index + 1));
if (match_index >= matches.size())
break;
/// Append matched segment into res_strings_chars
const auto & match = matches[match_index];
if (match.offset != std::string::npos)
{
res_strings_chars.resize_exact(res_strings_offset + match.length + 1);
memcpySmallAllowReadWriteOverflow15(&res_strings_chars[res_strings_offset], pos + match.offset, match.length);
res_strings_offset += match.length;
}
else
res_strings_chars.resize_exact(res_strings_offset + 1);
/// Update offsets of Column:String
res_strings_chars[res_strings_offset] = 0;
++res_strings_offset;
res_strings_offsets.push_back(res_strings_offset);
++i;
pos += matches[0].offset + matches[0].length;
}
/// Update offsets of Column:Array(String)
res_offset += i;
res_offsets.push_back(res_offset);
}
static void vectorConstant(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
const std::string & pattern,
ssize_t index,
ColumnArray::Offsets & res_offsets,
ColumnString::Chars & res_strings_chars,
ColumnString::Offsets & res_strings_offsets)
{
const SparkRegexp regexp = Regexps::createRegexp<false, false, false>(pattern);
unsigned capture = regexp.getNumberOfSubpatterns();
if (index < 0 || index >= capture + 1)
throw Exception(
ErrorCodes::INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE,
"Index value {} is out of range, should be in [0, {})",
index,
capture + 1);
OptimizedRegularExpression::MatchVec matches;
matches.reserve(index + 1);
res_offsets.reserve_exact(offsets.size());
res_strings_chars.reserve_exact(data.size() / 3);
res_strings_offsets.reserve_exact(offsets.size() * 2);
size_t res_offset = 0;
size_t res_strings_offset = 0;
size_t prev_offset = 0;
for (size_t cur_offset : offsets)
{
Pos start = reinterpret_cast<const char *>(&data[prev_offset]);
Pos end = start + (cur_offset - prev_offset - 1);
saveMatchs(
start,
end,
regexp,
matches,
index,
res_offsets,
res_strings_chars,
res_strings_offsets,
res_offset,
res_strings_offset);
prev_offset = cur_offset;
}
}
static void vectorVector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
const std::string & pattern,
const ColumnPtr & column_index,
ColumnArray::Offsets & res_offsets,
ColumnString::Chars & res_strings_chars,
ColumnString::Offsets & res_strings_offsets)
{
const SparkRegexp regexp = Regexps::createRegexp<false, false, false>(pattern);
unsigned capture = regexp.getNumberOfSubpatterns();
OptimizedRegularExpression::MatchVec matches;
matches.reserve(capture + 1);
res_offsets.reserve_exact(offsets.size());
res_strings_chars.reserve_exact(data.size() / 3);
res_strings_offsets.reserve_exact(offsets.size() * 2);
size_t res_offset = 0;
size_t res_strings_offset = 0;
size_t prev_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
ssize_t index = column_index->getInt(i);
if (index < 0 || index >= capture + 1)
throw Exception(
ErrorCodes::INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE,
"Index value {} is out of range, should be in [0, {})",
index,
capture + 1);
size_t cur_offset = offsets[i];
Pos start = reinterpret_cast<const char *>(&data[prev_offset]);
Pos end = start + (cur_offset - prev_offset - 1);
saveMatchs(
start,
end,
regexp,
matches,
index,
res_offsets,
res_strings_chars,
res_strings_offsets,
res_offset,
res_strings_offset);
prev_offset = cur_offset;
}
}
static void constantVector(
const std::string & str,
const std::string & pattern,
const ColumnPtr & column_index,
ColumnArray::Offsets & res_offsets,
ColumnString::Chars & res_strings_chars,
ColumnString::Offsets & res_strings_offsets)
{
const SparkRegexp regexp = Regexps::createRegexp<false, false, false>(pattern);
unsigned capture = regexp.getNumberOfSubpatterns();
/// Copy data into padded array to be able to use memcpySmallAllowReadWriteOverflow15.
ColumnString::Chars padded_str;
padded_str.insert(str.begin(), str.end());
Pos start = reinterpret_cast<Pos>(padded_str.data());
Pos pos = start;
Pos end = start + padded_str.size();
Pos prev_pos = nullptr;
std::vector<OptimizedRegularExpression::MatchVec> matches_groups;
while (pos < end)
{
OptimizedRegularExpression::MatchVec matches;
matches.reserve(capture + 1);
regexp.match(pos, end - pos, matches, static_cast<unsigned>(capture + 1));
if (capture + 1 > matches.size())
break;
/// Make all the offsets based on start
for (auto & match : matches)
if (match.offset != std::string::npos)
match.offset += pos - start;
prev_pos = pos;
pos = start + matches[0].offset + matches[0].length;
/// Avoid dead loop caused by empty captured string
if (pos == prev_pos)
++pos;
matches_groups.emplace_back(std::move(matches));
}
size_t rows = column_index->size();
res_offsets.reserve_exact(rows);
res_strings_chars.reserve_exact(rows * str.size() / 3);
res_strings_offsets.reserve_exact(rows * 2);
size_t res_offset = 0;
size_t res_strings_offset = 0;
for (size_t row_i = 0; row_i < rows; ++row_i)
{
ssize_t index = column_index->getInt(row_i);
if (index < 0 || index >= capture + 1)
throw Exception(
ErrorCodes::INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE,
"Index value {} is out of range, should be in [0, {})",
index,
capture + 1);
for (auto & matches : matches_groups)
{
const auto & match = matches[index];
/// Append matched segment into res_strings_chars
if (match.offset != std::string::npos)
{
res_strings_chars.resize_exact(res_strings_offset + match.length + 1);
memcpySmallAllowReadWriteOverflow15(&res_strings_chars[res_strings_offset], start + match.offset, match.length);
res_strings_offset += match.length;
}
else
res_strings_chars.resize_exact(res_strings_offset + 1);
/// Update offsets of Column:String
res_strings_chars[res_strings_offset] = 0;
++res_strings_offset;
res_strings_offsets.push_back(res_strings_offset);
}
/// Update offsets of Column:Array(String)
res_offset += matches_groups.size();
res_offsets.push_back(res_offset);
}
}
};
}
REGISTER_FUNCTION(RegexpExtractAllSpark)
{
factory.registerFunction<FunctionRegexpExtractAllSpark>(
FunctionDocumentation{.description = R"(Extracts all the fragments of a string that matches the regexp pattern and corresponds to the regex group index.)"});
}
}