blob: 7da847d8d6361d55371c2c7cd4b9846109de303e [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/exprs/vectorized_agg_fn.h"
#include <fmt/format.h>
#include <fmt/ranges.h> // IWYU pragma: keep
#include <gen_cpp/Exprs_types.h>
#include <gen_cpp/PlanNodes_types.h>
#include <glog/logging.h>
#include <memory>
#include <ostream>
#include <string_view>
#include "common/config.h"
#include "common/object_pool.h"
#include "vec/aggregate_functions/aggregate_function_ai_agg.h"
#include "vec/aggregate_functions/aggregate_function_java_udaf.h"
#include "vec/aggregate_functions/aggregate_function_rpc.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_sort.h"
#include "vec/aggregate_functions/aggregate_function_state_merge.h"
#include "vec/aggregate_functions/aggregate_function_state_union.h"
#include "vec/core/block.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/materialize_block.h"
#include "vec/data_types/data_type_agg_state.h"
#include "vec/data_types/data_type_factory.hpp"
#include "vec/exprs/vexpr.h"
#include "vec/exprs/vexpr_context.h"
#include "vec/utils/util.hpp"
static constexpr int64_t BE_VERSION_THAT_SUPPORT_NULLABLE_CHECK = 8;
namespace doris {
class RowDescriptor;
namespace vectorized {
class Arena;
class BufferWritable;
class IColumn;
} // namespace vectorized
} // namespace doris
namespace doris::vectorized {
#include "common/compile_check_begin.h"
template <class FunctionType>
AggregateFunctionPtr get_agg_state_function(const DataTypes& argument_types,
DataTypePtr return_type) {
return FunctionType::create(
assert_cast<const DataTypeAggState*>(argument_types[0].get())->get_nested_function(),
argument_types, return_type);
}
AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, const bool without_key,
const bool is_window_function)
: _fn(desc.fn),
_is_merge(desc.agg_expr.is_merge_agg),
_without_key(without_key),
_is_window_function(is_window_function),
_data_type(DataTypeFactory::instance().create_data_type(
desc.fn.ret_type, desc.__isset.is_nullable ? desc.is_nullable : true)) {
if (desc.agg_expr.__isset.param_types) {
const auto& param_types = desc.agg_expr.param_types;
for (const auto& param_type : param_types) {
_argument_types_with_sort.push_back(
DataTypeFactory::instance().create_data_type(param_type));
}
}
}
Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
const bool without_key, const bool is_window_function,
AggFnEvaluator** result) {
*result =
pool->add(AggFnEvaluator::create_unique(desc.nodes[0], without_key, is_window_function)
.release());
auto& agg_fn_evaluator = *result;
int node_idx = 0;
for (int i = 0; i < desc.nodes[0].num_children; ++i) {
++node_idx;
VExprSPtr expr;
VExprContextSPtr ctx;
RETURN_IF_ERROR(VExpr::create_tree_from_thrift(desc.nodes, &node_idx, expr, ctx));
agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx);
}
auto sort_size = sort_info.ordering_exprs.size();
auto real_arguments_size = agg_fn_evaluator->_argument_types_with_sort.size() - sort_size;
// Child arguments contains [real arguments, order by arguments], we pass the arguments
// to the order by functions
for (int i = 0; i < sort_size; ++i) {
agg_fn_evaluator->_sort_description.emplace_back(real_arguments_size + i,
sort_info.is_asc_order[i] ? 1 : -1,
sort_info.nulls_first[i] ? -1 : 1);
}
// Pass the real arguments to get functions
for (int i = 0; i < real_arguments_size; ++i) {
agg_fn_evaluator->_real_argument_types.emplace_back(
agg_fn_evaluator->_argument_types_with_sort[i]);
}
return Status::OK();
}
Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc,
const SlotDescriptor* intermediate_slot_desc,
const SlotDescriptor* output_slot_desc) {
DCHECK(intermediate_slot_desc != nullptr);
DCHECK(_intermediate_slot_desc == nullptr);
_output_slot_desc = output_slot_desc;
_intermediate_slot_desc = intermediate_slot_desc;
Status status = VExpr::prepare(_input_exprs_ctxs, state, desc);
RETURN_IF_ERROR(status);
DataTypes tmp_argument_types;
tmp_argument_types.reserve(_input_exprs_ctxs.size());
std::vector<std::string_view> child_expr_name;
// prepare for argument
for (auto& _input_exprs_ctx : _input_exprs_ctxs) {
auto data_type = _input_exprs_ctx->root()->data_type();
tmp_argument_types.emplace_back(data_type);
child_expr_name.emplace_back(_input_exprs_ctx->root()->expr_name());
}
std::vector<std::string> column_names;
for (const auto& expr_ctx : _input_exprs_ctxs) {
const auto& root = expr_ctx->root();
if (!root->expr_name().empty() && !root->is_constant()) {
column_names.emplace_back(root->expr_name());
}
}
const DataTypes& argument_types =
_real_argument_types.empty() ? tmp_argument_types : _real_argument_types;
if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
if (config::enable_java_support) {
_function = AggregateJavaUdaf::create(_fn, argument_types, _data_type);
RETURN_IF_ERROR(static_cast<AggregateJavaUdaf*>(_function.get())->check_udaf(_fn));
} else {
return Status::InternalError(
"Java UDAF is not enabled, you can change be config enable_java_support to "
"true and restart be.");
}
} else if (_fn.binary_type == TFunctionBinaryType::RPC) {
_function = AggregateRpcUdaf::create(_fn, argument_types, _data_type);
} else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) {
if (argument_types.size() != 1) {
return Status::InternalError("Agg state Function must input 1 argument but get {}",
argument_types.size());
}
if (argument_types[0]->is_nullable()) {
return Status::InternalError("Agg state function input type must be not nullable");
}
if (argument_types[0]->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) {
return Status::InternalError(
"Agg state function input type must be agg_state but get {}",
argument_types[0]->get_family_name());
}
std::string type_function_name =
assert_cast<const DataTypeAggState*>(argument_types[0].get())->get_function_name();
if (type_function_name + AGG_UNION_SUFFIX == _fn.name.function_name) {
if (_data_type->is_nullable()) {
return Status::InternalError(
"Union function return type must be not nullable, real={}",
_data_type->get_name());
}
if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) {
return Status::InternalError(
"Union function return type must be AGG_STATE, real={}",
_data_type->get_name());
}
_function = get_agg_state_function<AggregateStateUnion>(argument_types, _data_type);
} else if (type_function_name + AGG_MERGE_SUFFIX == _fn.name.function_name) {
auto type = assert_cast<const DataTypeAggState*>(argument_types[0].get())
->get_nested_function()
->get_return_type();
if (!type->equals(*_data_type)) {
return Status::InternalError("{}'s expect return type is {}, but input {}",
argument_types[0]->get_name(), type->get_name(),
_data_type->get_name());
}
_function = get_agg_state_function<AggregateStateMerge>(argument_types, _data_type);
} else {
return Status::InternalError("{} not match function {}", argument_types[0]->get_name(),
_fn.name.function_name);
}
} else {
const bool is_foreach =
AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name) ||
AggregateFunctionSimpleFactory::is_foreachv2(_fn.name.function_name);
// Here, only foreachv1 needs special treatment, and v2 can follow the normal code logic.
if (AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name)) {
_function = AggregateFunctionSimpleFactory::instance().get(
_fn.name.function_name, argument_types, _data_type,
AggregateFunctionSimpleFactory::result_nullable_by_foreach(_data_type),
state->be_exec_version(),
{.is_window_function = _is_window_function,
.is_foreach = is_foreach,
.column_names = std::move(column_names)});
} else {
_function = AggregateFunctionSimpleFactory::instance().get(
_fn.name.function_name, argument_types, _data_type, _data_type->is_nullable(),
state->be_exec_version(),
{.is_window_function = _is_window_function,
.is_foreach = is_foreach,
.column_names = std::move(column_names)});
}
}
if (_function == nullptr) {
return Status::InternalError("Agg Function {} is not implemented", _fn.signature);
}
if (!_sort_description.empty()) {
_function = transform_to_sort_agg_function(_function, _argument_types_with_sort,
_sort_description, state);
}
if (_fn.name.function_name == "ai_agg") {
_function->set_query_context(state->get_query_ctx());
}
// Foreachv2, like foreachv1, does not check the return type,
// because its return type is related to the internal agg.
if (!AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name) &&
!AggregateFunctionSimpleFactory::is_foreachv2(_fn.name.function_name)) {
if (state->be_exec_version() >= BE_VERSION_THAT_SUPPORT_NULLABLE_CHECK) {
RETURN_IF_ERROR(
_function->verify_result_type(_without_key, argument_types, _data_type));
}
}
_expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name);
return Status::OK();
}
Status AggFnEvaluator::open(RuntimeState* state) {
return VExpr::open(_input_exprs_ctxs, state);
}
void AggFnEvaluator::create(AggregateDataPtr place) {
_function->create(place);
}
void AggFnEvaluator::destroy(AggregateDataPtr place) {
_function->destroy(place);
}
Status AggFnEvaluator::execute_single_add(Block* block, AggregateDataPtr place, Arena& arena) {
RETURN_IF_ERROR(_calc_argument_columns(block));
_function->add_batch_single_place(block->rows(), place, _agg_columns.data(), arena);
return Status::OK();
}
Status AggFnEvaluator::execute_batch_add(Block* block, size_t offset, AggregateDataPtr* places,
Arena& arena, bool agg_many) {
RETURN_IF_ERROR(_calc_argument_columns(block));
_function->add_batch(block->rows(), places, offset, _agg_columns.data(), arena, agg_many);
return Status::OK();
}
Status AggFnEvaluator::execute_batch_add_selected(Block* block, size_t offset,
AggregateDataPtr* places, Arena& arena) {
RETURN_IF_ERROR(_calc_argument_columns(block));
_function->add_batch_selected(block->rows(), places, offset, _agg_columns.data(), arena);
return Status::OK();
}
Status AggFnEvaluator::streaming_agg_serialize(Block* block, BufferWritable& buf,
const size_t num_rows, Arena& arena) {
RETURN_IF_ERROR(_calc_argument_columns(block));
_function->streaming_agg_serialize(_agg_columns.data(), buf, num_rows, arena);
return Status::OK();
}
Status AggFnEvaluator::streaming_agg_serialize_to_column(Block* block, MutableColumnPtr& dst,
const size_t num_rows, Arena& arena) {
RETURN_IF_ERROR(_calc_argument_columns(block));
_function->streaming_agg_serialize_to_column(_agg_columns.data(), dst, num_rows, arena);
return Status::OK();
}
void AggFnEvaluator::insert_result_info(AggregateDataPtr place, IColumn* column) {
_function->insert_result_into(place, *column);
}
void AggFnEvaluator::insert_result_info_vec(const std::vector<AggregateDataPtr>& places,
size_t offset, IColumn* column, const size_t num_rows) {
_function->insert_result_into_vec(places, offset, *column, num_rows);
}
void AggFnEvaluator::reset(AggregateDataPtr place) {
_function->reset(place);
}
std::string AggFnEvaluator::debug_string(const std::vector<AggFnEvaluator*>& exprs) {
std::stringstream out;
out << "[";
for (int i = 0; i < exprs.size(); ++i) {
out << (i == 0 ? "" : " ") << exprs[i]->debug_string();
}
out << "]";
return out.str();
}
std::string AggFnEvaluator::debug_string() const {
std::stringstream out;
out << "AggFnEvaluator(";
out << _fn.signature;
out << ")";
return out.str();
}
Status AggFnEvaluator::_calc_argument_columns(Block* block) {
SCOPED_TIMER(_expr_timer);
_agg_columns.resize(_input_exprs_ctxs.size());
std::vector<int> column_ids(_input_exprs_ctxs.size());
for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
int column_id = -1;
RETURN_IF_ERROR(_input_exprs_ctxs[i]->execute(block, &column_id));
column_ids[i] = column_id;
}
materialize_block_inplace(*block, column_ids.data(),
column_ids.data() + _input_exprs_ctxs.size());
for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
_agg_columns[i] = block->get_by_position(column_ids[i]).column.get();
}
return Status::OK();
}
AggFnEvaluator* AggFnEvaluator::clone(RuntimeState* state, ObjectPool* pool) {
return pool->add(AggFnEvaluator::create_unique(*this, state).release());
}
AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state)
: _fn(evaluator._fn),
_is_merge(evaluator._is_merge),
_without_key(evaluator._without_key),
_is_window_function(evaluator._is_window_function),
_argument_types_with_sort(evaluator._argument_types_with_sort),
_real_argument_types(evaluator._real_argument_types),
_intermediate_slot_desc(evaluator._intermediate_slot_desc),
_output_slot_desc(evaluator._output_slot_desc),
_sort_description(evaluator._sort_description),
_data_type(evaluator._data_type),
_function(evaluator._function),
_expr_name(evaluator._expr_name),
_agg_columns(evaluator._agg_columns) {
if (evaluator._fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
DataTypes tmp_argument_types;
tmp_argument_types.reserve(evaluator._input_exprs_ctxs.size());
// prepare for argument
for (auto& _input_exprs_ctx : evaluator._input_exprs_ctxs) {
auto data_type = _input_exprs_ctx->root()->data_type();
tmp_argument_types.emplace_back(data_type);
}
const DataTypes& argument_types =
_real_argument_types.empty() ? tmp_argument_types : _real_argument_types;
_function = AggregateJavaUdaf::create(evaluator._fn, argument_types, evaluator._data_type);
THROW_IF_ERROR(static_cast<AggregateJavaUdaf*>(_function.get())->check_udaf(evaluator._fn));
}
DCHECK(_function != nullptr);
_input_exprs_ctxs.resize(evaluator._input_exprs_ctxs.size());
for (size_t i = 0; i < _input_exprs_ctxs.size(); i++) {
WARN_IF_ERROR(evaluator._input_exprs_ctxs[i]->clone(state, _input_exprs_ctxs[i]), "");
}
}
Status AggFnEvaluator::check_agg_fn_output(uint32_t key_size,
const std::vector<vectorized::AggFnEvaluator*>& agg_fn,
const RowDescriptor& output_row_desc) {
auto name_and_types = VectorizedUtils::create_name_and_data_types(output_row_desc);
for (uint32_t i = key_size, j = 0; i < name_and_types.size(); i++, j++) {
auto&& [name, column_type] = name_and_types[i];
auto agg_return_type = agg_fn[j]->function()->get_return_type();
if (!column_type->equals(*agg_return_type)) {
if (!column_type->is_nullable() || agg_return_type->is_nullable() ||
!remove_nullable(column_type)->equals(*agg_return_type)) {
return Status::InternalError(
"column_type not match data_types in agg node, column_type={}, "
"data_types={},column name={}",
column_type->get_name(), agg_return_type->get_name(), name);
}
}
}
return Status::OK();
}
bool AggFnEvaluator::is_blockable() const {
return _function->is_blockable() ||
std::any_of(_input_exprs_ctxs.begin(), _input_exprs_ctxs.end(),
[](VExprContextSPtr ctx) { return ctx->root()->is_blockable(); });
}
#include "common/compile_check_end.h"
} // namespace doris::vectorized