| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include <Functions/FunctionsMiscellaneous.h> |
| #include <Interpreters/ExpressionActions.h> |
| #include <Interpreters/ExpressionActionsSettings.h> |
| #include <Parser/ExpressionParser.h> |
| #include <Parser/FunctionParser.h> |
| #include <Parser/TypeParser.h> |
| #include <Poco/Logger.h> |
| #include <Common/Exception.h> |
| #include <Common/logger_useful.h> |
| |
| #include <unordered_set> |
| |
| namespace DB::ErrorCodes |
| { |
| extern const int LOGICAL_ERROR; |
| } |
| |
| namespace local_engine |
| { |
| DB::NamesAndTypesList collectLambdaArguments(ParserContextPtr parser_context_, const substrait::Expression_ScalarFunction & substrait_func) |
| { |
| DB::NamesAndTypesList lambda_arguments; |
| std::unordered_set<String> collected_names; |
| |
| for (const auto & arg : substrait_func.arguments()) |
| { |
| if (arg.value().has_scalar_function() |
| && parser_context_->getFunctionNameInSignature(arg.value().scalar_function().function_reference()) == "namedlambdavariable") |
| { |
| auto [_, col_name_field] = LiteralParser::parse(arg.value().scalar_function().arguments()[0].value().literal()); |
| String col_name = col_name_field.safeGet<String>(); |
| if (collected_names.contains(col_name)) |
| continue; |
| collected_names.insert(col_name); |
| auto type = TypeParser::parseType(arg.value().scalar_function().output_type()); |
| lambda_arguments.emplace_back(col_name, type); |
| } |
| } |
| return lambda_arguments; |
| } |
| |
| /// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node. |
| class FunctionParserLambda : public FunctionParser |
| { |
| public: |
| static constexpr auto name = "lambdafunction"; |
| explicit FunctionParserLambda(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { } |
| ~FunctionParserLambda() override = default; |
| |
| String getName() const override { return name; } |
| |
| String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override |
| { |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for LambdaFunction"); |
| } |
| |
| const DB::ActionsDAG::Node * |
| parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override |
| { |
| /// Some special cases, for example, `transform(arr, x -> concat(arr, array(x)))` refers to |
| /// a column `arr` out of it directly. We need a `arr` as an input column for `lambda_actions_dag` |
| DB::NamesAndTypesList parent_header; |
| for (const auto * output_node : actions_dag.getOutputs()) |
| parent_header.emplace_back(output_node->result_name, output_node->result_type); |
| DB::ActionsDAG lambda_actions_dag{parent_header}; |
| |
| /// The first argument is the lambda function body, followings are the lambda arguments which is |
| /// needed by the lambda function body. |
| /// There could be a nested lambda function in the lambda function body, and it refer a variable from |
| /// this outside lambda function's arguments. For an example, transform(number, x -> transform(letter, y -> struct(x, y))). |
| /// Before parsing the lambda function body, we add lambda function arguments int actions dag at first. |
| for (size_t i = 1; i < substrait_func.arguments().size(); ++i) |
| (void)parseExpression(lambda_actions_dag, substrait_func.arguments()[i].value()); |
| const auto & substrait_lambda_body = substrait_func.arguments()[0].value(); |
| const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body); |
| lambda_actions_dag.getOutputs().push_back(lambda_body_node); |
| lambda_actions_dag.removeUnusedActions(DB::Names(1, lambda_body_node->result_name)); |
| |
| DB::Names captured_column_names; |
| DB::Names required_column_names = lambda_actions_dag.getRequiredColumnsNames(); |
| DB::ActionsDAG::NodeRawConstPtrs lambda_children; |
| auto lambda_function_args = collectLambdaArguments(parser_context, substrait_func); |
| const auto & lambda_actions_inputs = lambda_actions_dag.getInputs(); |
| |
| std::unordered_map<String, const DB::ActionsDAG::Node *> parent_nodes; |
| for (const auto & node : actions_dag.getNodes()) |
| parent_nodes[node.result_name] = &node; |
| for (const auto & required_column_name : required_column_names) |
| { |
| if (std::find_if( |
| lambda_function_args.begin(), |
| lambda_function_args.end(), |
| [&required_column_name](const DB::NameAndTypePair & name_type) { return name_type.name == required_column_name; }) |
| == lambda_function_args.end()) |
| { |
| auto it = std::find_if( |
| lambda_actions_inputs.begin(), |
| lambda_actions_inputs.end(), |
| [&required_column_name](const auto & node) { return node->result_name == required_column_name; }); |
| if (it == lambda_actions_inputs.end()) |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Required column not found: {}", required_column_name); |
| auto parent_node_it = parent_nodes.find(required_column_name); |
| if (parent_node_it == parent_nodes.end()) |
| { |
| throw DB::Exception( |
| DB::ErrorCodes::LOGICAL_ERROR, |
| "Not found column {} in actions dag:\n{}", |
| required_column_name, |
| actions_dag.dumpDAG()); |
| } |
| /// The nodes must be the ones in `actions_dag`, otherwise `ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the |
| /// same name but their addresses are different. |
| lambda_children.push_back(parent_node_it->second); |
| captured_column_names.push_back(required_column_name); |
| } |
| } |
| auto expression_actions_settings = DB::ExpressionActionsSettings{getContext(), DB::CompileExpressions::yes}; |
| auto function_capture = std::make_shared<DB::FunctionCaptureOverloadResolver>( |
| std::move(lambda_actions_dag), |
| expression_actions_settings, |
| captured_column_names, |
| lambda_function_args, |
| lambda_body_node->result_type, |
| lambda_body_node->result_name, |
| false); |
| |
| const auto * result = &actions_dag.addFunction(function_capture, lambda_children, lambda_body_node->result_name); |
| return result; |
| } |
| |
| protected: |
| DB::ActionsDAG::NodeRawConstPtrs |
| parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override |
| { |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction"); |
| } |
| |
| const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( |
| const substrait::Expression_ScalarFunction & substrait_func, |
| const DB::ActionsDAG::Node * func_node, |
| DB::ActionsDAG & actions_dag) const override |
| { |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); |
| } |
| }; |
| |
| static FunctionParserRegister<FunctionParserLambda> register_lambda_function; |
| |
| |
| class NamedLambdaVariable : public FunctionParser |
| { |
| public: |
| static constexpr auto name = "namedlambdavariable"; |
| explicit NamedLambdaVariable(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { } |
| ~NamedLambdaVariable() override = default; |
| |
| String getName() const override { return name; } |
| |
| String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override |
| { |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for NamedLambdaVariable"); |
| } |
| |
| const DB::ActionsDAG::Node * |
| parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override |
| { |
| auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal()); |
| String col_name = col_name_field.safeGet<String>(); |
| |
| auto type = TypeParser::parseType(substrait_func.output_type()); |
| const auto & inputs = actions_dag.getInputs(); |
| auto it = std::find_if(inputs.begin(), inputs.end(), [&col_name](const auto * node) { return node->result_name == col_name; }); |
| if (it == inputs.end()) |
| return &(actions_dag.addInput(col_name, type)); |
| return *it; |
| } |
| |
| protected: |
| DB::ActionsDAG::NodeRawConstPtrs |
| parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override |
| { |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable"); |
| } |
| |
| const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( |
| const substrait::Expression_ScalarFunction & substrait_func, |
| const DB::ActionsDAG::Node * func_node, |
| DB::ActionsDAG & actions_dag) const override |
| { |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); |
| } |
| }; |
| |
| static FunctionParserRegister<NamedLambdaVariable> register_named_lambda_variable; |
| |
| } |