blob: a709f332b27e9cc8704e1065ff6a51b41e779d2a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SparkFunctionCheckDecimalOverflow.h"
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <Core/DecimalFunctions.h>
#include <Core/callOnTypeIndex.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include "Columns/ColumnsCommon.h"
#include <iostream>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int TYPE_MISMATCH;
extern const int NOT_IMPLEMENTED;
}
}
namespace local_engine
{
using namespace DB;
namespace
{
struct CheckDecimalOverflowSpark
{
static constexpr auto name = "checkDecimalOverflowSpark";
};
struct CheckDecimalOverflowSparkOrNull
{
static constexpr auto name = "checkDecimalOverflowSparkOrNull";
};
enum class CheckExceptionMode: uint8_t
{
Throw, /// Throw exception if value cannot be parsed.
Null /// Return ColumnNullable with NULLs when value cannot be parsed.
};
enum class ScaleDirection: int8_t
{
Up = 1,
Down = -1,
None = 0
};
/// Returns received decimal value if and Decimal value has less digits then it's Precision allow, 0 otherwise.
/// Precision could be set as second argument or omitted. If omitted function uses Decimal precision of the first argument.
template <typename Name, CheckExceptionMode mode>
class FunctionCheckDecimalOverflow : public IFunction
{
public:
static constexpr auto name = Name::name;
static constexpr auto exception_mode = mode;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionCheckDecimalOverflow>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 3; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if ((!isDecimal(arguments[0].type) && !isNativeNumber(arguments[0].type)) || !isInteger(arguments[1].type) || !isInteger(arguments[2].type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} {} {} of argument of function {}",
arguments[0].type->getName(),
arguments[1].type->getName(),
arguments[2].type->getName(),
getName());
UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);
auto return_type = createDecimal<DataTypeDecimal>(precision, scale);
if (isReturnTypeNullable(arguments[0]))
return std::make_shared<DataTypeNullable>(return_type);
return return_type;
}
bool isReturnTypeNullable(const ColumnWithTypeAndName & arg) const
{
if constexpr (exception_mode == CheckExceptionMode::Null)
return true;
if (arg.type->isNullable())
return true;
return false;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
UInt32 to_precision = extractArgument(arguments[1]);
UInt32 to_scale = extractArgument(arguments[2]);
const auto & src_col = arguments[0];
ColumnPtr dst_col;
auto call = [&](const auto & types) -> bool
{
using Types = std::decay_t<decltype(types)>;
using FromDataType = typename Types::LeftType;
using ToDataType = typename Types::RightType;
if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeNumber<FromDataType>)
{
using FromFieldType = typename FromDataType::FieldType;
/// Fast path
if constexpr (IsDataTypeDecimal<FromDataType>)
{
auto from_precision = getDecimalPrecision(*src_col.type);
auto from_scale = getDecimalScale(*src_col.type);
if (from_precision == to_precision && from_scale == to_scale)
{
if (isReturnTypeNullable(arguments[0]))
dst_col = makeNullable(src_col.column);
else
dst_col = src_col.column;
return true;
}
}
if (const ColumnVectorOrDecimal<FromFieldType> * col_vec = checkAndGetColumn<ColumnVectorOrDecimal<FromFieldType>>(src_col.column.get()))
{
executeInternal<FromDataType, ToDataType>(*col_vec, dst_col, input_rows_count, to_precision, to_scale);
return true;
}
}
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column while execute function {}", getName());
};
if (to_precision <= DecimalUtils::max_precision<Decimal32>)
callOnIndexAndDataType<DataTypeDecimal<Decimal32>>(src_col.type->getTypeId(), call);
else if (to_precision <= DecimalUtils::max_precision<Decimal64>)
callOnIndexAndDataType<DataTypeDecimal<Decimal64>>(src_col.type->getTypeId(), call);
else if (to_precision <= DecimalUtils::max_precision<Decimal128>)
callOnIndexAndDataType<DataTypeDecimal<Decimal128>>(src_col.type->getTypeId(), call);
else
callOnIndexAndDataType<DataTypeDecimal<Decimal256>>(src_col.type->getTypeId(), call);
if (!dst_col)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {} with {}", getName(), src_col.type->getName());
return dst_col;
}
private:
template <typename FromDataType, typename ToDataType>
requires(IsDataTypeDecimal<ToDataType> && (IsDataTypeDecimal<FromDataType> || IsDataTypeNumber<FromDataType>))
static void
executeInternal(const FromDataType::ColumnType & src_col, ColumnPtr & dst_col, size_t rows, UInt32 to_precision, UInt32 to_scale)
{
using ToFieldType = typename ToDataType::FieldType;
using ToNativeType = typename ToFieldType::NativeType;
using ToColumnType = typename ToDataType::ColumnType;
using FromFieldType = typename FromDataType::FieldType;
using MaxFieldType = std::conditional_t<
is_decimal<FromFieldType>,
std::conditional_t<(sizeof(FromFieldType) > sizeof(ToFieldType)), FromFieldType, ToFieldType>,
ToFieldType>;
using MaxNativeType = typename MaxFieldType::NativeType;
/// Calculate const parameters for decimal conversion outside the loop to avoid unnecessary calculations.
ScaleDirection scale_direction;
UInt32 from_scale = 0;
MaxNativeType scale_multiplier = 0;
MaxNativeType pow10_to_precision = DecimalUtils::scaleMultiplier<MaxNativeType>(to_precision);
if constexpr (IsDataTypeDecimal<FromDataType>)
{
from_scale = src_col.getScale();
if (to_scale > from_scale)
{
scale_direction = ScaleDirection::Up;
scale_multiplier = DecimalUtils::scaleMultiplier<MaxNativeType>(to_scale - from_scale);
}
else if (to_scale < from_scale)
{
scale_direction = ScaleDirection::Down;
scale_multiplier = DecimalUtils::scaleMultiplier<MaxNativeType>(from_scale - to_scale);
}
else
{
scale_direction = ScaleDirection::None;
scale_multiplier = 1;
}
}
else
{
scale_multiplier = DecimalUtils::scaleMultiplier<MaxNativeType>(to_scale);
}
auto & src_data = src_col.getData();
auto res_data_col = ToColumnType::create(rows, to_scale);
auto & res_data = res_data_col->getData();
auto res_nullmap_col = ColumnUInt8::create(rows, 0);
auto & res_nullmap_data = res_nullmap_col->getData();
if constexpr (IsDataTypeDecimal<FromDataType>)
{
if (scale_direction == ScaleDirection::Up)
for (size_t i = 0; i < rows; ++i)
res_nullmap_data[i] = !convertDecimalToDecimalImpl<ScaleDirection::Up, FromDataType, ToDataType, MaxNativeType>(
src_data[i], scale_multiplier, pow10_to_precision, res_data[i]);
else if (scale_direction == ScaleDirection::Down)
for (size_t i = 0; i < rows; ++i)
res_nullmap_data[i] = !convertDecimalToDecimalImpl<ScaleDirection::Down, FromDataType, ToDataType, MaxNativeType>(
src_data[i], scale_multiplier, pow10_to_precision, res_data[i]);
else
for (size_t i = 0; i < rows; ++i)
res_nullmap_data[i] = !convertDecimalToDecimalImpl<ScaleDirection::None, FromDataType, ToDataType, MaxNativeType>(
src_data[i], scale_multiplier, pow10_to_precision, res_data[i]);
}
else
{
for (size_t i = 0; i < rows; ++i)
res_nullmap_data[i]
= !convertNumberToDecimalImpl<FromDataType, ToDataType>(src_data[i], scale_multiplier, pow10_to_precision, res_data[i]);
}
if constexpr (exception_mode == CheckExceptionMode::Throw)
{
if (!memoryIsZero(res_nullmap_data.data(), 0, rows))
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
dst_col = std::move(res_data_col);
}
else
dst_col = ColumnNullable::create(std::move(res_data_col), std::move(res_nullmap_col));
}
template <typename FromDataType, typename ToDataType>
requires(IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>)
static ALWAYS_INLINE bool convertNumberToDecimalImpl(
const typename FromDataType::FieldType & from,
const typename ToDataType::FieldType::NativeType & scale_multiplier,
const typename ToDataType::FieldType::NativeType & pow10_to_precision,
typename ToDataType::FieldType & to)
{
using FromFieldType = typename FromDataType::FieldType;
using ToNativeType = typename ToDataType::FieldType::NativeType;
bool ok = false;
if constexpr (std::is_floating_point_v<FromFieldType>)
{
/// float to decimal
auto converted = from * static_cast<FromFieldType>(scale_multiplier);
auto float_pow10_to_precision = static_cast<FromFieldType>(pow10_to_precision);
ok = isFinite(from) && converted < float_pow10_to_precision && converted > -float_pow10_to_precision;
to = ok ? static_cast<ToNativeType>(converted) : static_cast<ToNativeType>(0);
}
else
{
/// signed integer to decimal
using MaxNativeType = std::conditional_t<(sizeof(FromFieldType) > sizeof(ToNativeType)), FromFieldType, ToNativeType>;
MaxNativeType converted = 0;
ok = !common::mulOverflow(static_cast<MaxNativeType>(from), static_cast<MaxNativeType>(scale_multiplier), converted) && converted < pow10_to_precision
&& converted > -pow10_to_precision;
to = ok ? static_cast<ToNativeType>(converted) : static_cast<ToNativeType>(0);
}
return ok;
}
template <ScaleDirection scale_direction, typename FromDataType, typename ToDataType, typename MaxNativeType>
requires(IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
static ALWAYS_INLINE bool convertDecimalToDecimalImpl(
const typename FromDataType::FieldType & from,
const MaxNativeType & scale_multiplier,
const MaxNativeType & pow10_to_precision,
typename ToDataType::FieldType & to)
{
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
using ToNativeType = typename ToFieldType::NativeType;
MaxNativeType converted;
bool ok = false;
if constexpr (scale_direction == ScaleDirection::Up)
{
ok = !common::mulOverflow(static_cast<MaxNativeType>(from.value), scale_multiplier, converted)
&& converted < pow10_to_precision && converted > -pow10_to_precision;
}
else if constexpr (scale_direction == ScaleDirection::None)
{
converted = from.value;
ok = converted < pow10_to_precision && converted > -pow10_to_precision;
}
else
{
converted = from.value / scale_multiplier;
ok = converted < pow10_to_precision && converted > -pow10_to_precision;
}
to = ok ? static_cast<ToNativeType>(converted) : static_cast<ToNativeType>(0);
return ok;
}
};
using FunctionCheckDecimalOverflowThrow = FunctionCheckDecimalOverflow<CheckDecimalOverflowSpark, CheckExceptionMode::Throw>;
using FunctionCheckDecimalOverflowOrNull = FunctionCheckDecimalOverflow<CheckDecimalOverflowSparkOrNull, CheckExceptionMode::Null>;
}
REGISTER_FUNCTION(CheckDecimalOverflowSpark)
{
factory.registerFunction<FunctionCheckDecimalOverflowThrow>(FunctionDocumentation{.description = R"(
Check decimal precision is overflow. If overflow throws exception.
)"});
factory.registerFunction<FunctionCheckDecimalOverflowOrNull>(FunctionDocumentation{.description = R"(
Check decimal precision is overflow. If overflow return `NULL`.
)"});
}
}