| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "gandiva/tree_expr_builder.h" |
| |
| #include <iostream> |
| #include <utility> |
| |
| #include "gandiva/decimal_type_util.h" |
| #include "gandiva/function_registry_common.h" |
| #include "gandiva/gandiva_aliases.h" |
| #include "gandiva/node.h" |
| |
| namespace gandiva { |
| |
| #define MAKE_LITERAL(atype, ctype) \ |
| NodePtr TreeExprBuilder::MakeLiteral(ctype value) { \ |
| return std::make_shared<LiteralNode>(atype, LiteralHolder(value), false); \ |
| } |
| |
| MAKE_LITERAL(arrow::boolean(), bool) |
| MAKE_LITERAL(arrow::int8(), int8_t) |
| MAKE_LITERAL(arrow::int16(), int16_t) |
| MAKE_LITERAL(arrow::int32(), int32_t) |
| MAKE_LITERAL(arrow::int64(), int64_t) |
| MAKE_LITERAL(arrow::uint8(), uint8_t) |
| MAKE_LITERAL(arrow::uint16(), uint16_t) |
| MAKE_LITERAL(arrow::uint32(), uint32_t) |
| MAKE_LITERAL(arrow::uint64(), uint64_t) |
| MAKE_LITERAL(arrow::float32(), float) |
| MAKE_LITERAL(arrow::float64(), double) |
| |
| NodePtr TreeExprBuilder::MakeStringLiteral(const std::string& value) { |
| return std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(value), false); |
| } |
| |
| NodePtr TreeExprBuilder::MakeBinaryLiteral(const std::string& value) { |
| return std::make_shared<LiteralNode>(arrow::binary(), LiteralHolder(value), false); |
| } |
| |
| NodePtr TreeExprBuilder::MakeDecimalLiteral(const DecimalScalar128& value) { |
| return std::make_shared<LiteralNode>( |
| arrow::decimal128(value.precision(), value.scale()), LiteralHolder(value), false); |
| } |
| |
| NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) { |
| static const std::string empty; |
| |
| if (data_type == nullptr) { |
| return nullptr; |
| } |
| |
| switch (data_type->id()) { |
| case arrow::Type::BOOL: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder(false), true); |
| case arrow::Type::INT8: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((int8_t)0), true); |
| case arrow::Type::INT16: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((int16_t)0), true); |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((int32_t)0), true); |
| case arrow::Type::UINT8: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint8_t)0), true); |
| case arrow::Type::UINT16: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint16_t)0), true); |
| case arrow::Type::UINT32: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint32_t)0), true); |
| case arrow::Type::UINT64: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint64_t)0), true); |
| case arrow::Type::FLOAT: |
| return std::make_shared<LiteralNode>(data_type, |
| LiteralHolder(static_cast<float>(0)), true); |
| case arrow::Type::DOUBLE: |
| return std::make_shared<LiteralNode>(data_type, |
| LiteralHolder(static_cast<double>(0)), true); |
| case arrow::Type::STRING: |
| case arrow::Type::BINARY: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder(empty), true); |
| case arrow::Type::INT32: |
| case arrow::Type::DATE32: |
| case arrow::Type::TIME32: |
| case arrow::Type::INTERVAL_MONTHS: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((int32_t)0), true); |
| case arrow::Type::INT64: |
| case arrow::Type::DATE64: |
| case arrow::Type::TIME64: |
| case arrow::Type::TIMESTAMP: |
| case arrow::Type::INTERVAL_DAY_TIME: |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder((int64_t)0), true); |
| case arrow::Type::DECIMAL: { |
| std::shared_ptr<arrow::DecimalType> decimal_type = |
| arrow::internal::checked_pointer_cast<arrow::DecimalType>(data_type); |
| DecimalScalar128 literal(decimal_type->precision(), decimal_type->scale()); |
| return std::make_shared<LiteralNode>(data_type, LiteralHolder(literal), true); |
| } |
| default: |
| return nullptr; |
| } |
| } |
| |
| NodePtr TreeExprBuilder::MakeField(FieldPtr field) { |
| return NodePtr(new FieldNode(field)); |
| } |
| |
| NodePtr TreeExprBuilder::MakeFunction(const std::string& name, const NodeVector& params, |
| DataTypePtr result_type) { |
| if (result_type == nullptr) { |
| return nullptr; |
| } |
| return std::make_shared<FunctionNode>(name, params, result_type); |
| } |
| |
| NodePtr TreeExprBuilder::MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node, |
| DataTypePtr result_type) { |
| if (condition == nullptr || then_node == nullptr || else_node == nullptr || |
| result_type == nullptr) { |
| return nullptr; |
| } |
| return std::make_shared<IfNode>(condition, then_node, else_node, result_type); |
| } |
| |
| NodePtr TreeExprBuilder::MakeAnd(const NodeVector& children) { |
| return std::make_shared<BooleanNode>(BooleanNode::AND, children); |
| } |
| |
| NodePtr TreeExprBuilder::MakeOr(const NodeVector& children) { |
| return std::make_shared<BooleanNode>(BooleanNode::OR, children); |
| } |
| |
| // set this to true to print expressions for debugging purposes |
| static bool print_expr = false; |
| |
| ExpressionPtr TreeExprBuilder::MakeExpression(NodePtr root_node, FieldPtr result_field) { |
| if (result_field == nullptr) { |
| return nullptr; |
| } |
| if (print_expr) { |
| std::cout << "Expression: " << root_node->ToString() << "\n"; |
| } |
| return ExpressionPtr(new Expression(root_node, result_field)); |
| } |
| |
| ExpressionPtr TreeExprBuilder::MakeExpression(const std::string& function, |
| const FieldVector& in_fields, |
| FieldPtr out_field) { |
| if (out_field == nullptr) { |
| return nullptr; |
| } |
| std::vector<NodePtr> field_nodes; |
| for (auto& field : in_fields) { |
| auto node = MakeField(field); |
| field_nodes.push_back(node); |
| } |
| auto func_node = MakeFunction(function, field_nodes, out_field->type()); |
| return MakeExpression(func_node, out_field); |
| } |
| |
| ConditionPtr TreeExprBuilder::MakeCondition(NodePtr root_node) { |
| if (root_node == nullptr) { |
| return nullptr; |
| } |
| if (print_expr) { |
| std::cout << "Condition: " << root_node->ToString() << "\n"; |
| } |
| |
| return ConditionPtr(new Condition(root_node)); |
| } |
| |
| ConditionPtr TreeExprBuilder::MakeCondition(const std::string& function, |
| const FieldVector& in_fields) { |
| std::vector<NodePtr> field_nodes; |
| for (auto& field : in_fields) { |
| auto node = MakeField(field); |
| field_nodes.push_back(node); |
| } |
| |
| auto func_node = MakeFunction(function, field_nodes, arrow::boolean()); |
| return ConditionPtr(new Condition(func_node)); |
| } |
| |
| NodePtr TreeExprBuilder::MakeInExpressionDecimal( |
| NodePtr node, std::unordered_set<gandiva::DecimalScalar128>& constants) { |
| int32_t precision = 0; |
| int32_t scale = 0; |
| if (!constants.empty()) { |
| precision = constants.begin()->precision(); |
| scale = constants.begin()->scale(); |
| } |
| return std::make_shared<InExpressionNode<gandiva::DecimalScalar128>>(node, constants, |
| precision, scale); |
| } |
| |
| #define MAKE_IN(NAME, ctype, type) \ |
| NodePtr TreeExprBuilder::MakeInExpression##NAME( \ |
| NodePtr node, const std::unordered_set<ctype>& values) { \ |
| return std::make_shared<InExpressionNode<ctype>>(node, values, type); \ |
| } |
| |
| MAKE_IN(Int32, int32_t, int32()); |
| MAKE_IN(Int64, int64_t, int64()); |
| MAKE_IN(Date32, int32_t, date32()); |
| MAKE_IN(Date64, int64_t, date64()); |
| MAKE_IN(TimeStamp, int64_t, timestamp()); |
| MAKE_IN(Time32, int32_t, time32()); |
| MAKE_IN(Time64, int64_t, time64()); |
| MAKE_IN(Float, float, float32()); |
| MAKE_IN(Double, double, float64()); |
| MAKE_IN(String, std::string, utf8()); |
| MAKE_IN(Binary, std::string, binary()); |
| |
| } // namespace gandiva |