| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #pragma once |
| |
| #include <Functions/FunctionsRound.h> |
| |
| namespace DB::ErrorCodes |
| { |
| extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| } |
| |
| namespace local_engine |
| { |
| |
| template <typename T, DB::Vectorize vectorize> |
| class BaseFloatRoundingHalfUpComputation; |
| |
| #ifdef __SSE4_1__ |
| |
| /// vectorized implementation for x86 |
| |
| template <> |
| class BaseFloatRoundingHalfUpComputation<Float32, DB::Vectorize::Yes> |
| { |
| public: |
| using ScalarType = Float32; |
| using VectorType = __m128; |
| static const size_t data_count = 4; |
| |
| static VectorType load(const ScalarType * in) { return _mm_loadu_ps(in); } |
| static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); } |
| static void store(ScalarType * out, VectorType val) { _mm_storeu_ps(out, val); } |
| static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_ps(val, scale); } |
| static VectorType divide(VectorType val, VectorType scale) { return _mm_div_ps(val, scale); } |
| template <DB::RoundingMode mode> |
| static VectorType apply(VectorType val) |
| { |
| ScalarType tempFloatsIn[data_count]; |
| ScalarType tempFloatsOut[data_count]; |
| store(tempFloatsIn, val); |
| for (size_t i = 0; i < data_count; ++i) |
| tempFloatsOut[i] = std::roundf(tempFloatsIn[i]); |
| |
| return load(tempFloatsOut); |
| } |
| |
| static VectorType prepare(size_t scale) { return load1(scale); } |
| }; |
| |
| template <> |
| class BaseFloatRoundingHalfUpComputation<Float64, DB::Vectorize::Yes> |
| { |
| public: |
| using ScalarType = Float64; |
| using VectorType = __m128d; |
| static const size_t data_count = 2; |
| |
| static VectorType load(const ScalarType * in) { return _mm_loadu_pd(in); } |
| static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); } |
| static void store(ScalarType * out, VectorType val) { _mm_storeu_pd(out, val); } |
| static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_pd(val, scale); } |
| static VectorType divide(VectorType val, VectorType scale) { return _mm_div_pd(val, scale); } |
| template <DB::RoundingMode mode> |
| static VectorType apply(VectorType val) |
| { |
| ScalarType tempFloatsIn[data_count]; |
| ScalarType tempFloatsOut[data_count]; |
| store(tempFloatsIn, val); |
| for (size_t i = 0; i < data_count; ++i) |
| tempFloatsOut[i] = std::round(tempFloatsIn[i]); |
| |
| return load(tempFloatsOut); |
| } |
| |
| static VectorType prepare(size_t scale) { return load1(scale); } |
| }; |
| |
| /// end __SSE4_1__ |
| #endif |
| |
| /// Sequential implementation for ARM. Also used for scalar arguments |
| |
| template <typename T> |
| class BaseFloatRoundingHalfUpComputation<T, DB::Vectorize::No> |
| { |
| public: |
| using ScalarType = T; |
| using VectorType = T; |
| static const size_t data_count = 1; |
| |
| static VectorType load(const ScalarType * in) { return *in; } |
| static VectorType load1(const ScalarType in) { return in; } |
| static VectorType store(ScalarType * out, ScalarType val) { return *out = val;} |
| static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } |
| static VectorType divide(VectorType val, VectorType scale) { return val / scale; } |
| template <DB::RoundingMode mode> |
| static VectorType apply(VectorType val) |
| { |
| if constexpr (std::is_same_v<ScalarType, Float32>) |
| { |
| return std::roundf(val); |
| } |
| else |
| { |
| return std::round(val); |
| } |
| } |
| |
| static VectorType prepare(size_t scale) |
| { |
| return load1(scale); |
| } |
| }; |
| |
| template <> |
| class BaseFloatRoundingHalfUpComputation<BFloat16, DB::Vectorize::No> |
| { |
| public: |
| using ScalarType = BFloat16; |
| using VectorType = BFloat16; |
| static const size_t data_count = 1; |
| |
| static VectorType load(const ScalarType * in) { return *in; } |
| static VectorType load1(const ScalarType in) { return in; } |
| static VectorType store(ScalarType * out, ScalarType val) { return *out = val;} |
| static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } |
| static VectorType divide(VectorType val, VectorType scale) { return val / scale; } |
| template <DB::RoundingMode mode> |
| static VectorType apply(VectorType val) |
| { |
| return BFloat16(std::roundf(static_cast<Float32>(val))); |
| } |
| |
| static VectorType prepare(size_t scale) |
| { |
| return load1(BFloat16(static_cast<Float32>(scale))); |
| } |
| }; |
| |
| |
| /** Implementation of low-level round-off functions for floating-point values. |
| */ |
| template <typename T, DB::RoundingMode rounding_mode, DB::ScaleMode scale_mode, DB::Vectorize vectorize> |
| class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation<T, vectorize> |
| { |
| using Base = BaseFloatRoundingHalfUpComputation<T, vectorize>; |
| |
| public: |
| static inline void compute(const T * __restrict in, const typename Base::VectorType & scale, T * __restrict out) |
| { |
| auto val = Base::load(in); |
| |
| if (scale_mode == DB::ScaleMode::Positive) |
| val = Base::multiply(val, scale); |
| else if (scale_mode == DB::ScaleMode::Negative) |
| val = Base::divide(val, scale); |
| |
| val = Base::template apply<rounding_mode>(val); |
| |
| if (scale_mode == DB::ScaleMode::Positive) |
| val = Base::divide(val, scale); |
| else if (scale_mode == DB::ScaleMode::Negative) |
| val = Base::multiply(val, scale); |
| |
| Base::store(out, val); |
| } |
| }; |
| |
| |
| /** Implementing high-level rounding functions. |
| */ |
| template <typename T, DB::RoundingMode rounding_mode, DB::ScaleMode scale_mode> |
| struct FloatRoundingHalfUpImpl |
| { |
| private: |
| static_assert(!DB::is_decimal<T>); |
| |
| template <DB::Vectorize vectorize = |
| #ifdef __SSE4_1__ |
| std::is_same_v<T, BFloat16> ? DB::Vectorize::No : DB::Vectorize::Yes |
| #else |
| DB::Vectorize::No |
| #endif |
| > |
| using Op = FloatRoundingHalfUpComputation<T, rounding_mode, scale_mode, vectorize>; |
| using Data = std::array<T, Op<>::data_count>; |
| using ColumnType = DB::ColumnVector<T>; |
| using Container = typename ColumnType::Container; |
| |
| public: |
| static NO_INLINE void apply(const Container & in, size_t scale, Container & out) |
| { |
| auto mm_scale = Op<>::prepare(scale); |
| |
| const size_t data_count = std::tuple_size<Data>(); |
| |
| const T * end_in = in.data() + in.size(); |
| const T * limit = in.data() + in.size() / data_count * data_count; |
| |
| const T * __restrict p_in = in.data(); |
| T * __restrict p_out = out.data(); |
| |
| while (p_in < limit) |
| { |
| Op<>::compute(p_in, mm_scale, p_out); |
| p_in += data_count; |
| p_out += data_count; |
| } |
| |
| if (p_in < end_in) |
| { |
| Data tmp_src{{}}; |
| Data tmp_dst; |
| |
| size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in); |
| |
| memcpy(&tmp_src, p_in, tail_size_bytes); |
| Op<>::compute(reinterpret_cast<T *>(&tmp_src), mm_scale, reinterpret_cast<T *>(&tmp_dst)); |
| memcpy(p_out, &tmp_dst, tail_size_bytes); |
| } |
| } |
| }; |
| |
| |
| /** Select the appropriate processing algorithm depending on the scale. |
| */ |
| template <typename T, DB::RoundingMode rounding_mode, DB::TieBreakingMode tie_breaking_mode> |
| struct DispatcherRoundingHalfUp |
| { |
| template <DB::ScaleMode scale_mode> |
| using FunctionRoundingImpl = std::conditional_t< |
| std::is_floating_point_v<T> || std::is_same_v<T, BFloat16>, |
| FloatRoundingHalfUpImpl<T, rounding_mode, scale_mode>, |
| DB::IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>; |
| |
| static DB::ColumnPtr apply(const DB::IColumn * col_general, DB::Scale scale_arg) |
| { |
| const auto * const col = checkAndGetColumn<DB::ColumnVector<T>>(col_general); |
| auto col_res = DB::ColumnVector<T>::create(); |
| |
| typename DB::ColumnVector<T>::Container & vec_res = col_res->getData(); |
| vec_res.resize_exact(col->getData().size()); |
| |
| if (!vec_res.empty()) |
| { |
| if (scale_arg == 0) |
| { |
| size_t scale = 1; |
| FunctionRoundingImpl<DB::ScaleMode::Zero>::apply(col->getData(), scale, vec_res); |
| } |
| else if (scale_arg > 0) |
| { |
| size_t scale = intExp10(scale_arg); |
| FunctionRoundingImpl<DB::ScaleMode::Positive>::apply(col->getData(), scale, vec_res); |
| } |
| else |
| { |
| size_t scale = intExp10(-scale_arg); |
| FunctionRoundingImpl<DB::ScaleMode::Negative>::apply(col->getData(), scale, vec_res); |
| } |
| } |
| |
| return col_res; |
| } |
| }; |
| |
| template <DB::is_decimal T, DB::RoundingMode rounding_mode, DB::TieBreakingMode tie_breaking_mode> |
| struct DispatcherRoundingHalfUp<T, rounding_mode, tie_breaking_mode> |
| { |
| public: |
| static DB::ColumnPtr apply(const DB::IColumn * col_general, DB::Scale scale_arg) |
| { |
| const auto * const col = checkAndGetColumn<DB::ColumnDecimal<T>>(col_general); |
| const typename DB::ColumnDecimal<T>::Container & vec_src = col->getData(); |
| |
| auto col_res = DB::ColumnDecimal<T>::create(vec_src.size(), col->getScale()); |
| auto & vec_res = col_res->getData(); |
| |
| if (!vec_res.empty()) |
| DB::DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(col->getData(), col->getScale(), vec_res, scale_arg); |
| |
| return col_res; |
| } |
| }; |
| |
| /** A template for functions that round the value of an input parameter of type |
| * (U)Int8/16/32/64, Float32/64 or Decimal32/64/128, and accept an additional optional parameter (default is 0). |
| */ |
| template <typename Name, DB::RoundingMode rounding_mode, DB::TieBreakingMode tie_breaking_mode> |
| class FunctionRoundingHalfUp : public DB::IFunction |
| { |
| public: |
| static constexpr auto name = "roundHalfUp"; |
| static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<FunctionRoundingHalfUp>(); } |
| |
| String getName() const override { return name; } |
| |
| bool isVariadic() const override { return true; } |
| size_t getNumberOfArguments() const override { return 0; } |
| bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return false; } |
| |
| /// Get result types by argument types. If the function does not apply to these arguments, throw an exception. |
| DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override |
| { |
| if ((arguments.empty()) || (arguments.size() > 2)) |
| throw DB::Exception( |
| DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, |
| "Number of arguments for function {} doesn't match: passed {}, should be 1 or 2.", |
| getName(), |
| arguments.size()); |
| |
| for (const auto & type : arguments) |
| if (!isNumber(type)) |
| throw DB::Exception( |
| DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[0]->getName(), getName()); |
| |
| return arguments[0]; |
| } |
| |
| static DB::Scale getScaleArg(const DB::ColumnsWithTypeAndName & arguments) |
| { |
| if (arguments.size() == 2) |
| { |
| const DB::IColumn & scale_column = *arguments[1].column; |
| if (!isColumnConst(scale_column)) |
| throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "DB::Scale argument for rounding functions must be constant"); |
| |
| DB::Field scale_field = assert_cast<const DB::ColumnConst &>(scale_column).getField(); |
| if (scale_field.getType() != DB::Field::Types::UInt64 && scale_field.getType() != DB::Field::Types::Int64) |
| throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "DB::Scale argument for rounding functions must have integer type"); |
| |
| Int64 scale64 = scale_field.safeGet<Int64>(); |
| if (scale64 > std::numeric_limits<DB::Scale>::max() || scale64 < std::numeric_limits<DB::Scale>::min()) |
| throw DB::Exception(DB::ErrorCodes::ARGUMENT_OUT_OF_BOUND, "DB::Scale argument for rounding function is too large"); |
| |
| return scale64; |
| } |
| return 0; |
| } |
| |
| bool useDefaultImplementationForConstants() const override { return true; } |
| DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } |
| |
| DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t /*input_rows_count*/) const override |
| { |
| const DB::ColumnWithTypeAndName & column = arguments[0]; |
| DB::Scale scale_arg = getScaleArg(arguments); |
| |
| DB::ColumnPtr res; |
| auto call = [&](const auto & types) -> bool |
| { |
| using Types = std::decay_t<decltype(types)>; |
| using DataType = typename Types::LeftType; |
| |
| if constexpr (DB::IsDataTypeNumber<DataType> || DB::IsDataTypeDecimal<DataType>) |
| { |
| using FieldType = typename DataType::FieldType; |
| res = DispatcherRoundingHalfUp<FieldType, rounding_mode, tie_breaking_mode>::apply(column.column.get(), scale_arg); |
| return true; |
| } |
| return false; |
| }; |
| |
| if (!callOnIndexAndDataType<void>(column.type->getTypeId(), call)) |
| throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", column.name, getName()); |
| |
| return res; |
| } |
| |
| bool hasInformationAboutMonotonicity() const override { return true; } |
| |
| Monotonicity getMonotonicityForRange(const DB::IDataType &, const DB::Field &, const DB::Field &) const override |
| { |
| return {.is_monotonic = true, .is_always_monotonic = true}; |
| } |
| }; |
| |
| |
| struct NameRoundHalfUp |
| { |
| static constexpr auto name = "roundHalfUp"; |
| }; |
| |
| using FunctionRoundHalfUp = FunctionRoundingHalfUp<NameRoundHalfUp, DB::RoundingMode::Round, DB::TieBreakingMode::Auto>; |
| |
| } |