blob: 6035815d67b3702da23df447c8dcf65a96498104 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Helpers.h
// and modified by Doris
#pragma once
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_null.h"
#include "vec/data_types/data_type.h"
#include "vec/utils/template_helpers.hpp"
/** If the serialized type is not the default type(string),
* aggregation function need to override these functions:
* 1. serialize_to_column
* 2. streaming_agg_serialize_to_column
* 3. deserialize_and_merge_vec
* 4. deserialize_and_merge_vec_selected
* 5. serialize_without_key_to_column
* 6. deserialize_and_merge_from_column
*/
#define CHECK_AGG_FUNCTION_SERIALIZED_TYPE(FunctionTemplate) \
do { \
constexpr bool _is_new_serialized_type = \
!std::is_same_v<decltype(&FunctionTemplate::get_serialized_type), \
decltype(&IAggregateFunction::get_serialized_type)>; \
if constexpr (_is_new_serialized_type) { \
static_assert(!std::is_same_v<decltype(&FunctionTemplate::serialize_to_column), \
decltype(&IAggregateFunctionHelper< \
FunctionTemplate>::serialize_to_column)>, \
"need to override serialize_to_column"); \
static_assert( \
!std::is_same_v< \
decltype(&FunctionTemplate::streaming_agg_serialize_to_column), \
decltype(&IAggregateFunction::streaming_agg_serialize_to_column)>, \
"need to override " \
"streaming_agg_serialize_to_column"); \
static_assert(!std::is_same_v<decltype(&FunctionTemplate::deserialize_and_merge_vec), \
decltype(&IAggregateFunctionHelper< \
FunctionTemplate>::deserialize_and_merge_vec)>, \
"need to override deserialize_and_merge_vec"); \
static_assert( \
!std::is_same_v< \
decltype(&FunctionTemplate::deserialize_and_merge_vec_selected), \
decltype(&IAggregateFunctionHelper< \
FunctionTemplate>::deserialize_and_merge_vec_selected)>, \
"need to override " \
"deserialize_and_merge_vec_selected"); \
static_assert( \
!std::is_same_v<decltype(&FunctionTemplate::serialize_without_key_to_column), \
decltype(&IAggregateFunctionHelper< \
FunctionTemplate>::serialize_without_key_to_column)>, \
"need to override serialize_without_key_to_column"); \
static_assert(!std::is_same_v< \
decltype(&FunctionTemplate::deserialize_and_merge_from_column), \
decltype(&IAggregateFunctionHelper< \
FunctionTemplate>::deserialize_and_merge_from_column)>, \
"need to override " \
"deserialize_and_merge_from_column"); \
} \
} while (false)
namespace doris::vectorized {
#include "common/compile_check_begin.h"
struct creator_without_type {
template <bool multi_arguments, bool f, typename T>
using NullableT = std::conditional_t<multi_arguments, AggregateFunctionNullVariadicInline<T, f>,
AggregateFunctionNullUnaryInline<T, f>>;
template <typename AggregateFunctionTemplate>
static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
CHECK_AGG_FUNCTION_SERIALIZED_TYPE(AggregateFunctionTemplate);
return create<AggregateFunctionTemplate>(argument_types, result_is_nullable);
}
template <typename AggregateFunctionTemplate, typename... TArgs>
static AggregateFunctionPtr create(const DataTypes& argument_types_,
const bool result_is_nullable, TArgs&&... args) {
std::unique_ptr<IAggregateFunction> result(std::make_unique<AggregateFunctionTemplate>(
std::forward<TArgs>(args)..., remove_nullable(argument_types_)));
if (have_nullable(argument_types_)) {
std::visit(
[&](auto multi_arguments, auto result_is_nullable) {
result.reset(new NullableT<multi_arguments, result_is_nullable,
AggregateFunctionTemplate>(result.release(),
argument_types_));
},
make_bool_variant(argument_types_.size() > 1),
make_bool_variant(result_is_nullable));
}
CHECK_AGG_FUNCTION_SERIALIZED_TYPE(AggregateFunctionTemplate);
return AggregateFunctionPtr(result.release());
}
/// AggregateFunctionTemplate will handle the nullable arguments, no need to use
/// AggregateFunctionNullVariadicInline/AggregateFunctionNullUnaryInline
template <typename AggregateFunctionTemplate, typename... TArgs>
static AggregateFunctionPtr create_ignore_nullable(const DataTypes& argument_types_,
const bool /*result_is_nullable*/,
TArgs&&... args) {
std::unique_ptr<IAggregateFunction> result = std::make_unique<AggregateFunctionTemplate>(
std::forward<TArgs>(args)..., argument_types_);
CHECK_AGG_FUNCTION_SERIALIZED_TYPE(AggregateFunctionTemplate);
return AggregateFunctionPtr(result.release());
}
};
template <template <PrimitiveType> class AggregateFunctionTemplate>
struct CurryDirect {
template <PrimitiveType type>
using T = AggregateFunctionTemplate<type>;
};
template <template <typename> class AggregateFunctionTemplate, template <PrimitiveType> class Data>
struct CurryData {
template <PrimitiveType Type>
using T = AggregateFunctionTemplate<Data<Type>>;
};
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
template <PrimitiveType> class Impl>
struct CurryDataImpl {
template <PrimitiveType Type>
using T = AggregateFunctionTemplate<Data<Impl<Type>>>;
};
template <template <PrimitiveType, typename> class AggregateFunctionTemplate,
template <PrimitiveType> class Data>
struct CurryDirectAndData {
template <PrimitiveType Type>
using T = AggregateFunctionTemplate<Type, Data<Type>>;
};
template <bool allow_integer, bool allow_float, bool allow_decimal, int define_index = 0>
struct creator_with_type_base {
template <typename Class, typename... TArgs>
static AggregateFunctionPtr create_base(const DataTypes& argument_types,
const bool result_is_nullable, TArgs&&... args) {
if constexpr (allow_integer) {
switch (argument_types[define_index]->get_primitive_type()) {
case PrimitiveType::TYPE_BOOLEAN:
return creator_without_type::create<typename Class::template T<TYPE_BOOLEAN>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_TINYINT:
return creator_without_type::create<typename Class::template T<TYPE_TINYINT>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_SMALLINT:
return creator_without_type::create<typename Class::template T<TYPE_SMALLINT>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_INT:
return creator_without_type::create<typename Class::template T<TYPE_INT>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_BIGINT:
return creator_without_type::create<typename Class::template T<TYPE_BIGINT>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_LARGEINT:
return creator_without_type::create<typename Class::template T<TYPE_LARGEINT>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
default:
break;
}
}
if constexpr (allow_float) {
switch (argument_types[define_index]->get_primitive_type()) {
case PrimitiveType::TYPE_FLOAT:
return creator_without_type::create<typename Class::template T<TYPE_FLOAT>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_DOUBLE:
return creator_without_type::create<typename Class::template T<TYPE_DOUBLE>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
default:
break;
}
}
if constexpr (allow_decimal) {
switch (argument_types[define_index]->get_primitive_type()) {
case PrimitiveType::TYPE_DECIMAL32:
return creator_without_type::create<typename Class::template T<TYPE_DECIMAL32>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_DECIMAL64:
return creator_without_type::create<typename Class::template T<TYPE_DECIMAL64>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_DECIMALV2:
return creator_without_type::create<typename Class::template T<TYPE_DECIMALV2>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_DECIMAL128I:
return creator_without_type::create<typename Class::template T<TYPE_DECIMAL128I>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
case PrimitiveType::TYPE_DECIMAL256:
return creator_without_type::create<typename Class::template T<TYPE_DECIMAL256>>(
argument_types, result_is_nullable, std::forward<TArgs>(args)...);
default:
break;
}
}
return nullptr;
}
template <template <PrimitiveType> class AggregateFunctionTemplate>
static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_base<CurryDirect<AggregateFunctionTemplate>>(argument_types,
result_is_nullable);
}
template <template <PrimitiveType> class AggregateFunctionTemplate, typename... TArgs>
static AggregateFunctionPtr create(TArgs&&... args) {
return create_base<CurryDirect<AggregateFunctionTemplate>>(std::forward<TArgs>(args)...);
}
template <template <typename> class AggregateFunctionTemplate,
template <PrimitiveType> class Data>
static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_base<CurryData<AggregateFunctionTemplate, Data>>(argument_types,
result_is_nullable);
}
template <template <typename> class AggregateFunctionTemplate,
template <PrimitiveType> class Data, typename... TArgs>
static AggregateFunctionPtr create(TArgs&&... args) {
return create_base<CurryData<AggregateFunctionTemplate, Data>>(
std::forward<TArgs>(args)...);
}
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
template <PrimitiveType> class Impl>
static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_base<CurryDataImpl<AggregateFunctionTemplate, Data, Impl>>(
argument_types, result_is_nullable);
}
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
template <PrimitiveType> class Impl, typename... TArgs>
static AggregateFunctionPtr create(TArgs&&... args) {
return create_base<CurryDataImpl<AggregateFunctionTemplate, Data, Impl>>(
std::forward<TArgs>(args)...);
}
template <template <PrimitiveType, typename> class AggregateFunctionTemplate,
template <PrimitiveType> class Data>
static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_base<CurryDirectAndData<AggregateFunctionTemplate, Data>>(argument_types,
result_is_nullable);
}
template <template <PrimitiveType, typename> class AggregateFunctionTemplate,
template <PrimitiveType> class Data, typename... TArgs>
static AggregateFunctionPtr create(TArgs&&... args) {
return create_base<CurryDirectAndData<AggregateFunctionTemplate, Data>>(
std::forward<TArgs>(args)...);
}
};
using creator_with_integer_type = creator_with_type_base<true, false, false>;
using creator_with_numeric_type = creator_with_type_base<true, true, false>;
using creator_with_decimal_type = creator_with_type_base<false, false, true>;
using creator_with_type = creator_with_type_base<true, true, true>;
} // namespace doris::vectorized
#include "common/compile_check_end.h"