| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #pragma once |
| #include <Core/Field.h> |
| #include <DataTypes/IDataType.h> |
| #include <Functions/FunctionFactory.h> |
| #include <Interpreters/ActionsDAG.h> |
| #include <Parser/ParserContext.h> |
| #include <Parser/SerializedPlanParser.h> |
| #include <base/types.h> |
| #include <substrait/algebra.pb.h> |
| #include <Poco/Logger.h> |
| #include <Common/IFactoryWithAliases.h> |
| |
| namespace local_engine |
| { |
| class ExpressionParser; |
| class AggregateFunctionParser |
| { |
| public: |
| /// CommonFunctionInfo is commmon representation for different function types, |
| struct CommonFunctionInfo |
| { |
| /// basic common function informations |
| using Arguments = google::protobuf::RepeatedPtrField<substrait::FunctionArgument>; |
| using SortFields = google::protobuf::RepeatedPtrField<substrait::SortField>; |
| Int32 function_ref; |
| Arguments arguments; |
| substrait::Type output_type; |
| |
| /// Following is for aggregate and window functions. |
| substrait::AggregationPhase phase; |
| SortFields sort_fields; |
| // only be used in aggregate functions at present. |
| substrait::Expression filter; |
| bool is_in_window = false; |
| bool is_aggregate_function = false; |
| bool has_filter = false; |
| |
| CommonFunctionInfo() { function_ref = -1; } |
| |
| CommonFunctionInfo(const substrait::WindowRel::Measure & win_measure) |
| : function_ref(win_measure.measure().function_reference()) |
| , arguments(win_measure.measure().arguments()) |
| , output_type(win_measure.measure().output_type()) |
| , phase(win_measure.measure().phase()) |
| , sort_fields(win_measure.measure().sorts()) |
| { |
| is_in_window = true; |
| is_aggregate_function = true; |
| } |
| |
| CommonFunctionInfo(const substrait::AggregateRel::Measure & agg_measure) |
| : function_ref(agg_measure.measure().function_reference()) |
| , arguments(agg_measure.measure().arguments()) |
| , output_type(agg_measure.measure().output_type()) |
| , phase(agg_measure.measure().phase()) |
| , sort_fields(agg_measure.measure().sorts()) |
| , filter(agg_measure.filter()) |
| { |
| has_filter = agg_measure.has_filter(); |
| is_aggregate_function = true; |
| } |
| }; |
| |
| AggregateFunctionParser(ParserContextPtr parser_context_); |
| |
| virtual ~AggregateFunctionParser(); |
| |
| virtual String getName() const = 0; |
| |
| /// In some special cases, different arguments size or different arguments types may refer to different |
| /// CH function implementation. |
| virtual String getCHFunctionName(const CommonFunctionInfo & func_info) const = 0; |
| |
| /// In most cases, arguments size and types are enough to determine the CH function implementation. |
| /// It is only be used in TypeParser::buildBlockFromNamedStruct |
| /// Users are allowed to modify arg types to make it fit for AggregateFunctionFactory::instance().get(...) in TypeParser::buildBlockFromNamedStruct |
| virtual String getCHFunctionName(DB::DataTypes & args) const = 0; |
| |
| /// Do some preprojections for the function arguments, and return the necessary arguments for the CH function. |
| virtual DB::ActionsDAG::NodeRawConstPtrs |
| parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAG & actions_dag) const; |
| |
| // `PartialMerge` is applied on the merging stages. |
| // `If` is applied when the aggreate function has a filter. This should only happen on the 1st stage. |
| // If no combinator is applied, return (ch_func_name,arg_column_types) |
| virtual std::pair<String, DB::DataTypes> |
| tryApplyCHCombinator(const CommonFunctionInfo & func_info, const String & ch_func_name, const DB::DataTypes & arg_column_types) const; |
| |
| /// Make a postprojection for the function result. |
| virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( |
| const CommonFunctionInfo & func_info, |
| const DB::ActionsDAG::Node * func_node, |
| DB::ActionsDAG & actions_dag, |
| bool with_nullability) const; |
| |
| /// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x). |
| /// 0.5 is the parameter of percentiles function. |
| virtual DB::Array parseFunctionParameters( |
| const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & /*arg_nodes*/, DB::ActionsDAG & /*actions_dag*/) const |
| { |
| return DB::Array(); |
| } |
| |
| /// Return the default parameters of the function. It's useful for creating a default function instance. |
| virtual DB::Array getDefaultFunctionParameters() const { return DB::Array(); } |
| |
| protected: |
| DB::ContextPtr getContext() const { return parser_context->queryContext(); } |
| |
| String getUniqueName(const String & name) const; |
| |
| const DB::ActionsDAG::Node * |
| addColumnToActionsDAG(DB::ActionsDAG & actions_dag, const DB::DataTypePtr & type, const DB::Field & field) const; |
| |
| const DB::ActionsDAG::Node * |
| toFunctionNode(DB::ActionsDAG & action_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const; |
| |
| const DB::ActionsDAG::Node * toFunctionNode( |
| DB::ActionsDAG & action_dag, |
| const String & func_name, |
| const String & result_name, |
| const DB::ActionsDAG::NodeRawConstPtrs & args) const; |
| |
| const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & rel) const; |
| |
| std::pair<DB::DataTypePtr, DB::Field> parseLiteral(const substrait::Expression_Literal & literal) const; |
| |
| const DB::ActionsDAG::Node * convertNanToNullIfNeed( |
| const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAG & actions_dag) const; |
| |
| ParserContextPtr parser_context; |
| std::unique_ptr<ExpressionParser> expression_parser; |
| Poco::Logger * logger = &Poco::Logger::get("AggregateFunctionParserFactory"); |
| }; |
| |
| using AggregateFunctionParserPtr = std::shared_ptr<AggregateFunctionParser>; |
| using AggregateFunctionParserCreator = std::function<AggregateFunctionParserPtr(ParserContextPtr)>; |
| |
| class AggregateFunctionParserFactory : public DB::IFactoryWithAliases<AggregateFunctionParserCreator> |
| { |
| public: |
| using Parsers = std::unordered_map<String, Value>; |
| |
| static AggregateFunctionParserFactory & instance(); |
| |
| void registerAggregateFunctionParser(const String & name, Value value); |
| |
| template <typename Parser> |
| void registerAggregateFunctionParser(const String & name) |
| { |
| auto creator |
| = [](ParserContextPtr parser_context) -> AggregateFunctionParserPtr { return std::make_shared<Parser>(parser_context); }; |
| registerAggregateFunctionParser(name, creator); |
| } |
| |
| AggregateFunctionParserPtr get(const String & name, ParserContextPtr parser_context) const; |
| AggregateFunctionParserPtr tryGet(const String & name, ParserContextPtr parser_context) const; |
| |
| const Parsers & getMap() const override { return parsers; } |
| |
| private: |
| Parsers parsers; |
| |
| /// Always empty |
| Parsers case_insensitive_parsers; |
| |
| const Parsers & getCaseInsensitiveMap() const override { return case_insensitive_parsers; } |
| |
| String getFactoryName() const override { return "AggregateFunctionParserFactory"; } |
| }; |
| |
| template <typename Parser> |
| struct AggregateFunctionParserRegister |
| { |
| AggregateFunctionParserRegister() { AggregateFunctionParserFactory::instance().registerAggregateFunctionParser<Parser>(Parser::name); } |
| }; |
| } |