| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include <Columns/ColumnNullable.h> |
| #include <Core/DecimalFunctions.h> |
| #include <Core/callOnTypeIndex.h> |
| #include <DataTypes/DataTypeNullable.h> |
| #include <DataTypes/DataTypesDecimal.h> |
| #include <Functions/FunctionFactory.h> |
| #include <Functions/FunctionHelpers.h> |
| #include <Functions/IFunction.h> |
| #include <Functions/SparkFunctionCheckDecimalOverflow.h> |
| |
| namespace DB |
| { |
| namespace ErrorCodes |
| { |
| extern const int DECIMAL_OVERFLOW; |
| extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
| extern const int ILLEGAL_COLUMN; |
| extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| extern const int TYPE_MISMATCH; |
| } |
| } |
| |
| |
| namespace local_engine |
| { |
| using namespace DB; |
| |
| struct NameMakeDecimal |
| { |
| static constexpr auto name = "makeDecimalSpark"; |
| }; |
| struct NameMakeDecimalOrNull |
| { |
| static constexpr auto name = "makeDecimalSparkOrNull"; |
| }; |
| |
| enum class ConvertExceptionMode |
| { |
| Throw, /// Throw exception if value cannot be parsed. |
| Null /// Return ColumnNullable with NULLs when value cannot be parsed. |
| }; |
| |
| namespace |
| { |
| /// Create decimal with nested value, precision and scale. Required 3 arguments. |
| /// If overflow, throw exceptions by default. Else use 'orNull' function will return null. |
| template <typename Name, ConvertExceptionMode mode> |
| class FunctionMakeDecimal : public IFunction |
| { |
| public: |
| static constexpr auto name = Name::name; |
| static constexpr auto exception_mode = mode; |
| |
| static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionMakeDecimal>(); } |
| |
| String getName() const override { return name; } |
| size_t getNumberOfArguments() const override { return 3; } |
| bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } |
| bool useDefaultImplementationForConstants() const override { return true; } |
| ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; } |
| |
| DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override |
| { |
| if (!isInteger(arguments[0].type) || !isInteger(arguments[1].type) || !isInteger(arguments[2].type)) |
| throw Exception( |
| ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, |
| "Cannot format {} {} {} as decimal", |
| arguments[0].type->getName(), |
| arguments[1].type->getName(), |
| arguments[2].type->getName()); |
| |
| DataTypePtr res = createDecimal<DataTypeDecimal>(extractArgument(arguments[1]), extractArgument(arguments[2])); |
| if constexpr (exception_mode == ConvertExceptionMode::Null) |
| return std::make_shared<DataTypeNullable>(res); |
| else |
| return res; |
| } |
| |
| ColumnPtr |
| executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override |
| { |
| const auto & unscale_column = arguments[0]; |
| if (!unscale_column.column) |
| throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column while execute function {}", getName()); |
| |
| auto precision_value = extractArgument(arguments[1]); |
| auto scale_value = extractArgument(arguments[2]); |
| |
| if (precision_value <= DecimalUtils::max_precision<Decimal32>) |
| return executeInternal<DataTypeDecimal<Decimal32>>(arguments, result_type, input_rows_count, precision_value, scale_value); |
| else if (precision_value <= DecimalUtils::max_precision<Decimal64>) |
| return executeInternal<DataTypeDecimal<Decimal64>>(arguments, result_type, input_rows_count, precision_value, scale_value); |
| else if (precision_value <= DecimalUtils::max_precision<Decimal128>) |
| return executeInternal<DataTypeDecimal<Decimal128>>(arguments, result_type, input_rows_count, precision_value, scale_value); |
| else |
| return executeInternal<DataTypeDecimal<Decimal256>>(arguments, result_type, input_rows_count, precision_value, scale_value); |
| } |
| |
| private: |
| template <typename DataType> |
| requires(IsDataTypeDecimal<DataType>) |
| static ColumnPtr executeInternal( |
| const ColumnsWithTypeAndName & arguments, |
| const DataTypePtr & result_type, |
| size_t input_rows_count, |
| UInt32 precision_value, |
| UInt32 scale) |
| { |
| auto src_column = arguments[0]; |
| ColumnPtr result_column; |
| |
| auto call = [&](const auto & types) -> bool //-V657 |
| { |
| using Types = std::decay_t<decltype(types)>; |
| using FromDataType = typename Types::LeftType; |
| using ToDataType = typename Types::RightType; |
| |
| if constexpr (IsDataTypeNumber<FromDataType>) |
| { |
| ColumnUInt8::MutablePtr col_null_map_to; |
| ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr; |
| if constexpr (exception_mode == ConvertExceptionMode::Null) |
| { |
| col_null_map_to = ColumnUInt8::create(input_rows_count, false); |
| vec_null_map_to = &col_null_map_to->getData(); |
| } |
| |
| using ToFieldType = typename ToDataType::FieldType; |
| using ToNativeType = typename ToFieldType::NativeType; |
| using ToColumnType = typename ToDataType::ColumnType; |
| using FromFieldType = typename FromDataType::FieldType; |
| typename ToColumnType::MutablePtr col_to = ToColumnType::create(input_rows_count, scale); |
| |
| const auto & vector = typeid_cast<const ColumnVector<FromFieldType> *>(arguments[0].column.get()); |
| auto & vec_to = col_to->getData(); |
| auto & datas = vector->getData(); |
| vec_to.resize_exact(input_rows_count); |
| |
| for (size_t i = 0; i < input_rows_count; ++i) |
| { |
| ToNativeType result; |
| bool convert_result |
| = convertDecimalsFromIntegerImpl<FromFieldType, ToNativeType>(datas[i], result, precision_value); |
| |
| if (convert_result) |
| vec_to[i] = static_cast<ToFieldType>(result); |
| else |
| { |
| if constexpr (exception_mode == ConvertExceptionMode::Null) |
| { |
| vec_to[i] = static_cast<ToFieldType>(0); |
| (*vec_null_map_to)[i] = 1; |
| } |
| else |
| throw Exception( |
| ErrorCodes::ILLEGAL_COLUMN, |
| "Cannot parse {} as {}", |
| src_column.type->getName(), |
| result_type->getName()); |
| } |
| } |
| |
| if constexpr (exception_mode == ConvertExceptionMode::Null) |
| result_column = ColumnNullable::create(std::move(col_to), std::move(col_null_map_to)); |
| else |
| result_column = std::move(col_to); |
| |
| return true; |
| } |
| else |
| return false; |
| }; |
| |
| bool r = callOnIndexAndDataType<DataType>(src_column.type->getTypeId(), call); |
| |
| if (!r) |
| throw Exception( |
| ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", src_column.type->getName(), name); |
| |
| return result_column; |
| } |
| |
| template <typename FromNativeType, typename ToNativeType> |
| static bool convertDecimalsFromIntegerImpl(FromNativeType from, ToNativeType & result, UInt32 precision_value) |
| { |
| Field convert_to = convertNumericTypeImpl<FromNativeType, ToNativeType>(from); |
| if (convert_to.isNull()) |
| { |
| if constexpr (ConvertExceptionMode::Throw == exception_mode) |
| throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Convert overflow"); |
| else |
| return false; |
| } |
| result = static_cast<ToNativeType>(convert_to.safeGet<ToNativeType>()); |
| |
| ToNativeType pow10 = intExp10OfSize<ToNativeType>(precision_value); |
| if ((result < 0 && result <= -pow10) || result >= pow10) |
| { |
| if constexpr (ConvertExceptionMode::Throw == exception_mode) |
| throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Convert overflow"); |
| else |
| return false; |
| } |
| |
| return true; |
| } |
| }; |
| |
| using FunctionMakeDecimalThrow = FunctionMakeDecimal<NameMakeDecimal, ConvertExceptionMode::Throw>; |
| using FunctionMakeDecimalOrNull = FunctionMakeDecimal<NameMakeDecimalOrNull, ConvertExceptionMode::Null>; |
| } |
| |
| REGISTER_FUNCTION(MakeDecimalSpark) |
| { |
| factory.registerFunction<FunctionMakeDecimalThrow>(FunctionDocumentation{.description = R"( |
| Create a decimal value by use nested type. If overflow throws exception. |
| )"}); |
| factory.registerFunction<FunctionMakeDecimalOrNull>(FunctionDocumentation{.description = R"( |
| Create a decimal value by use nested type. If overflow return `NULL`. |
| )"}); |
| } |
| } |