blob: d214858365163980041d5b525fdd9d472da0e84a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Parser/AggregateFunctionParser.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}
}
namespace local_engine
{
using namespace DB;
struct ApproxCountDistinctNameStruct
{
static constexpr auto spark_name = "approx_count_distinct";
static constexpr auto ch_name = "uniqHLLPP";
};
/// Spark approx_count_distinct(expr, relative_sd) = CH uniqHLLPP(relative_sd)(if(isNull(expr), null, sparkXxHash64(expr)))
/// Spark approx_count_distinct(expr) = CH uniqHLLPP(0.05)(if(isNull(expr), null, sparkXxHash64(expr)))
template <typename NameStruct>
class AggregateFunctionParserApproxCountDistinct final : public AggregateFunctionParser
{
public:
static constexpr auto name = NameStruct::spark_name;
AggregateFunctionParserApproxCountDistinct(ParserContextPtr parser_context_) : AggregateFunctionParser(parser_context_) { }
~AggregateFunctionParserApproxCountDistinct() override = default;
String getName() const override { return NameStruct::spark_name; }
String getCHFunctionName(const CommonFunctionInfo &) const override { return NameStruct::ch_name; }
String getCHFunctionName(DataTypes & types) const override
{
/// Always invoked during second stage, the first argument is expr, the second argument is relative_sd.
/// 1. Remove the second argument because types are used to create the aggregate function.
/// 2. Replace the first argument type with UInt64 or Nullable(UInt64) because uniqHLLPP requres it.
types.resize(1);
const auto old_type = types[0];
types[0] = std::make_shared<DataTypeUInt64>();
if (old_type->isNullable())
types[0] = std::make_shared<DataTypeNullable>(types[0]);
return NameStruct::ch_name;
}
Array parseFunctionParameters(
const CommonFunctionInfo & func_info, ActionsDAG::NodeRawConstPtrs & arg_nodes, ActionsDAG & actions_dag) const override
{
if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
|| func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT
|| func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED)
{
const auto & arguments = func_info.arguments;
const size_t num_args = arguments.size();
const size_t num_nodes = arg_nodes.size();
if (num_args != num_nodes || num_args > 2 || num_args < 1 || num_nodes > 2 || num_nodes < 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} takes 1 or 2 arguments in phase {}",
getName(),
magic_enum::enum_name(func_info.phase));
Array params;
if (num_args == 2)
{
const auto & relative_sd_arg = arguments[1].value();
if (relative_sd_arg.has_literal())
{
auto [_, field] = parseLiteral(relative_sd_arg.literal());
params.push_back(std::move(field));
}
else
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of function {} must be literal", getName());
}
else
{
params.push_back(0.05);
}
const auto & expr_arg = arg_nodes[0];
const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {expr_arg});
const auto * hash_node = toFunctionNode(actions_dag, "sparkXxHash64", {expr_arg});
const auto * null_node
= addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeNullable>(std::make_shared<DataTypeUInt64>()), {});
const auto * if_node = toFunctionNode(actions_dag, "if", {is_null_node, null_node, hash_node});
/// Replace the first argument expr with if(isNull(expr), null, sparkXxHash64(expr))
arg_nodes[0] = if_node;
arg_nodes.resize(1);
return params;
}
else
{
if (arg_nodes.size() != 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} takes 1 argument in phase {}",
getName(),
magic_enum::enum_name(func_info.phase));
const auto & result_type = arg_nodes[0]->result_type;
const auto * aggregate_function_type = checkAndGetDataType<DataTypeAggregateFunction>(result_type.get());
if (!aggregate_function_type)
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"The first argument type of function {} in phase {} must be AggregateFunction, but is {}",
getName(),
magic_enum::enum_name(func_info.phase),
result_type->getName());
return aggregate_function_type->getParameters();
}
}
Array getDefaultFunctionParameters() const override { return {0.05}; }
};
static const AggregateFunctionParserRegister<AggregateFunctionParserApproxCountDistinct<ApproxCountDistinctNameStruct>> registerer_approx_count_distinct;
}