blob: 7525026b218b45be5215aad5f014f6b7d48d81b1 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package iceberg
import (
"fmt"
"math"
"strings"
"github.com/google/uuid"
)
// BooleanExprVisitor is an interface for recursively visiting the nodes of a
// boolean expression
type BooleanExprVisitor[T any] interface {
VisitTrue() T
VisitFalse() T
VisitNot(childResult T) T
VisitAnd(left, right T) T
VisitOr(left, right T) T
VisitUnbound(UnboundPredicate) T
VisitBound(BoundPredicate) T
}
// BoundBooleanExprVisitor builds on BooleanExprVisitor by adding interface
// methods for visiting bound expressions, because we do casting of literals
// during binding you can assume that the BoundTerm and the Literal passed
// to a method have the same type.
type BoundBooleanExprVisitor[T any] interface {
BooleanExprVisitor[T]
VisitIn(BoundTerm, Set[Literal]) T
VisitNotIn(BoundTerm, Set[Literal]) T
VisitIsNan(BoundTerm) T
VisitNotNan(BoundTerm) T
VisitIsNull(BoundTerm) T
VisitNotNull(BoundTerm) T
VisitEqual(BoundTerm, Literal) T
VisitNotEqual(BoundTerm, Literal) T
VisitGreaterEqual(BoundTerm, Literal) T
VisitGreater(BoundTerm, Literal) T
VisitLessEqual(BoundTerm, Literal) T
VisitLess(BoundTerm, Literal) T
VisitStartsWith(BoundTerm, Literal) T
VisitNotStartsWith(BoundTerm, Literal) T
}
// VisitExpr is a convenience function to use a given visitor to visit all parts of
// a boolean expression in-order. Values returned from the methods are passed to the
// subsequent methods, effectively "bubbling up" the results.
func VisitExpr[T any](expr BooleanExpression, visitor BooleanExprVisitor[T]) (res T, err error) {
defer func() {
if r := recover(); r != nil {
switch e := r.(type) {
case string:
err = fmt.Errorf("error encountered during visitExpr: %s", e)
case error:
err = e
}
}
}()
return visitBoolExpr(expr, visitor), err
}
func visitBoolExpr[T any](e BooleanExpression, visitor BooleanExprVisitor[T]) T {
switch e := e.(type) {
case AlwaysFalse:
return visitor.VisitFalse()
case AlwaysTrue:
return visitor.VisitTrue()
case AndExpr:
left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor)
return visitor.VisitAnd(left, right)
case OrExpr:
left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor)
return visitor.VisitOr(left, right)
case NotExpr:
child := visitBoolExpr(e.child, visitor)
return visitor.VisitNot(child)
case UnboundPredicate:
return visitor.VisitUnbound(e)
case BoundPredicate:
return visitor.VisitBound(e)
}
panic(fmt.Errorf("%w: VisitBooleanExpression type %s", ErrNotImplemented, e))
}
// VisitBoundPredicate uses a BoundBooleanExprVisitor to call the appropriate method
// based on the type of operation in the predicate. This is a convenience function
// for implementing the VisitBound method of a BoundBooleanExprVisitor by simply calling
// iceberg.VisitBoundPredicate(pred, this).
func VisitBoundPredicate[T any](e BoundPredicate, visitor BoundBooleanExprVisitor[T]) T {
switch e.Op() {
case OpIn:
return visitor.VisitIn(e.Term(), e.(BoundSetPredicate).Literals())
case OpNotIn:
return visitor.VisitNotIn(e.Term(), e.(BoundSetPredicate).Literals())
case OpIsNan:
return visitor.VisitIsNan(e.Term())
case OpNotNan:
return visitor.VisitNotNan(e.Term())
case OpIsNull:
return visitor.VisitIsNull(e.Term())
case OpNotNull:
return visitor.VisitNotNull(e.Term())
case OpEQ:
return visitor.VisitEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpNEQ:
return visitor.VisitNotEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpGTEQ:
return visitor.VisitGreaterEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpGT:
return visitor.VisitGreater(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpLTEQ:
return visitor.VisitLessEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpLT:
return visitor.VisitLess(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpStartsWith:
return visitor.VisitStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpNotStartsWith:
return visitor.VisitNotStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal())
}
panic(fmt.Errorf("%w: unhandled bound predicate type: %s", ErrNotImplemented, e))
}
// BindExpr recursively binds each portion of an expression using the provided schema.
// Because the expression can end up being simplified to just AlwaysTrue/AlwaysFalse,
// this returns a BooleanExpression.
func BindExpr(s *Schema, expr BooleanExpression, caseSensitive bool) (BooleanExpression, error) {
return VisitExpr(expr, &bindVisitor{schema: s, caseSensitive: caseSensitive})
}
type bindVisitor struct {
schema *Schema
caseSensitive bool
}
func (*bindVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} }
func (*bindVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} }
func (*bindVisitor) VisitNot(child BooleanExpression) BooleanExpression {
return NewNot(child)
}
func (*bindVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression {
return NewAnd(left, right)
}
func (*bindVisitor) VisitOr(left, right BooleanExpression) BooleanExpression {
return NewOr(left, right)
}
func (b *bindVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression {
expr, err := pred.Bind(b.schema, b.caseSensitive)
if err != nil {
panic(err)
}
return expr
}
func (*bindVisitor) VisitBound(pred BoundPredicate) BooleanExpression {
panic(fmt.Errorf("%w: found already bound predicate: %s", ErrInvalidArgument, pred))
}
// ExpressionEvaluator returns a function which can be used to evaluate a given expression
// as long as a structlike value is passed which operates like and matches the passed in
// schema.
func ExpressionEvaluator(s *Schema, unbound BooleanExpression, caseSensitive bool) (func(structLike) (bool, error), error) {
bound, err := BindExpr(s, unbound, caseSensitive)
if err != nil {
return nil, err
}
return (&exprEvaluator{bound: bound}).Eval, nil
}
type exprEvaluator struct {
bound BooleanExpression
st structLike
}
func (e *exprEvaluator) Eval(st structLike) (bool, error) {
e.st = st
return VisitExpr(e.bound, e)
}
func (e *exprEvaluator) VisitUnbound(UnboundPredicate) bool {
panic("found unbound predicate when evaluating expression")
}
func (e *exprEvaluator) VisitBound(pred BoundPredicate) bool {
return VisitBoundPredicate(pred, e)
}
func (*exprEvaluator) VisitTrue() bool { return true }
func (*exprEvaluator) VisitFalse() bool { return false }
func (*exprEvaluator) VisitNot(child bool) bool { return !child }
func (*exprEvaluator) VisitAnd(left, right bool) bool { return left && right }
func (*exprEvaluator) VisitOr(left, right bool) bool { return left || right }
func (e *exprEvaluator) VisitIn(term BoundTerm, literals Set[Literal]) bool {
v := term.evalToLiteral(e.st)
if !v.Valid {
return false
}
return literals.Contains(v.Val)
}
func (e *exprEvaluator) VisitNotIn(term BoundTerm, literals Set[Literal]) bool {
return !e.VisitIn(term, literals)
}
func (e *exprEvaluator) VisitIsNan(term BoundTerm) bool {
switch term.Type().(type) {
case Float32Type:
v := term.(bound[float32]).eval(e.st)
if !v.Valid {
break
}
return math.IsNaN(float64(v.Val))
case Float64Type:
v := term.(bound[float64]).eval(e.st)
if !v.Valid {
break
}
return math.IsNaN(v.Val)
}
return false
}
func (e *exprEvaluator) VisitNotNan(term BoundTerm) bool {
return !e.VisitIsNan(term)
}
func (e *exprEvaluator) VisitIsNull(term BoundTerm) bool {
return term.evalIsNull(e.st)
}
func (e *exprEvaluator) VisitNotNull(term BoundTerm) bool {
return !term.evalIsNull(e.st)
}
func nullsFirstCmp[T LiteralType](cmp Comparator[T], v1, v2 Optional[T]) int {
if !v1.Valid {
if !v2.Valid {
// both are null
return 0
}
// v1 is null, v2 is not
return -1
}
if !v2.Valid {
return 1
}
return cmp(v1.Val, v2.Val)
}
func typedCmp[T LiteralType](st structLike, term BoundTerm, lit Literal) int {
v := term.(bound[T]).eval(st)
var l Optional[T]
rhs := lit.(TypedLiteral[T])
if lit != nil {
l.Valid = true
l.Val = rhs.Value()
}
return nullsFirstCmp(rhs.Comparator(), v, l)
}
func doCmp(st structLike, term BoundTerm, lit Literal) int {
// we already properly casted and converted everything during binding
// so we can type assert based on the term type
switch term.Type().(type) {
case BooleanType:
return typedCmp[bool](st, term, lit)
case Int32Type:
return typedCmp[int32](st, term, lit)
case Int64Type:
return typedCmp[int64](st, term, lit)
case Float32Type:
return typedCmp[float32](st, term, lit)
case Float64Type:
return typedCmp[float64](st, term, lit)
case DateType:
return typedCmp[Date](st, term, lit)
case TimeType:
return typedCmp[Time](st, term, lit)
case TimestampType, TimestampTzType:
return typedCmp[Timestamp](st, term, lit)
case BinaryType, FixedType:
return typedCmp[[]byte](st, term, lit)
case StringType:
return typedCmp[string](st, term, lit)
case UUIDType:
return typedCmp[uuid.UUID](st, term, lit)
case DecimalType:
return typedCmp[Decimal](st, term, lit)
}
panic(ErrType)
}
func (e *exprEvaluator) VisitEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) == 0
}
func (e *exprEvaluator) VisitNotEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) != 0
}
func (e *exprEvaluator) VisitGreater(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) > 0
}
func (e *exprEvaluator) VisitGreaterEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) >= 0
}
func (e *exprEvaluator) VisitLess(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) < 0
}
func (e *exprEvaluator) VisitLessEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) <= 0
}
func (e *exprEvaluator) VisitStartsWith(term BoundTerm, lit Literal) bool {
var value, prefix string
switch lit.(type) {
case TypedLiteral[string]:
val := term.(bound[string]).eval(e.st)
if !val.Valid {
return false
}
prefix, value = lit.(StringLiteral).Value(), val.Val
case TypedLiteral[[]byte]:
val := term.(bound[[]byte]).eval(e.st)
if !val.Valid {
return false
}
prefix, value = string(lit.(TypedLiteral[[]byte]).Value()), string(val.Val)
}
return strings.HasPrefix(value, prefix)
}
func (e *exprEvaluator) VisitNotStartsWith(term BoundTerm, lit Literal) bool {
return !e.VisitStartsWith(term, lit)
}
// RewriteNotExpr rewrites a boolean expression to remove "Not" nodes from the expression
// tree. This is because Projections assume there are no "not" nodes.
//
// Not nodes will be replaced with simply calling `Negate` on the child in the tree.
func RewriteNotExpr(expr BooleanExpression) (BooleanExpression, error) {
return VisitExpr(expr, rewriteNotVisitor{})
}
type rewriteNotVisitor struct{}
func (rewriteNotVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} }
func (rewriteNotVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} }
func (rewriteNotVisitor) VisitNot(child BooleanExpression) BooleanExpression {
return child.Negate()
}
func (rewriteNotVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression {
return NewAnd(left, right)
}
func (rewriteNotVisitor) VisitOr(left, right BooleanExpression) BooleanExpression {
return NewOr(left, right)
}
func (rewriteNotVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression {
return pred
}
func (rewriteNotVisitor) VisitBound(pred BoundPredicate) BooleanExpression {
return pred
}