blob: 79f3a3784a3af908c71e27e9109697ca114c6035 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/FunctionsMiscellaneous.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/ExpressionActionsSettings.h>
#include <Parser/ExpressionParser.h>
#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
#include <Poco/Logger.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
#include <unordered_set>
namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace local_engine
{
DB::NamesAndTypesList collectLambdaArguments(ParserContextPtr parser_context_, const substrait::Expression_ScalarFunction & substrait_func)
{
DB::NamesAndTypesList lambda_arguments;
std::unordered_set<String> collected_names;
for (const auto & arg : substrait_func.arguments())
{
if (arg.value().has_scalar_function()
&& parser_context_->getFunctionNameInSignature(arg.value().scalar_function().function_reference()) == "namedlambdavariable")
{
auto [_, col_name_field] = LiteralParser::parse(arg.value().scalar_function().arguments()[0].value().literal());
String col_name = col_name_field.safeGet<String>();
if (collected_names.contains(col_name))
continue;
collected_names.insert(col_name);
auto type = TypeParser::parseType(arg.value().scalar_function().output_type());
lambda_arguments.emplace_back(col_name, type);
}
}
return lambda_arguments;
}
/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node.
class FunctionParserLambda : public FunctionParser
{
public:
static constexpr auto name = "lambdafunction";
explicit FunctionParserLambda(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { }
~FunctionParserLambda() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for LambdaFunction");
}
const DB::ActionsDAG::Node *
parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
/// Some special cases, for example, `transform(arr, x -> concat(arr, array(x)))` refers to
/// a column `arr` out of it directly. We need a `arr` as an input column for `lambda_actions_dag`
DB::NamesAndTypesList parent_header;
for (const auto * output_node : actions_dag.getOutputs())
parent_header.emplace_back(output_node->result_name, output_node->result_type);
DB::ActionsDAG lambda_actions_dag{parent_header};
/// The first argument is the lambda function body, followings are the lambda arguments which is
/// needed by the lambda function body.
/// There could be a nested lambda function in the lambda function body, and it refer a variable from
/// this outside lambda function's arguments. For an example, transform(number, x -> transform(letter, y -> struct(x, y))).
/// Before parsing the lambda function body, we add lambda function arguments int actions dag at first.
for (size_t i = 1; i < substrait_func.arguments().size(); ++i)
(void)parseExpression(lambda_actions_dag, substrait_func.arguments()[i].value());
const auto & substrait_lambda_body = substrait_func.arguments()[0].value();
const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body);
lambda_actions_dag.getOutputs().push_back(lambda_body_node);
lambda_actions_dag.removeUnusedActions(DB::Names(1, lambda_body_node->result_name));
DB::Names captured_column_names;
DB::Names required_column_names = lambda_actions_dag.getRequiredColumnsNames();
DB::ActionsDAG::NodeRawConstPtrs lambda_children;
auto lambda_function_args = collectLambdaArguments(parser_context, substrait_func);
const auto & lambda_actions_inputs = lambda_actions_dag.getInputs();
std::unordered_map<String, const DB::ActionsDAG::Node *> parent_nodes;
for (const auto & node : actions_dag.getNodes())
parent_nodes[node.result_name] = &node;
for (const auto & required_column_name : required_column_names)
{
if (std::find_if(
lambda_function_args.begin(),
lambda_function_args.end(),
[&required_column_name](const DB::NameAndTypePair & name_type) { return name_type.name == required_column_name; })
== lambda_function_args.end())
{
auto it = std::find_if(
lambda_actions_inputs.begin(),
lambda_actions_inputs.end(),
[&required_column_name](const auto & node) { return node->result_name == required_column_name; });
if (it == lambda_actions_inputs.end())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Required column not found: {}", required_column_name);
auto parent_node_it = parent_nodes.find(required_column_name);
if (parent_node_it == parent_nodes.end())
{
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"Not found column {} in actions dag:\n{}",
required_column_name,
actions_dag.dumpDAG());
}
/// The nodes must be the ones in `actions_dag`, otherwise `ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the
/// same name but their addresses are different.
lambda_children.push_back(parent_node_it->second);
captured_column_names.push_back(required_column_name);
}
}
auto expression_actions_settings = DB::ExpressionActionsSettings{getContext(), DB::CompileExpressions::yes};
auto function_capture = std::make_shared<DB::FunctionCaptureOverloadResolver>(
std::move(lambda_actions_dag),
expression_actions_settings,
captured_column_names,
lambda_function_args,
lambda_body_node->result_type,
lambda_body_node->result_name,
false);
const auto * result = &actions_dag.addFunction(function_capture, lambda_children, lambda_body_node->result_name);
return result;
}
protected:
DB::ActionsDAG::NodeRawConstPtrs
parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction");
}
const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const substrait::Expression_ScalarFunction & substrait_func,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAG & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
}
};
static FunctionParserRegister<FunctionParserLambda> register_lambda_function;
class NamedLambdaVariable : public FunctionParser
{
public:
static constexpr auto name = "namedlambdavariable";
explicit NamedLambdaVariable(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { }
~NamedLambdaVariable() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for NamedLambdaVariable");
}
const DB::ActionsDAG::Node *
parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal());
String col_name = col_name_field.safeGet<String>();
auto type = TypeParser::parseType(substrait_func.output_type());
const auto & inputs = actions_dag.getInputs();
auto it = std::find_if(inputs.begin(), inputs.end(), [&col_name](const auto * node) { return node->result_name == col_name; });
if (it == inputs.end())
return &(actions_dag.addInput(col_name, type));
return *it;
}
protected:
DB::ActionsDAG::NodeRawConstPtrs
parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable");
}
const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const substrait::Expression_ScalarFunction & substrait_func,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAG & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
}
};
static FunctionParserRegister<NamedLambdaVariable> register_named_lambda_variable;
}