blob: 4233f56b2c13cc2cfb1502b521b8d2d2013e8539 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/IFunction.cpp
// and modified by Doris
#include "vec/functions/function.h"
#include <algorithm>
#include <memory>
#include <numeric>
#include "common/status.h"
#include "runtime/define_primitive_type.h"
#include "runtime/primitive_type.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/core/field.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nothing.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/functions/function_helpers.h"
#include "vec/utils/util.hpp"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
ColumnPtr wrap_in_nullable(const ColumnPtr& src, const Block& block, const ColumnNumbers& args,
size_t input_rows_count) {
ColumnPtr result_null_map_column;
/// If result is already nullable.
ColumnPtr src_not_nullable = src;
MutableColumnPtr mutable_result_null_map_column;
if (auto nullable = check_and_get_column_ptr<ColumnNullable>(src)) {
src_not_nullable = nullable->get_nested_column_ptr();
result_null_map_column = nullable->get_null_map_column_ptr();
}
for (const auto& arg : args) {
const ColumnWithTypeAndName& elem = block.get_by_position(arg);
if (!elem.type->is_nullable() || is_column_const(*elem.column)) {
continue;
}
if (auto nullable = cast_to_column<ColumnNullable>(elem.column); nullable->has_null()) {
const ColumnPtr& null_map_column = nullable->get_null_map_column_ptr();
if (!result_null_map_column) { // NOLINT(bugprone-use-after-move)
result_null_map_column = null_map_column->clone_resized(input_rows_count);
continue;
}
if (!mutable_result_null_map_column) {
mutable_result_null_map_column =
std::move(result_null_map_column)->assume_mutable();
}
NullMap& result_null_map =
assert_cast<ColumnUInt8&>(*mutable_result_null_map_column).get_data();
const NullMap& src_null_map =
assert_cast<const ColumnUInt8&>(*null_map_column).get_data();
VectorizedUtils::update_null_map(result_null_map, src_null_map);
}
}
if (!result_null_map_column) {
if (is_column_const(*src)) {
return ColumnConst::create(
make_nullable(assert_cast<const ColumnConst&>(*src).get_data_column_ptr(),
false),
input_rows_count);
}
return ColumnNullable::create(src, ColumnUInt8::create(input_rows_count, 0));
}
return ColumnNullable::create(src_not_nullable, result_null_map_column);
}
bool have_null_column(const Block& block, const ColumnNumbers& args) {
return std::ranges::any_of(args, [&block](const auto& elem) {
return block.get_by_position(elem).type->is_nullable();
});
}
bool have_null_column(const ColumnsWithTypeAndName& args) {
return std::ranges::any_of(args, [](const auto& elem) { return elem.type->is_nullable(); });
}
inline Status PreparedFunctionImpl::_execute_skipped_constant_deal(FunctionContext* context,
Block& block,
const ColumnNumbers& args,
uint32_t result,
size_t input_rows_count) const {
bool executed = false;
RETURN_IF_ERROR(default_implementation_for_nulls(context, block, args, result, input_rows_count,
&executed));
if (executed) {
return Status::OK();
}
return execute_impl(context, block, args, result, input_rows_count);
}
Status PreparedFunctionImpl::default_implementation_for_constant_arguments(
FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
size_t input_rows_count, bool* executed) const {
*executed = false;
ColumnNumbers args_expect_const = get_arguments_that_are_always_constant();
// Check that these arguments are really constant.
for (auto arg_num : args_expect_const) {
if (arg_num < args.size() &&
!is_column_const(*block.get_by_position(args[arg_num]).column)) {
return Status::InvalidArgument("Argument at index {} for function {} must be constant",
arg_num, get_name());
}
}
if (args.empty() || !use_default_implementation_for_constants() ||
!VectorizedUtils::all_arguments_are_constant(block, args)) {
return Status::OK();
}
// now all columns are const.
Block temporary_block;
int arguments_size = (int)args.size();
for (size_t arg_num = 0; arg_num < arguments_size; ++arg_num) {
const ColumnWithTypeAndName& column = block.get_by_position(args[arg_num]);
// Columns in const_list --> column_const, others --> nested_column
// that's because some functions supposes some specific columns always constant.
// If we unpack it, there will be unnecessary cost of virtual judge.
if (args_expect_const.end() !=
std::find(args_expect_const.begin(), args_expect_const.end(), arg_num)) {
temporary_block.insert({column.column, column.type, column.name});
} else {
temporary_block.insert(
{assert_cast<const ColumnConst*>(column.column.get())->get_data_column_ptr(),
column.type, column.name});
}
}
temporary_block.insert(block.get_by_position(result));
ColumnNumbers temporary_argument_numbers(arguments_size);
for (int i = 0; i < arguments_size; ++i) {
temporary_argument_numbers[i] = i;
}
RETURN_IF_ERROR(_execute_skipped_constant_deal(context, temporary_block,
temporary_argument_numbers, arguments_size,
temporary_block.rows()));
ColumnPtr result_column;
/// extremely rare case, when we have function with completely const arguments
/// but some of them produced by non is_deterministic function
if (temporary_block.get_by_position(arguments_size).column->size() > 1) {
result_column = temporary_block.get_by_position(arguments_size).column->clone_resized(1);
} else {
result_column = temporary_block.get_by_position(arguments_size).column;
}
// We shuold handle the case where the result column is also a ColumnConst.
block.get_by_position(result).column = ColumnConst::create(result_column, input_rows_count);
*executed = true;
return Status::OK();
}
Status PreparedFunctionImpl::default_implementation_for_nulls(
FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
size_t input_rows_count, bool* executed) const {
*executed = false;
if (args.empty() || !use_default_implementation_for_nulls()) {
return Status::OK();
}
if (std::ranges::any_of(args, [&block](const auto& elem) {
return block.get_by_position(elem).column->only_null();
})) {
block.get_by_position(result).column =
block.get_by_position(result).type->create_column_const(input_rows_count, Field());
*executed = true;
return Status::OK();
}
if (have_null_column(block, args)) {
bool need_to_default = need_replace_null_data_to_default();
// extract nested column from nulls
ColumnNumbers new_args;
Block new_block;
for (int i = 0; i < args.size(); ++i) {
uint32_t arg = args[i];
new_args.push_back(i);
new_block.insert(block.get_by_position(arg).unnest_nullable(need_to_default));
}
new_block.insert(block.get_by_position(result));
int new_result = new_block.columns() - 1;
RETURN_IF_ERROR(default_execute(context, new_block, new_args, new_result, block.rows()));
// After run with nested, wrap them in null. Before this, block.get_by_position(result).type
// is not compatible with get_by_position(result).column
block.get_by_position(result).column = wrap_in_nullable(
new_block.get_by_position(new_result).column, block, args, input_rows_count);
*executed = true;
return Status::OK();
}
return Status::OK();
}
Status PreparedFunctionImpl::default_execute(FunctionContext* context, Block& block,
const ColumnNumbers& args, uint32_t result,
size_t input_rows_count) const {
bool executed = false;
RETURN_IF_ERROR(default_implementation_for_constant_arguments(context, block, args, result,
input_rows_count, &executed));
if (executed) {
return Status::OK();
}
return _execute_skipped_constant_deal(context, block, args, result, input_rows_count);
}
Status PreparedFunctionImpl::execute(FunctionContext* context, Block& block,
const ColumnNumbers& args, uint32_t result,
size_t input_rows_count) const {
return default_execute(context, block, args, result, input_rows_count);
}
void FunctionBuilderImpl::check_number_of_arguments(size_t number_of_arguments) const {
if (is_variadic()) {
return;
}
size_t expected_number_of_arguments = get_number_of_arguments();
DCHECK_EQ(number_of_arguments, expected_number_of_arguments) << fmt::format(
"Number of arguments for function {} doesn't match: passed {} , should be {}",
get_name(), number_of_arguments, expected_number_of_arguments);
if (number_of_arguments != expected_number_of_arguments) {
throw Exception(
ErrorCode::INVALID_ARGUMENT,
"Number of arguments for function {} doesn't match: passed {} , should be {}",
get_name(), number_of_arguments, expected_number_of_arguments);
}
}
DataTypePtr FunctionBuilderImpl::get_return_type(const ColumnsWithTypeAndName& arguments) const {
check_number_of_arguments(arguments.size());
if (!arguments.empty() && use_default_implementation_for_nulls()) {
if (have_null_column(arguments)) {
ColumnNumbers numbers(arguments.size());
std::iota(numbers.begin(), numbers.end(), 0);
auto [nested_block, _] =
create_block_with_nested_columns(Block(arguments), numbers, false);
auto return_type = get_return_type_impl(
ColumnsWithTypeAndName(nested_block.begin(), nested_block.end()));
if (!return_type) {
return nullptr;
}
return make_nullable(return_type);
}
}
return get_return_type_impl(arguments);
}
bool FunctionBuilderImpl::is_date_or_datetime_or_decimal(
const DataTypePtr& return_type, const DataTypePtr& func_return_type) const {
return (is_date_or_datetime(return_type->get_primitive_type()) &&
is_date_or_datetime(func_return_type->get_primitive_type())) ||
(is_date_v2_or_datetime_v2(return_type->get_primitive_type()) &&
is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) ||
// For some date functions such as str_to_date(string, string), return_type will
// be datetimev2 if users enable datev2 but get_return_type(arguments) will still
// return datetime. We need keep backward compatibility here.
(is_date_v2_or_datetime_v2(return_type->get_primitive_type()) &&
is_date_or_datetime(func_return_type->get_primitive_type())) ||
(is_date_or_datetime(return_type->get_primitive_type()) &&
is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) ||
(is_decimal(return_type->get_primitive_type()) &&
is_decimal(func_return_type->get_primitive_type())) ||
(is_time_type(return_type->get_primitive_type()) &&
is_time_type(func_return_type->get_primitive_type()));
}
bool contains_date_or_datetime_or_decimal(const DataTypePtr& type) {
auto type_ptr = type->is_nullable() ? ((DataTypeNullable*)type.get())->get_nested_type() : type;
switch (type_ptr->get_primitive_type()) {
case TYPE_ARRAY: {
const auto* array_type = assert_cast<const DataTypeArray*>(type_ptr.get());
return contains_date_or_datetime_or_decimal(array_type->get_nested_type());
}
case TYPE_MAP: {
const auto* map_type = assert_cast<const DataTypeMap*>(type_ptr.get());
return contains_date_or_datetime_or_decimal(map_type->get_key_type()) ||
contains_date_or_datetime_or_decimal(map_type->get_value_type());
}
case TYPE_STRUCT: {
const auto* struct_type = assert_cast<const DataTypeStruct*>(type_ptr.get());
const auto& elements = struct_type->get_elements();
return std::ranges::any_of(elements, [](const DataTypePtr& element) {
return contains_date_or_datetime_or_decimal(element);
});
}
default:
// For scalar types, check if it's date/datetime/decimal
return is_date_or_datetime(type_ptr->get_primitive_type()) ||
is_date_v2_or_datetime_v2(type_ptr->get_primitive_type()) ||
is_decimal(type_ptr->get_primitive_type()) ||
is_time_type(type_ptr->get_primitive_type());
}
}
// make sure array/map/struct and nested array/map/struct can be check
bool FunctionBuilderImpl::is_nested_type_date_or_datetime_or_decimal(
const DataTypePtr& return_type, const DataTypePtr& func_return_type) const {
auto return_type_ptr = return_type->is_nullable()
? ((DataTypeNullable*)return_type.get())->get_nested_type()
: return_type;
auto func_return_type_ptr =
func_return_type->is_nullable()
? ((DataTypeNullable*)func_return_type.get())->get_nested_type()
: func_return_type;
// make sure that map/struct/array also need to check
if (return_type_ptr->get_primitive_type() != func_return_type_ptr->get_primitive_type()) {
return false;
}
// Check if this type contains date/datetime/decimal types
if (!contains_date_or_datetime_or_decimal(return_type_ptr)) {
// If no date/datetime/decimal types, just pass through
return true;
}
// If contains date/datetime/decimal types, recursively check each element
switch (return_type_ptr->get_primitive_type()) {
case TYPE_ARRAY: {
auto nested_return_type = remove_nullable(
(assert_cast<const DataTypeArray*>(return_type_ptr.get()))->get_nested_type());
auto nested_func_type = remove_nullable(
(assert_cast<const DataTypeArray*>(func_return_type_ptr.get()))->get_nested_type());
return is_nested_type_date_or_datetime_or_decimal(nested_return_type, nested_func_type);
}
case TYPE_MAP: {
const auto* return_map = assert_cast<const DataTypeMap*>(return_type_ptr.get());
const auto* func_map = assert_cast<const DataTypeMap*>(func_return_type_ptr.get());
auto key_return = remove_nullable(return_map->get_key_type());
auto key_func = remove_nullable(func_map->get_key_type());
auto value_return = remove_nullable(return_map->get_value_type());
auto value_func = remove_nullable(func_map->get_value_type());
return is_nested_type_date_or_datetime_or_decimal(key_return, key_func) &&
is_nested_type_date_or_datetime_or_decimal(value_return, value_func);
}
case TYPE_STRUCT: {
const auto* return_struct = assert_cast<const DataTypeStruct*>(return_type_ptr.get());
const auto* func_struct = assert_cast<const DataTypeStruct*>(func_return_type_ptr.get());
auto return_elements = return_struct->get_elements();
auto func_elements = func_struct->get_elements();
if (return_elements.size() != func_elements.size()) {
return false;
}
for (size_t i = 0; i < return_elements.size(); i++) {
auto elem_return = remove_nullable(return_elements[i]);
auto elem_func = remove_nullable(func_elements[i]);
if (!is_nested_type_date_or_datetime_or_decimal(elem_return, elem_func)) {
return false;
}
}
return true;
}
default:
return is_date_or_datetime_or_decimal(return_type_ptr, func_return_type_ptr);
}
}
#include "common/compile_check_end.h"
} // namespace doris::vectorized