| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| #pragma once |
| |
| /// \file iceberg/expression/expression_visitor.h |
| /// Visitor pattern implementation for traversing Iceberg expression trees. |
| |
| #include <concepts> |
| #include <memory> |
| #include <typeinfo> |
| |
| #include "iceberg/expression/aggregate.h" |
| #include "iceberg/expression/expression.h" |
| #include "iceberg/expression/literal.h" |
| #include "iceberg/expression/predicate.h" |
| #include "iceberg/expression/term.h" |
| #include "iceberg/iceberg_export.h" |
| #include "iceberg/result.h" |
| #include "iceberg/util/checked_cast.h" |
| #include "iceberg/util/macros.h" |
| |
| namespace iceberg { |
| |
| /// \brief Base visitor for traversing expression trees. |
| /// |
| /// This visitor traverses expression trees in postorder traversal and calls appropriate |
| /// visitor methods for each node. Subclasses can override specific methods to implement |
| /// custom behavior. |
| /// |
| /// \tparam R The return type produced by visitor methods |
| template <typename R> |
| class ICEBERG_EXPORT ExpressionVisitor { |
| using ParamType = std::conditional_t<std::is_fundamental_v<R>, R, const R&>; |
| |
| public: |
| virtual ~ExpressionVisitor() = default; |
| |
| /// \brief Visit a True expression (always evaluates to true). |
| virtual Result<R> AlwaysTrue() = 0; |
| |
| /// \brief Visit a False expression (always evaluates to false). |
| virtual Result<R> AlwaysFalse() = 0; |
| |
| /// \brief Visit a Not expression. |
| /// \param child_result The result from visiting the child expression |
| virtual Result<R> Not(ParamType child_result) = 0; |
| |
| /// \brief Visit an And expression. |
| /// \param left_result The result from visiting the left child |
| /// \param right_result The result from visiting the right child |
| virtual Result<R> And(ParamType left_result, ParamType right_result) = 0; |
| |
| /// \brief Visit an Or expression. |
| /// \param left_result The result from visiting the left child |
| /// \param right_result The result from visiting the right child |
| virtual Result<R> Or(ParamType left_result, ParamType right_result) = 0; |
| |
| /// \brief Visit a bound predicate. |
| /// \param pred The bound predicate to visit |
| virtual Result<R> Predicate(const std::shared_ptr<BoundPredicate>& pred) = 0; |
| |
| /// \brief Visit an unbound predicate. |
| /// \param pred The unbound predicate to visit |
| virtual Result<R> Predicate(const std::shared_ptr<UnboundPredicate>& pred) = 0; |
| |
| /// \brief Visit a bound aggregate. |
| /// \param aggregate The bound aggregate to visit. |
| virtual Result<R> Aggregate(const std::shared_ptr<BoundAggregate>& aggregate) { |
| ICEBERG_DCHECK(aggregate != nullptr, "Bound aggregate cannot be null"); |
| return NotSupported("Visitor {} does not support bound aggregate", |
| typeid(*this).name()); |
| } |
| |
| /// \brief Visit an unbound aggregate. |
| /// \param aggregate The unbound aggregate to visit. |
| virtual Result<R> Aggregate(const std::shared_ptr<UnboundAggregate>& aggregate) { |
| ICEBERG_DCHECK(aggregate != nullptr, "Unbound aggregate cannot be null"); |
| return NotSupported("Visitor {} does not support unbound aggregate", |
| typeid(*this).name()); |
| } |
| }; |
| |
| /// \brief Visitor for bound expressions. |
| /// |
| /// This visitor is for traversing bound expression trees. |
| /// |
| /// \tparam R The return type produced by visitor methods |
| template <typename R> |
| class ICEBERG_EXPORT BoundVisitor : public ExpressionVisitor<R> { |
| public: |
| ~BoundVisitor() override = default; |
| |
| /// \brief Visit an IS_NULL bound expression. |
| /// \param expr The bound expression being tested |
| virtual Result<R> IsNull(const std::shared_ptr<Bound>& expr) = 0; |
| |
| /// \brief Visit a NOT_NULL bound expression. |
| /// \param expr The bound expression being tested |
| virtual Result<R> NotNull(const std::shared_ptr<Bound>& expr) = 0; |
| |
| /// \brief Visit an IS_NAN bound expression. |
| /// \param expr The bound expression being tested |
| virtual Result<R> IsNaN(const std::shared_ptr<Bound>& expr) { |
| return NotSupported("IsNaN operation is not supported by this visitor"); |
| } |
| |
| /// \brief Visit a NOT_NAN bound expression. |
| /// \param expr The bound expression being tested |
| virtual Result<R> NotNaN(const std::shared_ptr<Bound>& expr) { |
| return NotSupported("NotNaN operation is not supported by this visitor"); |
| } |
| |
| /// \brief Visit a less-than bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to compare against |
| virtual Result<R> Lt(const std::shared_ptr<Bound>& expr, const Literal& lit) = 0; |
| |
| /// \brief Visit a less-than-or-equal bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to compare against |
| virtual Result<R> LtEq(const std::shared_ptr<Bound>& expr, const Literal& lit) = 0; |
| |
| /// \brief Visit a greater-than bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to compare against |
| virtual Result<R> Gt(const std::shared_ptr<Bound>& expr, const Literal& lit) = 0; |
| |
| /// \brief Visit a greater-than-or-equal bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to compare against |
| virtual Result<R> GtEq(const std::shared_ptr<Bound>& expr, const Literal& lit) = 0; |
| |
| /// \brief Visit an equality bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to compare against |
| virtual Result<R> Eq(const std::shared_ptr<Bound>& expr, const Literal& lit) = 0; |
| |
| /// \brief Visit a not-equal bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to compare against |
| virtual Result<R> NotEq(const std::shared_ptr<Bound>& expr, const Literal& lit) = 0; |
| |
| /// \brief Visit a starts-with bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to check for prefix match |
| virtual Result<R> StartsWith([[maybe_unused]] const std::shared_ptr<Bound>& expr, |
| [[maybe_unused]] const Literal& lit) { |
| return NotSupported("StartsWith operation is not supported by this visitor"); |
| } |
| |
| /// \brief Visit a not-starts-with bound expression. |
| /// \param expr The bound expression being tested |
| /// \param lit The literal value to check for prefix match |
| virtual Result<R> NotStartsWith([[maybe_unused]] const std::shared_ptr<Bound>& expr, |
| [[maybe_unused]] const Literal& lit) { |
| return NotSupported("NotStartsWith operation is not supported by this visitor"); |
| } |
| |
| /// \brief Visit an IN set bound expression. |
| /// \param expr The bound expression being tested |
| /// \param literal_set The set of literal values to test membership |
| virtual Result<R> In( |
| [[maybe_unused]] const std::shared_ptr<Bound>& expr, |
| [[maybe_unused]] const BoundSetPredicate::LiteralSet& literal_set) { |
| return NotSupported("In operation is not supported by this visitor"); |
| } |
| |
| /// \brief Visit a NOT_IN set bound expression. |
| /// \param expr The bound expression being tested |
| /// \param literal_set The set of literal values to test membership |
| virtual Result<R> NotIn( |
| [[maybe_unused]] const std::shared_ptr<Bound>& expr, |
| [[maybe_unused]] const BoundSetPredicate::LiteralSet& literal_set) { |
| return NotSupported("NotIn operation is not supported by this visitor"); |
| } |
| |
| /// \brief Visit a bound predicate. |
| /// |
| /// This method dispatches to specific visitor methods based on the predicate |
| /// type and operation. |
| /// |
| /// \param pred The bound predicate to visit |
| Result<R> Predicate(const std::shared_ptr<BoundPredicate>& pred) override { |
| ICEBERG_DCHECK(pred != nullptr, "BoundPredicate cannot be null"); |
| |
| switch (pred->kind()) { |
| case BoundPredicate::Kind::kUnary: { |
| switch (pred->op()) { |
| case Expression::Operation::kIsNull: |
| return IsNull(pred->term()); |
| case Expression::Operation::kNotNull: |
| return NotNull(pred->term()); |
| case Expression::Operation::kIsNan: |
| return IsNaN(pred->term()); |
| case Expression::Operation::kNotNan: |
| return NotNaN(pred->term()); |
| default: |
| return InvalidExpression("Invalid operation for BoundUnaryPredicate: {}", |
| ToString(pred->op())); |
| } |
| } |
| case BoundPredicate::Kind::kLiteral: { |
| const auto& literal_pred = |
| internal::checked_cast<const BoundLiteralPredicate&>(*pred); |
| switch (pred->op()) { |
| case Expression::Operation::kLt: |
| return Lt(pred->term(), literal_pred.literal()); |
| case Expression::Operation::kLtEq: |
| return LtEq(pred->term(), literal_pred.literal()); |
| case Expression::Operation::kGt: |
| return Gt(pred->term(), literal_pred.literal()); |
| case Expression::Operation::kGtEq: |
| return GtEq(pred->term(), literal_pred.literal()); |
| case Expression::Operation::kEq: |
| return Eq(pred->term(), literal_pred.literal()); |
| case Expression::Operation::kNotEq: |
| return NotEq(pred->term(), literal_pred.literal()); |
| case Expression::Operation::kStartsWith: |
| return StartsWith(pred->term(), literal_pred.literal()); |
| case Expression::Operation::kNotStartsWith: |
| return NotStartsWith(pred->term(), literal_pred.literal()); |
| default: |
| return InvalidExpression("Invalid operation for BoundLiteralPredicate: {}", |
| ToString(pred->op())); |
| } |
| } |
| case BoundPredicate::Kind::kSet: { |
| const auto& set_pred = internal::checked_cast<const BoundSetPredicate&>(*pred); |
| switch (pred->op()) { |
| case Expression::Operation::kIn: |
| return In(pred->term(), set_pred.literal_set()); |
| case Expression::Operation::kNotIn: |
| return NotIn(pred->term(), set_pred.literal_set()); |
| default: |
| return InvalidExpression("Invalid operation for BoundSetPredicate: {}", |
| ToString(pred->op())); |
| } |
| } |
| } |
| |
| return InvalidExpression("Unsupported bound predicate: {}", pred->ToString()); |
| } |
| |
| /// \brief Visit an unbound predicate. |
| /// |
| /// \param pred The unbound predicate |
| Result<R> Predicate(const std::shared_ptr<UnboundPredicate>& pred) override { |
| ICEBERG_DCHECK(pred != nullptr, "UnboundPredicate cannot be null"); |
| return NotSupported("Not a bound predicate: {}", pred->ToString()); |
| } |
| }; |
| |
| /// \brief Traverse an expression tree with a visitor. |
| /// |
| /// This function traverses the given expression tree in postorder traversal and calls |
| /// appropriate visitor methods for each node. Results from child nodes are passed to |
| /// parent nodes. |
| /// |
| /// \tparam R The return type produced by the visitor |
| /// \tparam V The visitor type (must derive from ExpressionVisitor<R>) |
| /// \param expr The expression to traverse |
| /// \param visitor The visitor to use for traversal |
| /// \return The result produced by the visitor for the root expression node |
| template <typename R, typename V> |
| requires std::derived_from<V, ExpressionVisitor<R>> |
| Result<R> Visit(const std::shared_ptr<Expression>& expr, V& visitor) { |
| ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null"); |
| |
| if (expr->is_bound_predicate()) { |
| return visitor.Predicate(std::dynamic_pointer_cast<BoundPredicate>(expr)); |
| } |
| |
| if (expr->is_unbound_predicate()) { |
| return visitor.Predicate(std::dynamic_pointer_cast<UnboundPredicate>(expr)); |
| } |
| |
| if (expr->is_bound_aggregate()) { |
| return visitor.Aggregate(std::dynamic_pointer_cast<BoundAggregate>(expr)); |
| } |
| |
| if (expr->is_unbound_aggregate()) { |
| return visitor.Aggregate(std::dynamic_pointer_cast<UnboundAggregate>(expr)); |
| } |
| |
| switch (expr->op()) { |
| case Expression::Operation::kTrue: |
| return visitor.AlwaysTrue(); |
| case Expression::Operation::kFalse: |
| return visitor.AlwaysFalse(); |
| case Expression::Operation::kNot: { |
| const auto& not_expr = internal::checked_pointer_cast<Not>(expr); |
| ICEBERG_ASSIGN_OR_RAISE(auto child_result, |
| (Visit<R, V>(not_expr->child(), visitor))); |
| return visitor.Not(std::move(child_result)); |
| } |
| case Expression::Operation::kAnd: { |
| const auto& and_expr = internal::checked_pointer_cast<And>(expr); |
| ICEBERG_ASSIGN_OR_RAISE(auto left_result, (Visit<R, V>(and_expr->left(), visitor))); |
| ICEBERG_ASSIGN_OR_RAISE(auto right_result, |
| (Visit<R, V>(and_expr->right(), visitor))); |
| return visitor.And(std::move(left_result), std::move(right_result)); |
| } |
| case Expression::Operation::kOr: { |
| const auto& or_expr = internal::checked_pointer_cast<Or>(expr); |
| ICEBERG_ASSIGN_OR_RAISE(auto left_result, (Visit<R, V>(or_expr->left(), visitor))); |
| ICEBERG_ASSIGN_OR_RAISE(auto right_result, |
| (Visit<R, V>(or_expr->right(), visitor))); |
| return visitor.Or(std::move(left_result), std::move(right_result)); |
| } |
| default: |
| return InvalidExpression("Unknown expression operation: {}", expr->ToString()); |
| } |
| } |
| |
| } // namespace iceberg |