blob: 7511886d3b8eca7eec6d7908d599d7e08a7f279c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cstddef>
#include <type_traits>
#include "core/arena.h"
#include "core/column/column_decimal.h"
#include "core/column/column_vector.h"
#include "core/data_type/data_type_decimal.h"
#include "core/data_type/data_type_number.h"
#include "core/string_buffer.hpp"
#include "core/types.h"
#include "exprs/aggregate/aggregate_function.h"
#include "util/io_helper.h"
namespace doris {
#include "common/compile_check_begin.h"
template <PrimitiveType T>
struct AggregateFunctionProductData {
typename PrimitiveTypeTraits<T>::CppType product {};
void add_impl(typename PrimitiveTypeTraits<T>::CppType value,
typename PrimitiveTypeTraits<T>::CppType& product_ref) {
if constexpr (std::is_integral_v<typename PrimitiveTypeTraits<T>::CppType>) {
typename PrimitiveTypeTraits<T>::CppType new_product;
if (__builtin_expect(common::mul_overflow(product_ref, value, new_product), false)) {
// if overflow, set product to infinity to keep the same behavior with double type
throw Exception(ErrorCode::INTERNAL_ERROR,
"Product overflow for type {} and value {} * {}", T, value,
product_ref);
} else {
product_ref = new_product;
}
} else {
// which type is float or double
product_ref *= value;
}
}
void add(typename PrimitiveTypeTraits<T>::CppType value,
typename PrimitiveTypeTraits<T>::CppType) {
add_impl(value, product);
VLOG_DEBUG << "product: " << product;
}
void merge(const AggregateFunctionProductData& other,
typename PrimitiveTypeTraits<T>::CppType) {
add_impl(other.product, product);
VLOG_DEBUG << "product: " << product;
}
void write(BufferWritable& buffer) const { buffer.write_binary(product); }
void read(BufferReadable& buffer) { buffer.read_binary(product); }
typename PrimitiveTypeTraits<T>::CppType get() const { return product; }
void reset(typename PrimitiveTypeTraits<T>::CppType value) { product = std::move(value); }
};
template <>
struct AggregateFunctionProductData<TYPE_DECIMALV2> {
Decimal128V2 product {};
void add(Decimal128V2 value, Decimal128V2) {
DecimalV2Value decimal_product(product);
DecimalV2Value decimal_value(value);
DecimalV2Value ret = decimal_product * decimal_value;
memcpy(&product, &ret, sizeof(Decimal128V2));
}
void merge(const AggregateFunctionProductData& other, Decimal128V2) {
DecimalV2Value decimal_product(product);
DecimalV2Value decimal_value(other.product);
DecimalV2Value ret = decimal_product * decimal_value;
memcpy(&product, &ret, sizeof(Decimal128V2));
}
void write(BufferWritable& buffer) const { buffer.write_binary(product); }
void read(BufferReadable& buffer) { buffer.read_binary(product); }
Decimal128V2 get() const { return product; }
void reset(Decimal128V2 value) { product = std::move(value); }
};
template <PrimitiveType T>
requires(T == TYPE_DECIMAL128I || T == TYPE_DECIMAL256)
struct AggregateFunctionProductData<T> {
typename PrimitiveTypeTraits<T>::CppType product {};
template <typename NestedType>
void add(Decimal<NestedType> value, typename PrimitiveTypeTraits<T>::CppType multiplier) {
product *= value;
product /= multiplier;
}
void merge(const AggregateFunctionProductData& other,
typename PrimitiveTypeTraits<T>::CppType multiplier) {
product *= other.product;
product /= multiplier;
}
void write(BufferWritable& buffer) const { buffer.write_binary(product); }
void read(BufferReadable& buffer) { buffer.read_binary(product); }
typename PrimitiveTypeTraits<T>::CppType get() const { return product; }
void reset(typename PrimitiveTypeTraits<T>::CppType value) { product = std::move(value); }
};
template <PrimitiveType T, PrimitiveType TResult, typename Data>
class AggregateFunctionProduct;
template <PrimitiveType T, PrimitiveType TResult>
constexpr static bool is_valid_product_types =
(is_same_or_wider_decimalv3(T, TResult) || (is_decimalv2(T) && is_decimalv2(TResult)) ||
(is_float_or_double(T) && is_float_or_double(TResult)) ||
(is_int_or_bool(T) && is_int(TResult)));
template <PrimitiveType T, PrimitiveType TResult, typename Data>
requires(is_valid_product_types<T, TResult>)
class AggregateFunctionProduct<T, TResult, Data> final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>,
UnaryExpression,
NullableAggregateFunction {
public:
using ResultDataType = typename PrimitiveTypeTraits<TResult>::DataType;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using ColVecResult = typename PrimitiveTypeTraits<TResult>::ColumnType;
std::string get_name() const override { return "product"; }
AggregateFunctionProduct(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>(
argument_types_),
scale(get_decimal_scale(*argument_types_[0])) {
if constexpr (is_decimal(T)) {
multiplier.value =
ResultDataType::get_scale_multiplier(get_decimal_scale(*argument_types_[0]));
}
}
DataTypePtr get_return_type() const override {
if constexpr (is_decimal(T)) {
return std::make_shared<ResultDataType>(ResultDataType::max_precision(), scale);
} else {
return std::make_shared<ResultDataType>();
}
}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& column =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
this->data(place).add(
typename PrimitiveTypeTraits<TResult>::CppType(column.get_data()[row_num]),
multiplier);
}
void reset(AggregateDataPtr place) const override {
if constexpr (is_decimal(T)) {
this->data(place).reset(multiplier);
} else {
this->data(place).reset(1);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
this->data(place).merge(this->data(rhs), multiplier);
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& column = assert_cast<ColVecResult&>(to);
column.get_data().push_back(this->data(place).get());
}
private:
UInt32 scale;
typename PrimitiveTypeTraits<TResult>::CppType multiplier;
};
} // namespace doris
#include "common/compile_check_end.h"