blob: 16f331aa6db60d152c55f6f5edf3e763257e0b49 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionFactory.h
// and modified by Doris
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
#include "agent/be_exec_version_manager.h"
#include "core/assert_cast.h"
#include "core/data_type/data_type.h"
#include "core/data_type/data_type_array.h"
#include "core/data_type/data_type_nullable.h"
#include "exprs/aggregate/aggregate_function.h"
namespace doris {
#include "common/compile_check_begin.h"
using DataTypePtr = std::shared_ptr<const IDataType>;
using DataTypes = std::vector<DataTypePtr>;
using AggregateFunctionCreator =
std::function<AggregateFunctionPtr(const std::string&, const DataTypes&, const DataTypePtr&,
const bool, const AggregateFunctionAttr&)>;
inline std::string types_name(const DataTypes& types) {
std::string name;
for (auto&& type : types) {
name += type->get_name();
}
return name;
}
constexpr std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_";
class AggregateFunctionSimpleFactory {
public:
using Creator = AggregateFunctionCreator;
private:
using AggregateFunctions = std::unordered_map<std::string, Creator>;
constexpr static std::string_view combiner_names[] = {"_foreach", "_foreachv2"};
AggregateFunctions aggregate_functions;
AggregateFunctions nullable_aggregate_functions;
std::unordered_map<std::string, std::string> function_alias;
public:
static bool is_foreach(const std::string& name) {
constexpr std::string_view suffix = "_foreach";
if (name.length() < suffix.length()) {
return false;
}
return name.substr(name.length() - suffix.length()) == suffix;
}
static bool is_foreachv2(const std::string& name) {
constexpr std::string_view suffix = "_foreachv2";
if (name.length() < suffix.length()) {
return false;
}
return name.substr(name.length() - suffix.length()) == suffix;
}
static bool result_nullable_by_foreach(DataTypePtr& data_type) {
// The return value of the 'foreach' function is 'null' or 'array<type>'.
// The internal function's nullable should depend on whether 'type' is nullable
DCHECK(data_type->is_nullable());
return assert_cast<const DataTypeArray*>(remove_nullable(data_type).get())
->get_nested_type()
->is_nullable();
}
void register_distinct_function_combinator(const Creator& creator, const std::string& prefix,
bool nullable = false) {
auto& functions = nullable ? nullable_aggregate_functions : aggregate_functions;
std::vector<std::string> need_insert;
for (const auto& entity : aggregate_functions) {
std::string target_value = prefix + entity.first;
if (functions.find(target_value) == functions.end()) {
need_insert.emplace_back(std::move(target_value));
}
}
for (const auto& function_name : need_insert) {
register_function(function_name, creator, nullable);
}
}
void register_foreach_function_combinator(const Creator& creator, const std::string& suffix,
bool nullable = false) {
auto& functions = nullable ? nullable_aggregate_functions : aggregate_functions;
std::vector<std::string> need_insert;
for (const auto& entity : aggregate_functions) {
std::string target_value = entity.first + suffix;
if (functions.find(target_value) == functions.end()) {
need_insert.emplace_back(std::move(target_value));
}
}
for (const auto& function_name : need_insert) {
register_function(function_name, creator, nullable);
}
}
AggregateFunctionPtr get(const std::string& name, const DataTypes& argument_types,
const DataTypePtr& result_type, const bool result_is_nullable,
int be_version, AggregateFunctionAttr attr = {}) {
bool nullable = false;
for (const auto& type : argument_types) {
if (type->is_nullable()) {
nullable = true;
}
}
std::string name_str = name;
temporary_function_update(be_version, name_str);
if (function_alias.contains(name)) {
name_str = function_alias[name];
}
if (nullable) {
return nullable_aggregate_functions.find(name_str) == nullable_aggregate_functions.end()
? nullptr
: nullable_aggregate_functions[name_str](name_str, argument_types,
result_type, result_is_nullable,
attr);
} else {
return aggregate_functions.find(name_str) == aggregate_functions.end()
? nullptr
: aggregate_functions[name_str](name_str, argument_types, result_type,
result_is_nullable, attr);
}
}
void register_function(const std::string& name, const Creator& creator, bool nullable = false) {
if (nullable) {
nullable_aggregate_functions[name] = creator;
} else {
aggregate_functions[name] = creator;
}
}
void register_function_both(const std::string& name, const Creator& creator) {
register_function(name, creator, false);
register_function(name, creator, true);
}
void register_alias(const std::string& name, const std::string& alias) {
function_alias[alias] = name;
for (const auto& s : combiner_names) {
function_alias[alias + std::string(s)] = name + std::string(s);
}
function_alias[DISTINCT_FUNCTION_PREFIX + alias] = DISTINCT_FUNCTION_PREFIX + name;
}
void register_alternative_function(const std::string& name, const Creator& creator,
bool nullable, int old_be_exec_version) {
auto new_name = name + BeExecVersionManager::get_function_suffix(old_be_exec_version);
register_function(new_name, creator, nullable);
BeExecVersionManager::registe_old_function_compatibility(old_be_exec_version, name);
}
void temporary_function_update(int fe_version_now, std::string& name) {
int old_version = BeExecVersionManager::get_function_compatibility(fe_version_now, name);
if (!old_version) {
return;
}
name = name + BeExecVersionManager::get_function_suffix(old_version);
}
static AggregateFunctionSimpleFactory& instance();
};
}; // namespace doris
#include "common/compile_check_end.h"