blob: ceffc18afecc33c99ffe0b8cf1a6bad1f5607775 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/ActionsDAG.h>
#include <Parser/AggregateFunctionParser.h>
#include <Parser/aggregate_function_parser/PercentileParserBase.h>
#include <substrait/algebra.pb.h>
#include <Common/CHUtil.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}
namespace local_engine
{
using namespace DB;
void PercentileParserBase::assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const
{
if (size != expect)
throw Exception(
DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} in phase {} requires exactly {} arguments but got {} arguments",
getName(),
magic_enum::enum_name(phase),
expect,
size);
}
const substrait::Expression::Literal &
PercentileParserBase::assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const
{
if (!expr.has_literal())
throw Exception(
DB::ErrorCodes::BAD_ARGUMENTS,
"The argument of function {} in phase {} must be literal, but is {}",
getName(),
magic_enum::enum_name(phase),
expr.DebugString());
return expr.literal();
}
String PercentileParserBase::getCHFunctionName(const CommonFunctionInfo & func_info) const
{
const auto & output_type = func_info.output_type;
return output_type.has_list() ? getCHPluralName() : getCHSingularName();
}
String PercentileParserBase::getCHFunctionName(DB::DataTypes & types) const
{
/// Always invoked during second stage
assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT, types.size(), expectedTupleElementsNumberInSecondStage());
auto type = removeNullable(types[PERCENTAGE_INDEX]);
types.resize(1);
if (getName() == "percentile")
{
/// Corresponding CH function requires two arguments: quantileExactWeightedInterpolated(xxx)(col, weight)
types.push_back(std::make_shared<DataTypeUInt64>());
}
return isArray(type) ? getCHPluralName() : getCHSingularName();
}
DB::Array PercentileParserBase::parseFunctionParameters(
const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes, DB::ActionsDAG & actions_dag) const
{
if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
|| func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT || func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED)
{
Array params;
const auto & arguments = func_info.arguments;
assertArgumentsSize(func_info.phase, arguments.size(), expectedArgumentsNumberInFirstStage());
auto param_indexes = getArgumentsThatAreParameters();
for (auto idx : param_indexes)
{
const auto & expr = arguments[idx].value();
const auto & literal = assertAndGetLiteral(func_info.phase, expr);
auto [type, field] = parseLiteral(literal);
if (idx == PERCENTAGE_INDEX && isArray(removeNullable(type)))
{
/// Multiple percentages for quantilesXXX
const Array & percentags = field.safeGet<Array>();
for (const auto & percentage : percentags)
params.emplace_back(percentage);
}
else
{
params.emplace_back(std::move(field));
}
}
/// Collect arguments in substrait plan that are not CH parameters as CH arguments
ActionsDAG::NodeRawConstPtrs new_arg_nodes;
for (size_t i = 0; i < arg_nodes.size(); ++i)
{
if (std::find(param_indexes.begin(), param_indexes.end(), i) == param_indexes.end())
{
if (getName() == "percentile" && i == 2)
{
/// In spark percentile(col, percentage, weight), the last argument weight is a signed integer
/// But CH requires weight as an unsigned integer
DataTypePtr dst_type = std::make_shared<DataTypeUInt64>();
if (arg_nodes[i]->result_type->isNullable())
dst_type = std::make_shared<DataTypeNullable>(dst_type);
arg_nodes[i] = ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, arg_nodes[i], dst_type);
}
new_arg_nodes.emplace_back(arg_nodes[i]);
}
}
new_arg_nodes.swap(arg_nodes);
return params;
}
else
{
assertArgumentsSize(func_info.phase, arg_nodes.size(), 1);
const auto & result_type = arg_nodes[0]->result_type;
const auto * aggregate_function_type = DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(result_type.get());
if (!aggregate_function_type)
throw Exception(
DB::ErrorCodes::BAD_ARGUMENTS,
"The first argument type of function {} in phase {} must be AggregateFunction, but is {}",
getName(),
magic_enum::enum_name(func_info.phase),
result_type->getName());
return aggregate_function_type->getParameters();
}
}
}