| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include <Core/Settings.h> |
| #include <DataTypes/DataTypeNullable.h> |
| #include <DataTypes/DataTypesDecimal.h> |
| #include <Functions/FunctionHelpers.h> |
| #include <Parser/FunctionParser.h> |
| #include <Parser/TypeParser.h> |
| #include <Common/BlockTypeUtils.h> |
| #include <Common/GlutenSettings.h> |
| |
| namespace DB::ErrorCodes |
| { |
| extern const int BAD_ARGUMENTS; |
| extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| } |
| |
| namespace local_engine |
| { |
| using namespace DB; |
| class DecimalType |
| { |
| static constexpr Int32 spark_max_precision = 38; |
| static constexpr Int32 spark_max_scale = 38; |
| static constexpr Int32 minimum_adjusted_scale = 6; |
| |
| static constexpr Int32 chickhouse_max_precision = DB::DataTypeDecimal256::maxPrecision(); |
| static constexpr Int32 chickhouse_max_scale = DB::DataTypeDecimal128::maxPrecision(); |
| |
| public: |
| Int32 precision; |
| Int32 scale; |
| |
| private: |
| static DecimalType bounded_to_click_house(const Int32 precision, const Int32 scale) |
| { |
| return DecimalType(std::min(precision, chickhouse_max_precision), std::min(scale, chickhouse_max_scale)); |
| } |
| |
| public: |
| static DecimalType evalAddSubstractDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) |
| { |
| const Int32 scale = s1; |
| const Int32 precision = scale + std::max(p1 - s1, p2 - s2) + 1; |
| return bounded_to_click_house(precision, scale); |
| } |
| |
| static DecimalType evalDividetDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) |
| { |
| const Int32 scale = std::max(minimum_adjusted_scale, s1 + p2 + 1); |
| const Int32 precision = p1 - s1 + s2 + scale; |
| return bounded_to_click_house(precision, scale); |
| } |
| |
| static DecimalType evalModuloDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) |
| { |
| const Int32 scale = std::max(s1, s2); |
| const Int32 precision = std::min(p1 - s1, p2 - s2) + scale; |
| return bounded_to_click_house(precision, scale); |
| } |
| |
| static DecimalType evalMultiplyDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) |
| { |
| const Int32 scale = s1; |
| const Int32 precision = p1 + p2 + 1; |
| return bounded_to_click_house(precision, scale); |
| } |
| }; |
| |
| class FunctionParserBinaryArithmetic : public FunctionParser |
| { |
| protected: |
| ActionsDAG::NodeRawConstPtrs convertBinaryArithmeticFunDecimalArgs( |
| ActionsDAG & actions_dag, |
| const ActionsDAG::NodeRawConstPtrs & args, |
| const DecimalType & eval_type, |
| const substrait::Expression_ScalarFunction & arithmeticFun) const |
| { |
| const Int32 precision = eval_type.precision; |
| const Int32 scale = eval_type.scale; |
| |
| ActionsDAG::NodeRawConstPtrs new_args; |
| new_args.reserve(args.size()); |
| |
| ActionsDAG::NodeRawConstPtrs cast_args; |
| cast_args.reserve(2); |
| cast_args.emplace_back(args[0]); |
| DataTypePtr ch_type = createDecimal<DataTypeDecimal>(precision, scale); |
| ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type); |
| const String type_name = ch_type->getName(); |
| const DataTypePtr str_type = std::make_shared<DataTypeString>(); |
| const ActionsDAG::Node * type_node |
| = &actions_dag.addColumn(ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name))); |
| cast_args.emplace_back(type_node); |
| const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args); |
| actions_dag.addOrReplaceInOutputs(*cast_node); |
| new_args.emplace_back(cast_node); |
| new_args.emplace_back(args[1]); |
| return new_args; |
| } |
| |
| DecimalType getDecimalType(const DataTypePtr & left, const DataTypePtr & right) const |
| { |
| assert(isDecimal(left) && isDecimal(right)); |
| const Int32 p1 = getDecimalPrecision(*left); |
| const Int32 s1 = getDecimalScale(*left); |
| const Int32 p2 = getDecimalPrecision(*right); |
| const Int32 s2 = getDecimalScale(*right); |
| return internalEvalType(p1, s1, p2, s2); |
| } |
| |
| virtual DecimalType internalEvalType(Int32 p1, Int32 s1, Int32 p2, Int32 s2) const = 0; |
| |
| const ActionsDAG::Node * |
| checkDecimalOverflow(ActionsDAG & actions_dag, const ActionsDAG::Node * func_node, Int32 precision, Int32 scale) const |
| { |
| //TODO: checkDecimalOverflowSpark throw exception per configuration |
| const DB::ActionsDAG::NodeRawConstPtrs overflow_args |
| = {func_node, |
| expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), precision), |
| expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), scale)}; |
| return toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", overflow_args); |
| } |
| |
| virtual const DB::ActionsDAG::Node * createFunctionNode( |
| DB::ActionsDAG & actions_dag, |
| const String & func_name, |
| const DB::ActionsDAG::NodeRawConstPtrs & args, |
| DataTypePtr result_type) const |
| { |
| return toFunctionNode(actions_dag, func_name, args); |
| } |
| |
| public: |
| explicit FunctionParserBinaryArithmetic(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { } |
| const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const override |
| { |
| const auto ch_func_name = getCHFunctionName(substrait_func); |
| auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); |
| |
| if (parsed_args.size() != 2) |
| throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); |
| |
| const auto left_type = DB::removeNullable(parsed_args[0]->result_type); |
| const auto right_type = DB::removeNullable(parsed_args[1]->result_type); |
| const auto result_type = removeNullable(TypeParser::parseType(substrait_func.output_type())); |
| const auto * func_node = createFunctionNode(actions_dag, ch_func_name, parsed_args, result_type); |
| return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); |
| } |
| }; |
| |
| class FunctionParserPlus final : public FunctionParserBinaryArithmetic |
| { |
| public: |
| explicit FunctionParserPlus(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } |
| |
| static constexpr auto name = "add"; |
| String getName() const override { return name; } |
| String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "plus"; } |
| |
| protected: |
| DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override |
| { |
| return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); |
| } |
| |
| const DB::ActionsDAG::Node * createFunctionNode( |
| DB::ActionsDAG & actions_dag, |
| const String & func_name, |
| const DB::ActionsDAG::NodeRawConstPtrs & new_args, |
| DataTypePtr result_type) const override |
| { |
| const auto * left_arg = new_args[0]; |
| const auto * right_arg = new_args[1]; |
| |
| if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) |
| { |
| const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( |
| result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); |
| |
| const auto & settings = parser_context->queryContext()->getSettingsRef(); |
| auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") |
| ? "sparkDecimalPlusEffect" |
| : "sparkDecimalPlus"; |
| |
| return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); |
| } |
| |
| return toFunctionNode(actions_dag, "plus", {left_arg, right_arg}); |
| } |
| }; |
| |
| class FunctionParserMinus final : public FunctionParserBinaryArithmetic |
| { |
| public: |
| explicit FunctionParserMinus(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } |
| |
| static constexpr auto name = "subtract"; |
| String getName() const override { return name; } |
| String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "minus"; } |
| |
| protected: |
| DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override |
| { |
| return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); |
| } |
| |
| const DB::ActionsDAG::Node * createFunctionNode( |
| DB::ActionsDAG & actions_dag, |
| const String & func_name, |
| const DB::ActionsDAG::NodeRawConstPtrs & new_args, |
| DataTypePtr result_type) const override |
| { |
| const auto * left_arg = new_args[0]; |
| const auto * right_arg = new_args[1]; |
| |
| if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) |
| { |
| const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( |
| result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); |
| |
| const auto & settings = parser_context->queryContext()->getSettingsRef(); |
| auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") |
| ? "sparkDecimalMinusEffect" |
| : "sparkDecimalMinus"; |
| |
| return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); |
| } |
| |
| return toFunctionNode(actions_dag, "minus", {left_arg, right_arg}); |
| } |
| }; |
| |
| class FunctionParserMultiply final : public FunctionParserBinaryArithmetic |
| { |
| public: |
| explicit FunctionParserMultiply(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } |
| static constexpr auto name = "multiply"; |
| String getName() const override { return name; } |
| String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "multiply"; } |
| |
| protected: |
| DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override |
| { |
| return DecimalType::evalMultiplyDecimalType(p1, s1, p2, s2); |
| } |
| |
| const DB::ActionsDAG::Node * createFunctionNode( |
| DB::ActionsDAG & actions_dag, |
| const String & func_name, |
| const DB::ActionsDAG::NodeRawConstPtrs & new_args, |
| DataTypePtr result_type) const override |
| { |
| const auto * left_arg = new_args[0]; |
| const auto * right_arg = new_args[1]; |
| |
| if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) |
| { |
| const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( |
| result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); |
| |
| const auto & settings = parser_context->queryContext()->getSettingsRef(); |
| auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") |
| ? "sparkDecimalMultiplyEffect" |
| : "sparkDecimalMultiply"; |
| |
| return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); |
| } |
| |
| return toFunctionNode(actions_dag, "multiply", {left_arg, right_arg}); |
| } |
| }; |
| |
| class FunctionParserModulo final : public FunctionParserBinaryArithmetic |
| { |
| public: |
| explicit FunctionParserModulo(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } |
| static constexpr auto name = "modulus"; |
| String getName() const override { return name; } |
| String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "modulo"; } |
| |
| protected: |
| DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override |
| { |
| return DecimalType::evalModuloDecimalType(p1, s1, p2, s2); |
| } |
| |
| const DB::ActionsDAG::Node * createFunctionNode( |
| DB::ActionsDAG & actions_dag, |
| const String & func_name, |
| const DB::ActionsDAG::NodeRawConstPtrs & new_args, |
| DataTypePtr result_type) const override |
| { |
| const auto * left_arg = new_args[0]; |
| const auto * right_arg = new_args[1]; |
| |
| if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) |
| { |
| const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( |
| result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); |
| |
| const auto & settings = parser_context->queryContext()->getSettingsRef(); |
| auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") |
| ? "sparkDecimalModuloEffect" |
| : "sparkDecimalModulo"; |
| ; |
| return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); |
| } |
| |
| return toFunctionNode(actions_dag, "spark_modulo", {left_arg, right_arg}); |
| } |
| }; |
| |
| class FunctionParserDivide final : public FunctionParserBinaryArithmetic |
| { |
| public: |
| explicit FunctionParserDivide(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } |
| static constexpr auto name = "divide"; |
| String getName() const override { return name; } |
| String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "divide"; } |
| |
| protected: |
| DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override |
| { |
| return DecimalType::evalDividetDecimalType(p1, s1, p2, s2); |
| } |
| |
| const DB::ActionsDAG::Node * createFunctionNode( |
| DB::ActionsDAG & actions_dag, |
| const String & func_name, |
| const DB::ActionsDAG::NodeRawConstPtrs & new_args, |
| DataTypePtr result_type) const override |
| { |
| assert(func_name == name); |
| const auto * left_arg = new_args[0]; |
| const auto * right_arg = new_args[1]; |
| |
| if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) |
| { |
| const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( |
| result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); |
| |
| const auto & settings = parser_context->queryContext()->getSettingsRef(); |
| auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") |
| ? "sparkDecimalDivideEffect" |
| : "sparkDecimalDivide"; |
| ; |
| return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); |
| } |
| |
| return toFunctionNode(actions_dag, "sparkDivide", {left_arg, right_arg}); |
| } |
| }; |
| |
| static FunctionParserRegister<FunctionParserPlus> register_plus; |
| static FunctionParserRegister<FunctionParserMinus> register_minus; |
| static FunctionParserRegister<FunctionParserMultiply> register_mltiply; |
| static FunctionParserRegister<FunctionParserDivide> register_divide; |
| static FunctionParserRegister<FunctionParserModulo> register_modulo; |
| |
| } |