| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #pragma once |
| |
| #include <cstddef> |
| #include <type_traits> |
| |
| #include "core/arena.h" |
| #include "core/column/column_decimal.h" |
| #include "core/column/column_vector.h" |
| #include "core/data_type/data_type_decimal.h" |
| #include "core/data_type/data_type_number.h" |
| #include "core/string_buffer.hpp" |
| #include "core/types.h" |
| #include "exprs/aggregate/aggregate_function.h" |
| #include "util/io_helper.h" |
| |
| namespace doris { |
| #include "common/compile_check_begin.h" |
| |
| template <PrimitiveType T> |
| struct AggregateFunctionProductData { |
| typename PrimitiveTypeTraits<T>::CppType product {}; |
| |
| void add_impl(typename PrimitiveTypeTraits<T>::CppType value, |
| typename PrimitiveTypeTraits<T>::CppType& product_ref) { |
| if constexpr (std::is_integral_v<typename PrimitiveTypeTraits<T>::CppType>) { |
| typename PrimitiveTypeTraits<T>::CppType new_product; |
| if (__builtin_expect(common::mul_overflow(product_ref, value, new_product), false)) { |
| // if overflow, set product to infinity to keep the same behavior with double type |
| throw Exception(ErrorCode::INTERNAL_ERROR, |
| "Product overflow for type {} and value {} * {}", T, value, |
| product_ref); |
| } else { |
| product_ref = new_product; |
| } |
| } else { |
| // which type is float or double |
| product_ref *= value; |
| } |
| } |
| |
| void add(typename PrimitiveTypeTraits<T>::CppType value, |
| typename PrimitiveTypeTraits<T>::CppType) { |
| add_impl(value, product); |
| VLOG_DEBUG << "product: " << product; |
| } |
| |
| void merge(const AggregateFunctionProductData& other, |
| typename PrimitiveTypeTraits<T>::CppType) { |
| add_impl(other.product, product); |
| VLOG_DEBUG << "product: " << product; |
| } |
| |
| void write(BufferWritable& buffer) const { buffer.write_binary(product); } |
| |
| void read(BufferReadable& buffer) { buffer.read_binary(product); } |
| |
| typename PrimitiveTypeTraits<T>::CppType get() const { return product; } |
| |
| void reset(typename PrimitiveTypeTraits<T>::CppType value) { product = std::move(value); } |
| }; |
| |
| template <> |
| struct AggregateFunctionProductData<TYPE_DECIMALV2> { |
| Decimal128V2 product {}; |
| |
| void add(Decimal128V2 value, Decimal128V2) { |
| DecimalV2Value decimal_product(product); |
| DecimalV2Value decimal_value(value); |
| DecimalV2Value ret = decimal_product * decimal_value; |
| memcpy(&product, &ret, sizeof(Decimal128V2)); |
| } |
| |
| void merge(const AggregateFunctionProductData& other, Decimal128V2) { |
| DecimalV2Value decimal_product(product); |
| DecimalV2Value decimal_value(other.product); |
| DecimalV2Value ret = decimal_product * decimal_value; |
| memcpy(&product, &ret, sizeof(Decimal128V2)); |
| } |
| |
| void write(BufferWritable& buffer) const { buffer.write_binary(product); } |
| |
| void read(BufferReadable& buffer) { buffer.read_binary(product); } |
| |
| Decimal128V2 get() const { return product; } |
| |
| void reset(Decimal128V2 value) { product = std::move(value); } |
| }; |
| |
| template <PrimitiveType T> |
| requires(T == TYPE_DECIMAL128I || T == TYPE_DECIMAL256) |
| struct AggregateFunctionProductData<T> { |
| typename PrimitiveTypeTraits<T>::CppType product {}; |
| |
| template <typename NestedType> |
| void add(Decimal<NestedType> value, typename PrimitiveTypeTraits<T>::CppType multiplier) { |
| product *= value; |
| product /= multiplier; |
| } |
| |
| void merge(const AggregateFunctionProductData& other, |
| typename PrimitiveTypeTraits<T>::CppType multiplier) { |
| product *= other.product; |
| product /= multiplier; |
| } |
| |
| void write(BufferWritable& buffer) const { buffer.write_binary(product); } |
| |
| void read(BufferReadable& buffer) { buffer.read_binary(product); } |
| |
| typename PrimitiveTypeTraits<T>::CppType get() const { return product; } |
| |
| void reset(typename PrimitiveTypeTraits<T>::CppType value) { product = std::move(value); } |
| }; |
| |
| template <PrimitiveType T, PrimitiveType TResult, typename Data> |
| class AggregateFunctionProduct; |
| |
| template <PrimitiveType T, PrimitiveType TResult> |
| constexpr static bool is_valid_product_types = |
| (is_same_or_wider_decimalv3(T, TResult) || (is_decimalv2(T) && is_decimalv2(TResult)) || |
| (is_float_or_double(T) && is_float_or_double(TResult)) || |
| (is_int_or_bool(T) && is_int(TResult))); |
| template <PrimitiveType T, PrimitiveType TResult, typename Data> |
| requires(is_valid_product_types<T, TResult>) |
| class AggregateFunctionProduct<T, TResult, Data> final |
| : public IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>, |
| UnaryExpression, |
| NullableAggregateFunction { |
| public: |
| using ResultDataType = typename PrimitiveTypeTraits<TResult>::DataType; |
| using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType; |
| using ColVecResult = typename PrimitiveTypeTraits<TResult>::ColumnType; |
| |
| std::string get_name() const override { return "product"; } |
| |
| AggregateFunctionProduct(const DataTypes& argument_types_) |
| : IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>( |
| argument_types_), |
| scale(get_decimal_scale(*argument_types_[0])) { |
| if constexpr (is_decimal(T)) { |
| multiplier.value = |
| ResultDataType::get_scale_multiplier(get_decimal_scale(*argument_types_[0])); |
| } |
| } |
| |
| DataTypePtr get_return_type() const override { |
| if constexpr (is_decimal(T)) { |
| return std::make_shared<ResultDataType>(ResultDataType::max_precision(), scale); |
| } else { |
| return std::make_shared<ResultDataType>(); |
| } |
| } |
| |
| void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
| Arena&) const override { |
| const auto& column = |
| assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]); |
| this->data(place).add( |
| typename PrimitiveTypeTraits<TResult>::CppType(column.get_data()[row_num]), |
| multiplier); |
| } |
| |
| void reset(AggregateDataPtr place) const override { |
| if constexpr (is_decimal(T)) { |
| this->data(place).reset(multiplier); |
| } else { |
| this->data(place).reset(1); |
| } |
| } |
| |
| void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
| Arena&) const override { |
| this->data(place).merge(this->data(rhs), multiplier); |
| } |
| |
| void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
| this->data(place).write(buf); |
| } |
| |
| void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
| Arena&) const override { |
| this->data(place).read(buf); |
| } |
| |
| void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
| auto& column = assert_cast<ColVecResult&>(to); |
| column.get_data().push_back(this->data(place).get()); |
| } |
| |
| private: |
| UInt32 scale; |
| typename PrimitiveTypeTraits<TResult>::CppType multiplier; |
| }; |
| |
| } // namespace doris |
| |
| #include "common/compile_check_end.h" |