blob: 8483f8fe0efa75823577c50e55a447870579ff56 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Core/DecimalComparison.h
// and modified by Doris
#pragma once
#include "vec/columns/column_const.h"
#include "vec/columns/column_vector.h"
#include "vec/common/arithmetic_overflow.h"
#include "vec/core/accurate_comparison.h"
#include "vec/core/block.h"
#include "vec/core/call_on_type_index.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/functions/function_helpers.h" /// todo core should not depend on function"
namespace doris::vectorized {
inline bool allow_decimal_comparison(const DataTypePtr& left_type, const DataTypePtr& right_type) {
if (is_decimal(left_type->get_primitive_type())) {
if (is_decimal(right_type->get_primitive_type()) ||
is_int_or_bool(right_type->get_primitive_type()))
return true;
} else if (is_int_or_bool(left_type->get_primitive_type()) &&
is_decimal(right_type->get_primitive_type()))
return true;
return false;
}
template <size_t>
struct ConstructDecInt {
static constexpr PrimitiveType Type = TYPE_INT;
};
template <>
struct ConstructDecInt<8> {
static constexpr PrimitiveType Type = TYPE_BIGINT;
};
template <>
struct ConstructDecInt<16> {
static constexpr PrimitiveType Type = TYPE_LARGEINT;
};
template <>
struct ConstructDecInt<32> {
static constexpr PrimitiveType Type = TYPE_DECIMAL256;
};
template <PrimitiveType T, PrimitiveType U>
struct DecCompareInt {
static constexpr PrimitiveType Type =
ConstructDecInt < (!is_decimal(U) ||
sizeof(typename PrimitiveTypeTraits<T>::ColumnItemType) >
sizeof(typename PrimitiveTypeTraits<U>::ColumnItemType))
? sizeof(typename PrimitiveTypeTraits<T>::ColumnItemType)
: sizeof(typename PrimitiveTypeTraits<U>::ColumnItemType) > ::Type;
};
///
template <PrimitiveType A, PrimitiveType B, template <PrimitiveType> typename Operation,
bool _check_overflow = true, bool _actual = is_decimal(A) || is_decimal(B)>
class DecimalComparison {
public:
static constexpr PrimitiveType CompareIntPType = DecCompareInt<A, B>::Type;
using CompareInt = typename PrimitiveTypeTraits<CompareIntPType>::CppNativeType;
using Op = Operation<CompareIntPType>;
using ColVecA = typename PrimitiveTypeTraits<A>::ColumnType;
using ColVecB = typename PrimitiveTypeTraits<B>::ColumnType;
using ArrayA = typename ColVecA::Container;
using ArrayB = typename ColVecB::Container;
DecimalComparison(Block& block, uint32_t result, const ColumnWithTypeAndName& col_left,
const ColumnWithTypeAndName& col_right) {
if (!apply(block, result, col_left, col_right)) {
throw Exception(Status::FatalError("Wrong decimal comparison with {} and {}",
col_left.type->get_name(),
col_right.type->get_name()));
}
}
static bool apply(Block& block, uint32_t result [[maybe_unused]],
const ColumnWithTypeAndName& col_left,
const ColumnWithTypeAndName& col_right) {
if constexpr (_actual) {
ColumnPtr c_res;
Shift shift = getScales<A, B>(col_left.type, col_right.type);
c_res = apply_with_scale(col_left.column, col_right.column, shift);
if (c_res) {
block.replace_by_position(result, std::move(c_res));
}
return true;
}
return false;
}
static bool compare(typename PrimitiveTypeTraits<A>::ColumnItemType a,
typename PrimitiveTypeTraits<B>::ColumnItemType b, UInt32 scale_a,
UInt32 scale_b) {
static const UInt32 max_scale = max_decimal_precision<TYPE_DECIMAL256>();
if (scale_a > max_scale || scale_b > max_scale) {
throw Exception(Status::FatalError("Bad scale of decimal field"));
}
Shift shift;
if (scale_a < scale_b) {
shift.a = typename PrimitiveTypeTraits<B>::DataType(max_decimal_precision<B>(), scale_b)
.get_scale_multiplier(scale_b - scale_a);
}
if (scale_a > scale_b) {
shift.b = typename PrimitiveTypeTraits<A>::DataType(max_decimal_precision<A>(), scale_a)
.get_scale_multiplier(scale_a - scale_b);
}
return apply_with_scale(a, b, shift);
}
private:
struct Shift {
CompareInt a = 1;
CompareInt b = 1;
bool none() const { return a == 1 && b == 1; }
bool left() const { return a != 1; }
bool right() const { return b != 1; }
};
template <typename T, typename U>
static auto apply_with_scale(T a, U b, const Shift& shift) {
if (shift.left())
return apply<true, false>(a, b, shift.a);
else if (shift.right())
return apply<false, true>(a, b, shift.b);
return apply<false, false>(a, b, 1);
}
template <PrimitiveType T, PrimitiveType U>
requires(is_decimal(T) && is_decimal(U))
static Shift getScales(const DataTypePtr& left_type, const DataTypePtr& right_type) {
const typename PrimitiveTypeTraits<T>::DataType* decimal0 = check_decimal<T>(*left_type);
const typename PrimitiveTypeTraits<U>::DataType* decimal1 = check_decimal<U>(*right_type);
Shift shift;
if (decimal0 && decimal1) {
constexpr PrimitiveType Type =
sizeof(typename PrimitiveTypeTraits<T>::ColumnItemType) >=
sizeof(typename PrimitiveTypeTraits<U>::ColumnItemType)
? T
: U;
auto type_ptr = decimal_result_type(*decimal0, *decimal1, false, false, false);
const DataTypeDecimal<Type>* result_type = check_decimal<Type>(*type_ptr);
shift.a = result_type->scale_factor_for(*decimal0);
shift.b = result_type->scale_factor_for(*decimal1);
} else if (decimal0) {
shift.b = decimal0->get_scale_multiplier();
} else if (decimal1) {
shift.a = decimal1->get_scale_multiplier();
}
return shift;
}
template <PrimitiveType T, PrimitiveType U>
requires(is_decimal(T) && !is_decimal(U))
static Shift getScales(const DataTypePtr& left_type, const DataTypePtr&) {
Shift shift;
const typename PrimitiveTypeTraits<T>::DataTypeType* decimal0 =
check_decimal<T>(*left_type);
if (decimal0) {
shift.b = decimal0->get_scale_multiplier();
}
return shift;
}
template <PrimitiveType T, PrimitiveType U>
requires(!is_decimal(T) && is_decimal(U))
static Shift getScales(const DataTypePtr&, const DataTypePtr& right_type) {
Shift shift;
const typename PrimitiveTypeTraits<U>::DataType* decimal1 = check_decimal<U>(*right_type);
if (decimal1) {
shift.a = decimal1->get_scale_multiplier();
}
return shift;
}
template <bool scale_left, bool scale_right>
static ColumnPtr apply(const ColumnPtr& c0, const ColumnPtr& c1, CompareInt scale) {
if constexpr (_actual) {
bool c0_is_const = is_column_const(*c0);
bool c1_is_const = is_column_const(*c1);
if (c0_is_const && c1_is_const) {
const ColumnConst* c0_const = check_and_get_column_const<ColVecA>(c0.get());
const ColumnConst* c1_const = check_and_get_column_const<ColVecB>(c1.get());
typename PrimitiveTypeTraits<A>::ColumnItemType a = c0_const->template get_value<
typename PrimitiveTypeTraits<A>::ColumnItemType>();
typename PrimitiveTypeTraits<B>::ColumnItemType b = c1_const->template get_value<
typename PrimitiveTypeTraits<B>::ColumnItemType>();
UInt8 res = apply<scale_left, scale_right>(a, b, scale);
return DataTypeUInt8().create_column_const(c0->size(), to_field<TYPE_BOOLEAN>(res));
}
auto c_res = ColumnUInt8::create(c0->size());
ColumnUInt8::Container& vec_res = c_res->get_data();
if (c0_is_const) {
const ColumnConst* c0_const = check_and_get_column_const<ColVecA>(c0.get());
typename PrimitiveTypeTraits<A>::ColumnItemType a = c0_const->template get_value<
typename PrimitiveTypeTraits<A>::ColumnItemType>();
if (const ColVecB* c1_vec = check_and_get_column<ColVecB>(c1.get()))
constant_vector<scale_left, scale_right>(a, c1_vec->get_data(), vec_res, scale);
else {
throw Exception(Status::FatalError("Wrong column in Decimal comparison"));
}
} else if (c1_is_const) {
const ColumnConst* c1_const = check_and_get_column_const<ColVecB>(c1.get());
typename PrimitiveTypeTraits<B>::ColumnItemType b = c1_const->template get_value<
typename PrimitiveTypeTraits<B>::ColumnItemType>();
if (const ColVecA* c0_vec = check_and_get_column<ColVecA>(c0.get()))
vector_constant<scale_left, scale_right>(c0_vec->get_data(), b, vec_res, scale);
else {
throw Exception(Status::FatalError("Wrong column in Decimal comparison"));
}
} else {
if (const ColVecA* c0_vec = check_and_get_column<ColVecA>(c0.get())) {
if (const ColVecB* c1_vec = check_and_get_column<ColVecB>(c1.get()))
vector_vector<scale_left, scale_right>(c0_vec->get_data(),
c1_vec->get_data(), vec_res, scale);
else {
throw Exception(Status::FatalError("Wrong column in Decimal comparison"));
}
} else {
throw Exception(Status::FatalError("Wrong column in Decimal comparison"));
}
}
return c_res;
} else {
return ColumnUInt8::create();
}
}
template <bool scale_left, bool scale_right>
static UInt8 apply(typename PrimitiveTypeTraits<A>::ColumnItemType a,
typename PrimitiveTypeTraits<B>::ColumnItemType b,
CompareInt scale [[maybe_unused]]) {
CompareInt x = a;
CompareInt y = b;
if constexpr (_check_overflow) {
bool overflow = false;
if constexpr (sizeof(typename PrimitiveTypeTraits<A>::ColumnItemType) >
sizeof(CompareInt))
overflow |= (typename PrimitiveTypeTraits<A>::ColumnItemType(x) != a);
if constexpr (sizeof(typename PrimitiveTypeTraits<B>::ColumnItemType) >
sizeof(CompareInt))
overflow |= (typename PrimitiveTypeTraits<B>::ColumnItemType(y) != b);
if constexpr (IsUnsignedV<typename PrimitiveTypeTraits<A>::ColumnItemType>)
overflow |= (x < 0);
if constexpr (IsUnsignedV<typename PrimitiveTypeTraits<B>::ColumnItemType>)
overflow |= (y < 0);
if constexpr (scale_left) overflow |= common::mul_overflow(x, scale, x);
if constexpr (scale_right) overflow |= common::mul_overflow(y, scale, y);
if (overflow) {
throw Exception(Status::FatalError("Can't compare"));
}
} else {
if constexpr (scale_left) x *= scale;
if constexpr (scale_right) y *= scale;
}
return Op::apply(x, y);
}
template <bool scale_left, bool scale_right>
static void NO_INLINE vector_vector(const ArrayA& a, const ArrayB& b, PaddedPODArray<UInt8>& c,
CompareInt scale) {
size_t size = a.size();
const typename PrimitiveTypeTraits<A>::ColumnItemType* a_pos = a.data();
const typename PrimitiveTypeTraits<B>::ColumnItemType* b_pos = b.data();
UInt8* c_pos = c.data();
const typename PrimitiveTypeTraits<A>::ColumnItemType* a_end = a_pos + size;
while (a_pos < a_end) {
*c_pos = apply<scale_left, scale_right>(*a_pos, *b_pos, scale);
++a_pos;
++b_pos;
++c_pos;
}
}
template <bool scale_left, bool scale_right>
static void NO_INLINE vector_constant(const ArrayA& a,
typename PrimitiveTypeTraits<B>::ColumnItemType b,
PaddedPODArray<UInt8>& c, CompareInt scale) {
size_t size = a.size();
const typename PrimitiveTypeTraits<A>::ColumnItemType* a_pos = a.data();
UInt8* c_pos = c.data();
const typename PrimitiveTypeTraits<A>::ColumnItemType* a_end = a_pos + size;
while (a_pos < a_end) {
*c_pos = apply<scale_left, scale_right>(*a_pos, b, scale);
++a_pos;
++c_pos;
}
}
template <bool scale_left, bool scale_right>
static void NO_INLINE constant_vector(typename PrimitiveTypeTraits<A>::ColumnItemType a,
const ArrayB& b, PaddedPODArray<UInt8>& c,
CompareInt scale) {
size_t size = b.size();
const typename PrimitiveTypeTraits<B>::ColumnItemType* b_pos = b.data();
UInt8* c_pos = c.data();
const typename PrimitiveTypeTraits<B>::ColumnItemType* b_end = b_pos + size;
while (b_pos < b_end) {
*c_pos = apply<scale_left, scale_right>(a, *b_pos, scale);
++b_pos;
++c_pos;
}
}
};
} // namespace doris::vectorized