blob: 3a587d97744e887d6020b6592fb71cb0499ef73f [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <fmt/format.h>
#include <algorithm>
#include <cstddef>
#include "common/status.h"
#include "core/assert_cast.h"
#include "core/block/block.h"
#include "core/block/column_numbers.h"
#include "core/column/column_const.h"
#include "core/column/column_string.h"
#include "core/column/column_vector.h"
#include "core/data_type/data_type_string.h"
#include "core/memcpy_small.h"
#include "core/string_ref.h"
#include "exprs/function/function.h"
#include "exprs/function/function_helpers.h"
#include "exprs/function/simple_function_factory.h"
#include "exprs/function_context.h"
namespace doris {
#include "common/compile_check_avoid_begin.h"
template <bool Reverse>
class FunctionMaskPartial;
class FunctionMask : public IFunction {
public:
static constexpr auto name = "mask";
static constexpr unsigned char DEFAULT_UPPER_MASK = 'X';
static constexpr unsigned char DEFAULT_LOWER_MASK = 'x';
static constexpr unsigned char DEFAULT_NUMBER_MASK = 'n';
String get_name() const override { return name; }
static FunctionPtr create() { return std::make_shared<FunctionMask>(); }
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeString>();
}
size_t get_number_of_arguments() const override { return 0; }
ColumnNumbers get_arguments_that_are_always_constant() const override { return {1, 2, 3}; }
bool is_variadic() const override { return true; }
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
DCHECK_GE(arguments.size(), 1);
DCHECK_LE(arguments.size(), 4);
char upper = DEFAULT_UPPER_MASK, lower = DEFAULT_LOWER_MASK, number = DEFAULT_NUMBER_MASK;
auto res = ColumnString::create();
const auto& source_column =
assert_cast<const ColumnString&>(*block.get_by_position(arguments[0]).column);
if (arguments.size() > 1) {
const auto& col = *block.get_by_position(arguments[1]).column;
auto string_ref = col.get_data_at(0);
if (string_ref.size > 0) {
upper = *string_ref.data;
}
}
if (arguments.size() > 2) {
const auto& col = *block.get_by_position(arguments[2]).column;
auto string_ref = col.get_data_at(0);
if (string_ref.size > 0) {
lower = *string_ref.data;
}
}
if (arguments.size() > 3) {
const auto& col = *block.get_by_position(arguments[3]).column;
auto string_ref = col.get_data_at(0);
if (string_ref.size > 0) {
number = *string_ref.data;
}
}
if (arguments.size() > 4) {
return Status::InvalidArgument(
fmt::format("too many arguments for function {}", get_name()));
}
vector_mask(source_column, *res, upper, lower, number);
block.get_by_position(result).column = std::move(res);
return Status::OK();
}
friend class FunctionMaskPartial<true>;
friend class FunctionMaskPartial<false>;
private:
static void vector_mask(const ColumnString& source, ColumnString& result, const char upper,
const char lower, const char number) {
result.get_chars().resize(source.get_chars().size());
result.get_offsets().resize(source.get_offsets().size());
memcpy_small_allow_read_write_overflow15(
result.get_offsets().data(), source.get_offsets().data(),
source.get_offsets().size() * sizeof(ColumnString::Offset));
const unsigned char* src = source.get_chars().data();
const size_t size = source.get_chars().size();
unsigned char* res = result.get_chars().data();
mask(src, size, upper, lower, number, res);
}
static void mask(const unsigned char* __restrict src, const size_t size,
const unsigned char upper, const unsigned char lower,
const unsigned char number, unsigned char* __restrict res) {
for (size_t i = 0; i != size; ++i) {
auto c = src[i];
if (c >= 'A' && c <= 'Z') {
res[i] = upper;
} else if (c >= 'a' && c <= 'z') {
res[i] = lower;
} else if (c >= '0' && c <= '9') {
res[i] = number;
} else {
res[i] = c;
}
}
}
};
template <bool Reverse>
class FunctionMaskPartial : public IFunction {
public:
static constexpr auto name = Reverse ? "mask_last_n" : "mask_first_n";
String get_name() const override { return name; }
static FunctionPtr create() { return std::make_shared<FunctionMaskPartial>(); }
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeString>();
}
size_t get_number_of_arguments() const override { return 0; }
bool is_variadic() const override { return true; }
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
auto res = ColumnString::create();
auto col = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
const auto& source_column = assert_cast<const ColumnString&>(*col);
if (arguments.size() == 1) { // no 2nd arg, just mask all
FunctionMask::vector_mask(source_column, *res, FunctionMask::DEFAULT_UPPER_MASK,
FunctionMask::DEFAULT_LOWER_MASK,
FunctionMask::DEFAULT_NUMBER_MASK);
} else {
const auto& [col_2nd, is_const] =
unpack_if_const(block.get_by_position(arguments[1]).column);
const auto& col_n = assert_cast<const ColumnInt32&>(*col_2nd);
if (is_const) {
RETURN_IF_ERROR(vector<true>(source_column, col_n, *res));
} else {
RETURN_IF_ERROR(vector<false>(source_column, col_n, *res));
}
}
block.get_by_position(result).column = std::move(res);
return Status::OK();
}
private:
template <bool is_const>
static Status vector(const ColumnString& src, const ColumnInt32& col_n, ColumnString& result) {
const auto num_rows = src.size();
const auto* chars = src.get_chars().data();
const auto* offsets = src.get_offsets().data();
result.get_chars().resize(src.get_chars().size());
result.get_offsets().resize(src.get_offsets().size());
memcpy_small_allow_read_write_overflow15(
result.get_offsets().data(), src.get_offsets().data(),
src.get_offsets().size() * sizeof(ColumnString::Offset));
auto* res = result.get_chars().data();
const auto& col_n_data = col_n.get_data();
for (ssize_t i = 0; i != num_rows; ++i) {
auto offset = offsets[i - 1];
int len = offsets[i] - offset;
const int n = col_n_data[index_check_const<is_const>(i)];
if (n < 0) [[unlikely]] {
return Status::InvalidArgument(
"function {} only accept non-negative input for 2nd argument but got {}",
name, n);
}
if constexpr (Reverse) {
auto start = std::max(len - n, 0);
if (start > 0) {
memcpy(&res[offset], &chars[offset], start);
}
offset += start;
} else {
if (n < len) {
memcpy(&res[offset + n], &chars[offset + n], len - n);
}
}
len = std::min(n, len);
FunctionMask::mask(&chars[offset], len, FunctionMask::DEFAULT_UPPER_MASK,
FunctionMask::DEFAULT_LOWER_MASK, FunctionMask::DEFAULT_NUMBER_MASK,
&res[offset]);
}
return Status::OK();
}
};
void register_function_string_mask(SimpleFunctionFactory& factory) {
factory.register_function<FunctionMask>();
factory.register_function<FunctionMaskPartial<true>>();
factory.register_function<FunctionMaskPartial<false>>();
}
#include "common/compile_check_avoid_end.h"
} // namespace doris