blob: 5d7059ae6c90bf67e7aa4808070cdb3e8c09d82a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stack>
#include <DataTypes/DataTypeString.h>
#include <IO/ReadBufferFromString.h>
#include <Parser/FunctionParser.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include <Common/re2.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
}
using namespace DB;
namespace local_engine
{
class FunctionParserRegexpExtract : public FunctionParser
{
public:
explicit FunctionParserRegexpExtract(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { }
~FunctionParserRegexpExtract() override = default;
static constexpr auto name = "regexp_extract";
String getName() const override { return name; }
const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const override
{
const auto & args = substrait_func.arguments();
if (args.size() != 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires 3 arguments", getName());
if (args[1].value().has_literal())
{
const auto & literal_expr = args[1].value().literal();
if (literal_expr.has_string())
{
std::string expr_str = literal_expr.string();
/// FIXEDME: This only works for RE2
expr_str = adaptPatternForRE2(expr_str);
String sparkRegexp = adjustSparkRegexpRule(expr_str);
const auto * regex_expr_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeString>(), sparkRegexp);
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
parsed_args[1] = regex_expr_node;
const auto * result_node = toFunctionNode(actions_dag, "regexpExtract", parsed_args);
return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag);
}
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 2nd argument's type must be const string", getName());
}
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 2nd argument's type must be const", getName());
}
private:
String adaptPatternForRE2(const String & pattern_) const
{
LOG_DEBUG(getLogger("FunctionParserRegexpExtract"), "xxx original pattern: {}", pattern_);
String res = pattern_;
// adaptation for $, see issue #8325. equal two cases in re2: $ and \n$, but not include strings which contains \n in middle.
static const std::string replaced_str = "($|\\\\n$)";
static const re2::RE2 replace_dollar_pattern("([^\\\\])(\\$)");
re2::RE2::GlobalReplace(&res, replace_dollar_pattern, "\\1" + replaced_str);
LOG_DEBUG(getLogger("FunctionParserRegexpExtract"), "xxx adaption for $: {}", res);
// adaption for `.` . Need to remove flag s.
res = "(?-s)" + res;
return res;
}
String adjustSparkRegexpRule(String & str) const
{
const auto left_bracket_pos = str.find('[');
const auto right_bracket_pos = str.find(']');
if (left_bracket_pos == str.npos || right_bracket_pos == str.npos || left_bracket_pos >= right_bracket_pos)
return str;
auto throw_message = [this, &str]() -> void {
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The value of parameter(s) 'regexp' in `{}` is invalid: '{}'", getName(), str);
};
ReadBufferFromString buf(str);
std::stack<String> strs;
strs.emplace("");
bool nead_right_bracket = false;
while (!buf.eof())
{
if (*buf.position() == '[')
{
strs.emplace("");
}
else if (*buf.position() == ']')
{
if (strs.size() == 1)
{
// "ab]c"
strs.top().append("]");
}
else
{
String back = strs.top();
strs.pop();
if (strs.size() == 1)
{
// "abc[abc]abc"
strs.top().append("[").append(back).append("]");
nead_right_bracket = false;
}
else
{
// "abc[a[abc]c]abc"
strs.top().append(back);
nead_right_bracket = true;
}
}
}
else
{
strs.top() += *buf.position();
}
++buf.position();
}
if (nead_right_bracket && strs.size() != 1)
throw_message();
while (strs.size() != 1)
{
String back = strs.top();
strs.pop();
strs.top().append("[").append(back);
}
return strs.top();
}
};
static FunctionParserRegister<FunctionParserRegexpExtract> register_regexp_extract;
}