blob: 696f5352f200e5a4242bea1e8c73648392b2dad9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "AggregateFunctionParser.h"
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Parser/ExpressionParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <Parser/TypeParser.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_FUNCTION;
}
}
namespace local_engine
{
using namespace DB;
AggregateFunctionParser::AggregateFunctionParser(ParserContextPtr parser_context_) : parser_context(parser_context_)
{
expression_parser = std::make_unique<ExpressionParser>(parser_context);
}
AggregateFunctionParser::~AggregateFunctionParser()
{
}
String AggregateFunctionParser::getUniqueName(const String & name) const
{
return expression_parser->getUniqueName(name);
}
const DB::ActionsDAG::Node *
AggregateFunctionParser::addColumnToActionsDAG(DB::ActionsDAG & actions_dag, const DB::DataTypePtr & type, const DB::Field & field) const
{
return expression_parser->addConstColumn(actions_dag, type, field);
}
const DB::ActionsDAG::Node * AggregateFunctionParser::toFunctionNode(
DB::ActionsDAG & action_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const
{
return expression_parser->toFunctionNode(action_dag, func_name, args);
}
const DB::ActionsDAG::Node * AggregateFunctionParser::toFunctionNode(
DB::ActionsDAG & action_dag, const String & func_name, const String & result_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const
{
return expression_parser->toFunctionNode(action_dag, func_name, args, result_name);
}
const DB::ActionsDAG::Node * AggregateFunctionParser::parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & rel) const
{
return expression_parser->parseExpression(actions_dag, rel);
}
std::pair<DataTypePtr, Field> AggregateFunctionParser::parseLiteral(const substrait::Expression_Literal & literal) const
{
return LiteralParser::parse(literal);
}
DB::ActionsDAG::NodeRawConstPtrs
AggregateFunctionParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAG & actions_dag) const
{
DB::ActionsDAG::NodeRawConstPtrs collected_args;
for (const auto & arg : func_info.arguments)
{
auto arg_value = arg.value();
const DB::ActionsDAG::Node * arg_node = parseExpression(actions_dag, arg_value);
// If the aggregate result is required to be nullable, make all inputs be nullable at the first stage.
auto required_output_type = DB::WhichDataType(TypeParser::parseType(func_info.output_type));
if (required_output_type.isNullable()
&& (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
|| func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT)
&& !arg_node->result_type->isNullable())
{
DB::ActionsDAG::NodeRawConstPtrs args;
args.emplace_back(arg_node);
const auto * node = toFunctionNode(actions_dag, "toNullable", args);
actions_dag.addOrReplaceInOutputs(*node);
arg_node = node;
}
collected_args.push_back(arg_node);
}
if (func_info.has_filter)
{
// With `If` combinator, the function take one more argument which refers to the condition.
const auto * action_node = parseExpression(actions_dag, func_info.filter);
collected_args.emplace_back(action_node);
}
return collected_args;
}
std::pair<String, DB::DataTypes> AggregateFunctionParser::tryApplyCHCombinator(
const CommonFunctionInfo & func_info, const String & ch_func_name, const DB::DataTypes & argument_types) const
{
auto get_aggregate_function
= [](const String & name, const DB::DataTypes & argument_types, const DB::Array & parameters) -> DB::AggregateFunctionPtr
{
DB::AggregateFunctionProperties properties;
auto func = RelParser::getAggregateFunction(name, argument_types, properties, parameters);
if (!func)
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown aggregate function {}", name);
return func;
};
String combinator_function_name = ch_func_name;
DB::DataTypes combinator_argument_types = argument_types;
if (func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
&& func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT)
{
if (argument_types.size() != 1)
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Only support one argument aggregate function in phase {}", func_info.phase);
// Add a check here for safty.
if (func_info.has_filter)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Apply filter in phase {} not supported", func_info.phase);
const auto * aggr_func_type = DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(argument_types[0].get());
if (!aggr_func_type)
{
// FIXME. This is should be fixed. It's the case that count(distinct(xxx)) with other aggregate functions.
// Gluten breaks the rule that intermediate result should have a special format name here.
LOG_INFO(logger, "Intermediate aggregate function data is expected in phase {} for {}", func_info.phase, ch_func_name);
auto arg_type = DB::removeNullable(argument_types[0]);
if (auto * tupe_type = typeid_cast<const DB::DataTypeTuple *>(arg_type.get()))
combinator_argument_types = tupe_type->getElements();
auto agg_function = get_aggregate_function(ch_func_name, argument_types, aggr_func_type->getParameters());
auto agg_intermediate_result_type = agg_function->getStateType();
combinator_argument_types = {agg_intermediate_result_type};
}
else
{
// Special case for handling the intermedidate result from aggregate functions with filter.
// It's safe to use AggregateFunctionxxx to parse intermediate result from AggregateFunctionxxxIf,
// since they have the same binary representation
// reproduce this case by
// select
// count(a),count(b), count(1), count(distinct(a)), count(distinct(b))
// from values (1, null), (2,2) as data(a,b)
// with `first_value` enable
if (endsWith(aggr_func_type->getFunction()->getName(), "If") && ch_func_name != aggr_func_type->getFunction()->getName())
{
auto original_args_types = aggr_func_type->getArgumentsDataTypes();
combinator_argument_types = DataTypes(original_args_types.begin(), std::prev(original_args_types.end()));
auto agg_function = get_aggregate_function(ch_func_name, combinator_argument_types, aggr_func_type->getParameters());
combinator_argument_types = {agg_function->getStateType()};
}
}
combinator_function_name += "PartialMerge";
}
else if (func_info.has_filter)
{
// Apply `If` aggregate function combinator on the original aggregate function.
combinator_function_name += "If";
}
return {combinator_function_name, combinator_argument_types};
}
const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAG & actions_dag, bool with_nullability) const
{
const auto & output_type = func_info.output_type;
bool need_convert_type = !TypeParser::isTypeMatched(output_type, func_node->result_type, !with_nullability);
if (need_convert_type)
{
func_node = ActionsDAGUtil::convertNodeType(actions_dag, func_node, TypeParser::parseType(output_type), func_node->result_name);
actions_dag.addOrReplaceInOutputs(*func_node);
}
func_node = convertNanToNullIfNeed(func_info, func_node, actions_dag);
if (output_type.has_decimal())
{
String checkDecimalOverflowSparkOrNull = "checkDecimalOverflowSparkOrNull";
DB::ActionsDAG::NodeRawConstPtrs overflow_args
= {func_node,
expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().precision()),
expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().scale())};
func_node = toFunctionNode(actions_dag, checkDecimalOverflowSparkOrNull, func_node->result_name, overflow_args);
actions_dag.addOrReplaceInOutputs(*func_node);
}
return func_node;
}
const DB::ActionsDAG::Node * AggregateFunctionParser::convertNanToNullIfNeed(
const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAG & actions_dag) const
{
if (getCHFunctionName(func_info) != "corr" || !func_node->result_type->isNullable())
return func_node;
/// result is nullable.
/// if result is NaN, convert it to NULL.
auto is_nan_func_node = toFunctionNode(actions_dag, "isNaN", getUniqueName("isNaN"), {func_node});
auto nullable_col = func_node->result_type->createColumn();
nullable_col->insertDefault();
const auto * null_node
= &actions_dag.addColumn(DB::ColumnWithTypeAndName(std::move(nullable_col), func_node->result_type, getUniqueName("null")));
DB::ActionsDAG::NodeRawConstPtrs convert_nan_func_args = {is_nan_func_node, null_node, func_node};
func_node = toFunctionNode(actions_dag, "if", func_node->result_name, convert_nan_func_args);
actions_dag.addOrReplaceInOutputs(*func_node);
return func_node;
}
AggregateFunctionParserFactory & AggregateFunctionParserFactory::instance()
{
static AggregateFunctionParserFactory factory;
return factory;
}
void AggregateFunctionParserFactory::registerAggregateFunctionParser(const String & name, Value value)
{
if (!parsers.emplace(name, value).second)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Aggregate function parser {} is already registered", name);
}
AggregateFunctionParserPtr AggregateFunctionParserFactory::get(const String & name, ParserContextPtr parser_context) const
{
auto parser = tryGet(name, parser_context);
if (!parser)
throw DB::Exception(DB::ErrorCodes::UNKNOWN_FUNCTION, "Unknown aggregate function {}", name);
return parser;
}
AggregateFunctionParserPtr AggregateFunctionParserFactory::tryGet(const String & name, ParserContextPtr parser_context) const
{
auto it = parsers.find(name);
if (it == parsers.end())
return nullptr;
return it->second(parser_context);
}
}