blob: 94bfcf076c05c73275924819d995e1b27bcd3b9c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from typing import Any, List, Set
import pytest
from pyiceberg.conversions import to_bytes
from pyiceberg.expressions import (
AlwaysFalse,
AlwaysTrue,
And,
BooleanExpression,
BoundEqualTo,
BoundGreaterThan,
BoundGreaterThanOrEqual,
BoundIn,
BoundIsNaN,
BoundIsNull,
BoundLessThan,
BoundLessThanOrEqual,
BoundNotEqualTo,
BoundNotIn,
BoundNotNaN,
BoundNotNull,
BoundNotStartsWith,
BoundPredicate,
BoundReference,
BoundStartsWith,
BoundTerm,
EqualTo,
GreaterThan,
GreaterThanOrEqual,
In,
IsNaN,
IsNull,
LessThan,
LessThanOrEqual,
Not,
NotEqualTo,
NotIn,
NotNaN,
NotNull,
NotStartsWith,
Or,
Reference,
StartsWith,
UnboundPredicate,
)
from pyiceberg.expressions.literals import Literal, literal
from pyiceberg.expressions.visitors import (
BindVisitor,
BooleanExpressionVisitor,
BoundBooleanExpressionVisitor,
_ManifestEvalVisitor,
expression_evaluator,
expression_to_plain_format,
rewrite_not,
rewrite_to_dnf,
visit,
visit_bound_predicate,
)
from pyiceberg.manifest import ManifestFile, PartitionFieldSummary
from pyiceberg.schema import Accessor, Schema
from pyiceberg.typedef import Record
from pyiceberg.types import (
DoubleType,
FloatType,
IcebergType,
IntegerType,
NestedField,
PrimitiveType,
StringType,
)
class ExampleVisitor(BooleanExpressionVisitor[List[str]]):
"""A test implementation of a BooleanExpressionVisitor
As this visitor visits each node, it appends an element to a `visit_history` list. This enables testing that a given expression is
visited in an expected order by the `visit` method.
"""
def __init__(self) -> None:
self.visit_history: List[str] = []
def visit_true(self) -> List[str]:
self.visit_history.append("TRUE")
return self.visit_history
def visit_false(self) -> List[str]:
self.visit_history.append("FALSE")
return self.visit_history
def visit_not(self, child_result: List[str]) -> List[str]:
self.visit_history.append("NOT")
return self.visit_history
def visit_and(self, left_result: List[str], right_result: List[str]) -> List[str]:
self.visit_history.append("AND")
return self.visit_history
def visit_or(self, left_result: List[str], right_result: List[str]) -> List[str]:
self.visit_history.append("OR")
return self.visit_history
def visit_unbound_predicate(self, predicate: UnboundPredicate[Any]) -> List[str]:
self.visit_history.append(str(predicate.__class__.__name__).upper())
return self.visit_history
def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> List[str]:
self.visit_history.append(str(predicate.__class__.__name__).upper())
return self.visit_history
class FooBoundBooleanExpressionVisitor(BoundBooleanExpressionVisitor[List[str]]):
"""A test implementation of a BoundBooleanExpressionVisitor
As this visitor visits each node, it appends an element to a `visit_history` list. This enables testing that a given bound expression is
visited in an expected order by the `visit` method.
"""
def __init__(self) -> None:
self.visit_history: List[str] = []
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> List[str]:
self.visit_history.append("IN")
return self.visit_history
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> List[str]:
self.visit_history.append("NOT_IN")
return self.visit_history
def visit_is_nan(self, term: BoundTerm[Any]) -> List[str]:
self.visit_history.append("IS_NAN")
return self.visit_history
def visit_not_nan(self, term: BoundTerm[Any]) -> List[str]:
self.visit_history.append("NOT_NAN")
return self.visit_history
def visit_is_null(self, term: BoundTerm[Any]) -> List[str]:
self.visit_history.append("IS_NULL")
return self.visit_history
def visit_not_null(self, term: BoundTerm[Any]) -> List[str]:
self.visit_history.append("NOT_NULL")
return self.visit_history
def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]: # pylint: disable=redefined-outer-name
self.visit_history.append("EQUAL")
return self.visit_history
def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]: # pylint: disable=redefined-outer-name
self.visit_history.append("NOT_EQUAL")
return self.visit_history
def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]: # pylint: disable=redefined-outer-name
self.visit_history.append("GREATER_THAN_OR_EQUAL")
return self.visit_history
def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]: # pylint: disable=redefined-outer-name
self.visit_history.append("GREATER_THAN")
return self.visit_history
def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]: # pylint: disable=redefined-outer-name
self.visit_history.append("LESS_THAN")
return self.visit_history
def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]: # pylint: disable=redefined-outer-name
self.visit_history.append("LESS_THAN_OR_EQUAL")
return self.visit_history
def visit_true(self) -> List[str]:
self.visit_history.append("TRUE")
return self.visit_history
def visit_false(self) -> List[str]:
self.visit_history.append("FALSE")
return self.visit_history
def visit_not(self, child_result: List[str]) -> List[str]:
self.visit_history.append("NOT")
return self.visit_history
def visit_and(self, left_result: List[str], right_result: List[str]) -> List[str]:
self.visit_history.append("AND")
return self.visit_history
def visit_or(self, left_result: List[str], right_result: List[str]) -> List[str]:
self.visit_history.append("OR")
return self.visit_history
def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]:
self.visit_history.append("STARTS_WITH")
return self.visit_history
def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> List[str]:
self.visit_history.append("NOT_STARTS_WITH")
return self.visit_history
def test_boolean_expression_visitor() -> None:
"""Test post-order traversal of boolean expression visit method"""
expr = And(
Or(Not(EqualTo("a", 1)), Not(NotEqualTo("b", 0)), EqualTo("a", 1), NotEqualTo("b", 0)),
Not(EqualTo("a", 1)),
NotEqualTo("b", 0),
)
visitor = ExampleVisitor()
result = visit(expr, visitor=visitor)
assert result == [
"EQUALTO",
"NOT",
"NOTEQUALTO",
"NOT",
"OR",
"EQUALTO",
"OR",
"NOTEQUALTO",
"OR",
"EQUALTO",
"NOT",
"AND",
"NOTEQUALTO",
"AND",
]
def test_boolean_expression_visit_raise_not_implemented_error() -> None:
"""Test raise NotImplementedError when visiting an unsupported object type"""
visitor = ExampleVisitor()
with pytest.raises(NotImplementedError) as exc_info:
visit("foo", visitor=visitor) # type: ignore
assert str(exc_info.value) == "Cannot visit unsupported expression: foo"
def test_bind_visitor_already_bound(table_schema_simple: Schema) -> None:
bound = BoundEqualTo[str](
term=BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)),
literal=literal("hello"),
)
with pytest.raises(TypeError) as exc_info:
visit(bound, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert (
"Found already bound predicate: BoundEqualTo(term=BoundReference(field=NestedField(field_id=1, name='foo', field_type=StringType(), required=False), accessor=Accessor(position=0,inner=None)), literal=literal('hello'))"
== str(exc_info.value)
)
def test_visit_bound_visitor_unknown_predicate() -> None:
with pytest.raises(TypeError) as exc_info:
visit_bound_predicate({"something"}, FooBoundBooleanExpressionVisitor()) # type: ignore
assert "Unknown predicate: {'something'}" == str(exc_info.value)
def test_always_true_expression_binding(table_schema_simple: Schema) -> None:
"""Test that visiting an always-true expression returns always-true"""
unbound_expression = AlwaysTrue()
bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysTrue()
def test_always_false_expression_binding(table_schema_simple: Schema) -> None:
"""Test that visiting an always-false expression returns always-false"""
unbound_expression = AlwaysFalse()
bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysFalse()
def test_always_false_and_always_true_expression_binding(table_schema_simple: Schema) -> None:
"""Test that visiting both an always-true AND always-false expression returns always-false"""
unbound_expression = And(AlwaysTrue(), AlwaysFalse())
bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysFalse()
def test_always_false_or_always_true_expression_binding(table_schema_simple: Schema) -> None:
"""Test that visiting always-true OR always-false expression returns always-true"""
unbound_expression = Or(AlwaysTrue(), AlwaysFalse())
bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysTrue()
@pytest.mark.parametrize(
"unbound_and_expression,expected_bound_expression",
[
(
And(
In(Reference("foo"), {"foo", "bar"}),
In(Reference("bar"), {1, 2, 3}),
),
And(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("foo"), literal("bar")},
),
BoundIn[int](
BoundReference(
field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
accessor=Accessor(position=1, inner=None),
),
{literal(1), literal(2), literal(3)},
),
),
),
(
And(
In(Reference("foo"), ("bar", "baz")),
In(
Reference("bar"),
(1,),
),
In(
Reference("foo"),
("baz",),
),
),
And(
And(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar"), literal("baz")},
),
BoundEqualTo[int](
BoundReference(
field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
accessor=Accessor(position=1, inner=None),
),
literal(1),
),
),
BoundEqualTo(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal("baz"),
),
),
),
],
)
def test_and_expression_binding(
unbound_and_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema
) -> None:
"""Test that visiting an unbound AND expression with a bind-visitor returns the expected bound expression"""
bound_expression = visit(unbound_and_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
@pytest.mark.parametrize(
"unbound_or_expression,expected_bound_expression",
[
(
Or(
In(Reference("foo"), ("foo", "bar")),
In(Reference("bar"), (1, 2, 3)),
),
Or(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("foo"), literal("bar")},
),
BoundIn[int](
BoundReference(
field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
accessor=Accessor(position=1, inner=None),
),
{literal(1), literal(2), literal(3)},
),
),
),
(
Or(
In(Reference("foo"), ("bar", "baz")),
In(
Reference("foo"),
("bar",),
),
In(
Reference("foo"),
("baz",),
),
),
Or(
Or(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar"), literal("baz")},
),
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar")},
),
),
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("baz")},
),
),
),
(
Or(
AlwaysTrue(),
AlwaysFalse(),
),
AlwaysTrue(),
),
(
Or(
AlwaysTrue(),
AlwaysTrue(),
),
AlwaysTrue(),
),
(
Or(
AlwaysFalse(),
AlwaysFalse(),
),
AlwaysFalse(),
),
],
)
def test_or_expression_binding(
unbound_or_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema
) -> None:
"""Test that visiting an unbound OR expression with a bind-visitor returns the expected bound expression"""
bound_expression = visit(unbound_or_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
@pytest.mark.parametrize(
"unbound_in_expression,expected_bound_expression",
[
(
In(Reference("foo"), ("foo", "bar")),
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("foo"), literal("bar")},
),
),
(
In(Reference("foo"), ("bar", "baz")),
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar"), literal("baz")},
),
),
(
In(
Reference("foo"),
("bar",),
),
BoundEqualTo(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal("bar"),
),
),
],
)
def test_in_expression_binding(
unbound_in_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema
) -> None:
"""Test that visiting an unbound IN expression with a bind-visitor returns the expected bound expression"""
bound_expression = visit(unbound_in_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
@pytest.mark.parametrize(
"unbound_not_expression,expected_bound_expression",
[
(
Not(In(Reference("foo"), ("foo", "bar"))),
Not(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("foo"), literal("bar")},
)
),
),
(
Not(
Or(
In(Reference("foo"), ("foo", "bar")),
In(Reference("foo"), ("foo", "bar", "baz")),
)
),
Not(
Or(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("foo"), literal("bar")},
),
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("foo"), literal("bar"), literal("baz")},
),
),
),
),
],
)
def test_not_expression_binding(
unbound_not_expression: UnboundPredicate[Any], expected_bound_expression: BoundPredicate[Any], table_schema_simple: Schema
) -> None:
"""Test that visiting an unbound NOT expression with a bind-visitor returns the expected bound expression"""
bound_expression = visit(unbound_not_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
def test_bound_boolean_expression_visitor_and_in() -> None:
"""Test visiting an And and In expression with a bound boolean expression visitor"""
bound_expression = And(
BoundIn(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literals={literal("foo"), literal("bar")},
),
BoundIn(
term=BoundReference(
field=NestedField(field_id=2, name="bar", field_type=StringType(), required=False),
accessor=Accessor(position=1, inner=None),
),
literals={literal("baz"), literal("qux")},
),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["IN", "IN", "AND"]
def test_bound_boolean_expression_visitor_or() -> None:
"""Test visiting an Or expression with a bound boolean expression visitor"""
bound_expression = Or(
Not(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("foo"), literal("bar")},
)
),
Not(
BoundIn(
BoundReference(
field=NestedField(field_id=2, name="bar", field_type=StringType(), required=False),
accessor=Accessor(position=1, inner=None),
),
{literal("baz"), literal("qux")},
)
),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["IN", "NOT", "IN", "NOT", "OR"]
def test_bound_boolean_expression_visitor_equal() -> None:
bound_expression = BoundEqualTo(
term=BoundReference(
field=NestedField(field_id=2, name="bar", field_type=StringType(), required=False),
accessor=Accessor(position=1, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["EQUAL"]
def test_bound_boolean_expression_visitor_not_equal() -> None:
bound_expression = BoundNotEqualTo(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["NOT_EQUAL"]
def test_bound_boolean_expression_visitor_always_true() -> None:
bound_expression = AlwaysTrue()
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["TRUE"]
def test_bound_boolean_expression_visitor_always_false() -> None:
bound_expression = AlwaysFalse()
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["FALSE"]
def test_bound_boolean_expression_visitor_in() -> None:
bound_expression = BoundIn(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literals={literal("foo"), literal("bar")},
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["IN"]
def test_bound_boolean_expression_visitor_not_in() -> None:
bound_expression = BoundNotIn(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literals={literal("foo"), literal("bar")},
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["NOT_IN"]
def test_bound_boolean_expression_visitor_is_nan() -> None:
bound_expression = BoundIsNaN(
term=BoundReference(
field=NestedField(field_id=3, name="baz", field_type=FloatType(), required=False),
accessor=Accessor(position=0, inner=None),
),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["IS_NAN"]
def test_bound_boolean_expression_visitor_not_nan() -> None:
bound_expression = BoundNotNaN(
term=BoundReference(
field=NestedField(field_id=3, name="baz", field_type=FloatType(), required=False),
accessor=Accessor(position=0, inner=None),
),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["NOT_NAN"]
def test_bound_boolean_expression_visitor_is_null() -> None:
bound_expression = BoundIsNull(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["IS_NULL"]
def test_bound_boolean_expression_visitor_not_null() -> None:
bound_expression = BoundNotNull(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["NOT_NULL"]
def test_bound_boolean_expression_visitor_greater_than() -> None:
bound_expression = BoundGreaterThan(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["GREATER_THAN"]
def test_bound_boolean_expression_visitor_greater_than_or_equal() -> None:
bound_expression = BoundGreaterThanOrEqual(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["GREATER_THAN_OR_EQUAL"]
def test_bound_boolean_expression_visitor_less_than() -> None:
bound_expression = BoundLessThan(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["LESS_THAN"]
def test_bound_boolean_expression_visitor_less_than_or_equal() -> None:
bound_expression = BoundLessThanOrEqual(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["LESS_THAN_OR_EQUAL"]
def test_bound_boolean_expression_visitor_raise_on_unbound_predicate() -> None:
bound_expression = LessThanOrEqual(
term=Reference("foo"),
literal="foo",
)
visitor = FooBoundBooleanExpressionVisitor()
with pytest.raises(TypeError) as exc_info:
visit(bound_expression, visitor=visitor)
assert "Not a bound predicate" in str(exc_info.value)
def test_bound_boolean_expression_visitor_starts_with() -> None:
bound_expression = BoundStartsWith(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["STARTS_WITH"]
def test_bound_boolean_expression_visitor_not_starts_with() -> None:
bound_expression = BoundNotStartsWith(
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal=literal("foo"),
)
visitor = FooBoundBooleanExpressionVisitor()
result = visit(bound_expression, visitor=visitor)
assert result == ["NOT_STARTS_WITH"]
def _to_byte_buffer(field_type: IcebergType, val: Any) -> bytes:
if not isinstance(field_type, PrimitiveType):
raise ValueError(f"Expected a PrimitiveType, got: {type(field_type)}")
return to_bytes(field_type, val)
def _to_manifest_file(*partitions: PartitionFieldSummary) -> ManifestFile:
"""Helper to create a ManifestFile"""
return ManifestFile(manifest_path="", manifest_length=0, partition_spec_id=0, partitions=partitions)
INT_MIN_VALUE = 30
INT_MAX_VALUE = 79
INT_MIN = _to_byte_buffer(IntegerType(), INT_MIN_VALUE)
INT_MAX = _to_byte_buffer(IntegerType(), INT_MAX_VALUE)
STRING_MIN = _to_byte_buffer(StringType(), "a")
STRING_MAX = _to_byte_buffer(StringType(), "z")
@pytest.fixture
def schema() -> Schema:
return Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "all_nulls_missing_nan", StringType(), required=False),
NestedField(3, "some_nulls", StringType(), required=False),
NestedField(4, "no_nulls", StringType(), required=False),
NestedField(5, "float", FloatType(), required=False),
NestedField(6, "all_nulls_double", DoubleType(), required=False),
NestedField(7, "all_nulls_no_nans", FloatType(), required=False),
NestedField(8, "all_nans", DoubleType(), required=False),
NestedField(9, "both_nan_and_null", FloatType(), required=False),
NestedField(10, "no_nan_or_null", DoubleType(), required=False),
NestedField(11, "all_nulls_missing_nan_float", FloatType(), required=False),
NestedField(12, "all_same_value_or_null", StringType(), required=False),
NestedField(13, "no_nulls_same_value_a", StringType(), required=False),
)
@pytest.fixture
def manifest_no_stats() -> ManifestFile:
return _to_manifest_file()
@pytest.fixture
def manifest() -> ManifestFile:
return _to_manifest_file(
# id
PartitionFieldSummary(
contains_null=False,
contains_nan=None,
lower_bound=INT_MIN,
upper_bound=INT_MAX,
),
# all_nulls_missing_nan
PartitionFieldSummary(
contains_null=True,
contains_nan=None,
lower_bound=None,
upper_bound=None,
),
# some_nulls
PartitionFieldSummary(
contains_null=True,
contains_nan=None,
lower_bound=STRING_MIN,
upper_bound=STRING_MAX,
),
# no_nulls
PartitionFieldSummary(
contains_null=False,
contains_nan=None,
lower_bound=STRING_MIN,
upper_bound=STRING_MAX,
),
# float
PartitionFieldSummary(
contains_null=True,
contains_nan=None,
lower_bound=_to_byte_buffer(FloatType(), 0.0),
upper_bound=_to_byte_buffer(FloatType(), 20.0),
),
# all_nulls_double
PartitionFieldSummary(contains_null=True, contains_nan=None, lower_bound=None, upper_bound=None),
# all_nulls_no_nans
PartitionFieldSummary(
contains_null=True,
contains_nan=False,
lower_bound=None,
upper_bound=None,
),
# all_nans
PartitionFieldSummary(
contains_null=False,
contains_nan=True,
lower_bound=None,
upper_bound=None,
),
# both_nan_and_null
PartitionFieldSummary(
contains_null=True,
contains_nan=True,
lower_bound=None,
upper_bound=None,
),
# no_nan_or_null
PartitionFieldSummary(
contains_null=False,
contains_nan=False,
lower_bound=_to_byte_buffer(FloatType(), 0.0),
upper_bound=_to_byte_buffer(FloatType(), 20.0),
),
# all_nulls_missing_nan_float
PartitionFieldSummary(contains_null=True, contains_nan=None, lower_bound=None, upper_bound=None),
# all_same_value_or_null
PartitionFieldSummary(
contains_null=True,
contains_nan=None,
lower_bound=STRING_MIN,
upper_bound=STRING_MIN,
),
# no_nulls_same_value_a
PartitionFieldSummary(
contains_null=False,
contains_nan=None,
lower_bound=STRING_MIN,
upper_bound=STRING_MIN,
),
)
def test_all_nulls(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval(
manifest
), "Should skip: all nulls column with non-floating type contains all null"
assert _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval(
manifest
), "Should read: no NaN information may indicate presence of NaN value"
assert _ManifestEvalVisitor(schema, NotNull(Reference("some_nulls")), case_sensitive=True).eval(
manifest
), "Should read: column with some nulls contains a non-null value"
assert _ManifestEvalVisitor(schema, NotNull(Reference("no_nulls")), case_sensitive=True).eval(
manifest
), "Should read: non-null column contains a non-null value"
def test_no_nulls(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, IsNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval(
manifest
), "Should read: at least one null value in all null column"
assert _ManifestEvalVisitor(schema, IsNull(Reference("some_nulls")), case_sensitive=True).eval(
manifest
), "Should read: column with some nulls contains a null value"
assert not _ManifestEvalVisitor(schema, IsNull(Reference("no_nulls")), case_sensitive=True).eval(
manifest
), "Should skip: non-null column contains no null values"
assert _ManifestEvalVisitor(schema, IsNull(Reference("both_nan_and_null")), case_sensitive=True).eval(
manifest
), "Should read: both_nan_and_null column contains no null values"
def test_is_nan(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, IsNaN(Reference("float")), case_sensitive=True).eval(
manifest
), "Should read: no information on if there are nan value in float column"
assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_double")), case_sensitive=True).eval(
manifest
), "Should read: no NaN information may indicate presence of NaN value"
assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval(
manifest
), "Should read: no NaN information may indicate presence of NaN value"
assert not _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval(
manifest
), "Should skip: no nan column doesn't contain nan value"
assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nans")), case_sensitive=True).eval(
manifest
), "Should read: all_nans column contains nan value"
assert _ManifestEvalVisitor(schema, IsNaN(Reference("both_nan_and_null")), case_sensitive=True).eval(
manifest
), "Should read: both_nan_and_null column contains nan value"
assert not _ManifestEvalVisitor(schema, IsNaN(Reference("no_nan_or_null")), case_sensitive=True).eval(
manifest
), "Should skip: no_nan_or_null column doesn't contain nan value"
def test_not_nan(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, NotNaN(Reference("float")), case_sensitive=True).eval(
manifest
), "Should read: no information on if there are nan value in float column"
assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_double")), case_sensitive=True).eval(
manifest
), "Should read: all null column contains non nan value"
assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval(
manifest
), "Should read: no_nans column contains non nan value"
assert not _ManifestEvalVisitor(schema, NotNaN(Reference("all_nans")), case_sensitive=True).eval(
manifest
), "Should skip: all nans column doesn't contain non nan value"
assert _ManifestEvalVisitor(schema, NotNaN(Reference("both_nan_and_null")), case_sensitive=True).eval(
manifest
), "Should read: both_nan_and_null nans column contains non nan value"
assert _ManifestEvalVisitor(schema, NotNaN(Reference("no_nan_or_null")), case_sensitive=True).eval(
manifest
), "Should read: no_nan_or_null column contains non nan value"
def test_missing_stats(schema: Schema, manifest_no_stats: ManifestFile) -> None:
expressions: List[BooleanExpression] = [
LessThan(Reference("id"), 5),
LessThanOrEqual(Reference("id"), 30),
EqualTo(Reference("id"), 70),
GreaterThan(Reference("id"), 78),
GreaterThanOrEqual(Reference("id"), 90),
NotEqualTo(Reference("id"), 101),
IsNull(Reference("id")),
NotNull(Reference("id")),
IsNaN(Reference("float")),
NotNaN(Reference("float")),
]
for expr in expressions:
assert _ManifestEvalVisitor(schema, expr, case_sensitive=True).eval(
manifest_no_stats
), f"Should read when missing stats for expr: {expr}"
def test_not(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, Not(LessThan(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(
manifest
), "Should read: not(false)"
assert not _ManifestEvalVisitor(schema, Not(GreaterThan(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(
manifest
), "Should skip: not(true)"
def test_and(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(
schema,
And(
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MIN_VALUE - 30),
),
case_sensitive=True,
).eval(manifest), "Should skip: and(false, true)"
assert not _ManifestEvalVisitor(
schema,
And(
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 1),
),
case_sensitive=True,
).eval(manifest), "Should skip: and(false, false)"
assert _ManifestEvalVisitor(
schema,
And(
GreaterThan(Reference("id"), INT_MIN_VALUE - 25),
LessThanOrEqual(Reference("id"), INT_MIN_VALUE),
),
case_sensitive=True,
).eval(manifest), "Should read: and(true, true)"
def test_or(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(
schema,
Or(
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 1),
),
case_sensitive=True,
).eval(manifest), "Should skip: or(false, false)"
assert _ManifestEvalVisitor(
schema,
Or(
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE - 19),
),
case_sensitive=True,
).eval(manifest), "Should read: or(false, true)"
def test_integer_lt(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (5 < 30)"
assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (30 is not < 30)"
assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: may possible ids"
def test_integer_lt_eq(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (5 < 30)"
assert not _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (29 < 30)"
assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: many possible ids"
def test_integer_gt(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (85 < 79)"
assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (79 is not > 79)"
assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(
manifest
), "Should read: may possible ids"
def test_integer_gt_eq(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (85 < 79)"
assert not _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (80 > 79)"
assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: may possible ids"
def test_integer_eq(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should not read: id below lower bound"
assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should not read: id below lower bound"
assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound"
assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds"
assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound"
assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should not read: id above upper bound"
assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should not read: id above upper bound"
def test_integer_not_eq(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound"
assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds"
assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound"
assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
def test_integer_not_eq_rewritten(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 1)), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE - 4)), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 1)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 6)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
def test_integer_not_eq_rewritten_case_insensitive(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE - 25)), case_sensitive=False).eval(
manifest
), "Should read: id below lower bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE - 1)), case_sensitive=False).eval(
manifest
), "Should read: id below lower bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE)), case_sensitive=False).eval(
manifest
), "Should read: id equal to lower bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE - 4)), case_sensitive=False).eval(
manifest
), "Should read: id between lower and upper bounds"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE)), case_sensitive=False).eval(
manifest
), "Should read: id equal to upper bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE + 1)), case_sensitive=False).eval(
manifest
), "Should read: id above upper bound"
assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE + 6)), case_sensitive=False).eval(
manifest
), "Should read: id above upper bound"
def test_integer_in(schema: Schema, manifest: ManifestFile) -> None:
assert not _ManifestEvalVisitor(
schema, In(Reference("id"), (INT_MIN_VALUE - 25, INT_MIN_VALUE - 24)), case_sensitive=True
).eval(manifest), "Should not read: id below lower bound (5 < 30, 6 < 30)"
assert not _ManifestEvalVisitor(
schema, In(Reference("id"), (INT_MIN_VALUE - 2, INT_MIN_VALUE - 1)), case_sensitive=True
).eval(manifest), "Should not read: id below lower bound (28 < 30, 29 < 30)"
assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MIN_VALUE - 1, INT_MIN_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound (30 == 30)"
assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE - 4, INT_MAX_VALUE - 3)), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"
assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE, INT_MAX_VALUE + 1)), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound (79 == 79)"
assert not _ManifestEvalVisitor(
schema, In(Reference("id"), (INT_MAX_VALUE + 1, INT_MAX_VALUE + 2)), case_sensitive=True
).eval(manifest), "Should not read: id above upper bound (80 > 79, 81 > 79)"
assert not _ManifestEvalVisitor(
schema, In(Reference("id"), (INT_MAX_VALUE + 6, INT_MAX_VALUE + 7)), case_sensitive=True
).eval(manifest), "Should not read: id above upper bound (85 > 79, 86 > 79)"
assert not _ManifestEvalVisitor(schema, In(Reference("all_nulls_missing_nan"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should skip: in on all nulls column"
assert _ManifestEvalVisitor(schema, In(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on some nulls column"
assert _ManifestEvalVisitor(schema, In(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on no nulls column"
def test_integer_not_in(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(
schema, NotIn(Reference("id"), (INT_MIN_VALUE - 25, INT_MIN_VALUE - 24)), case_sensitive=True
).eval(manifest), "Should read: id below lower bound (5 < 30, 6 < 30)"
assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MIN_VALUE - 2, INT_MIN_VALUE - 1)), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound (28 < 30, 29 < 30)"
assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MIN_VALUE - 1, INT_MIN_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound (30 == 30)"
assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE - 4, INT_MAX_VALUE - 3)), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"
assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE, INT_MAX_VALUE + 1)), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound (79 == 79)"
assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE + 1, INT_MAX_VALUE + 2)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound (80 > 79, 81 > 79)"
assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE + 6, INT_MAX_VALUE + 7)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound (85 > 79, 86 > 79)"
assert _ManifestEvalVisitor(schema, NotIn(Reference("all_nulls_missing_nan"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: notIn on no nulls column"
assert _ManifestEvalVisitor(schema, NotIn(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on some nulls column"
assert _ManifestEvalVisitor(schema, NotIn(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on no nulls column"
def test_string_starts_with(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "a"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "aa"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "dddd"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "z"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, StartsWith(Reference("no_nulls"), "a"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert not _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "zzzz"), case_sensitive=False).eval(
manifest
), "Should skip: range doesn't match"
assert not _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "1"), case_sensitive=False).eval(
manifest
), "Should skip: range doesn't match"
def test_string_not_starts_with(schema: Schema, manifest: ManifestFile) -> None:
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "a"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "aa"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "dddd"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "z"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("no_nulls"), "a"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "zzzz"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "1"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("all_same_value_or_null"), "a"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("all_same_value_or_null"), "aa"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("all_same_value_or_null"), "A"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
# Iceberg does not implement SQL's 3-way boolean logic, so the choice of an all null column
# matching is
# by definition in order to surface more values to the query engine to allow it to make its own
# decision.
assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("all_nulls_missing_nan"), "A"), case_sensitive=False).eval(
manifest
), "Should read: range matches"
assert not _ManifestEvalVisitor(schema, NotStartsWith(Reference("no_nulls_same_value_a"), "a"), case_sensitive=False).eval(
manifest
), "Should not read: all values start with the prefix"
def test_rewrite_not_equal_to() -> None:
assert rewrite_not(Not(EqualTo(Reference("x"), 34.56))) == NotEqualTo(Reference("x"), 34.56)
def test_rewrite_not_not_equal_to() -> None:
assert rewrite_not(Not(NotEqualTo(Reference("x"), 34.56))) == EqualTo(Reference("x"), 34.56)
def test_rewrite_not_in() -> None:
assert rewrite_not(Not(In(Reference("x"), (34.56,)))) == NotIn(Reference("x"), (34.56,))
def test_rewrite_and() -> None:
assert rewrite_not(
Not(
And(
EqualTo(Reference("x"), 34.56),
EqualTo(Reference("y"), 34.56),
)
)
) == Or(
NotEqualTo(term=Reference(name="x"), literal=34.56),
NotEqualTo(term=Reference(name="y"), literal=34.56),
)
def test_rewrite_or() -> None:
assert rewrite_not(
Not(
Or(
EqualTo(Reference("x"), 34.56),
EqualTo(Reference("y"), 34.56),
)
)
) == And(
NotEqualTo(term=Reference(name="x"), literal=34.56),
NotEqualTo(term=Reference(name="y"), literal=34.56),
)
def test_rewrite_always_false() -> None:
assert rewrite_not(Not(AlwaysFalse())) == AlwaysTrue()
def test_rewrite_always_true() -> None:
assert rewrite_not(Not(AlwaysTrue())) == AlwaysFalse()
def test_rewrite_bound() -> None:
schema = Schema(NestedField(2, "a", IntegerType(), required=False), schema_id=1)
assert rewrite_not(IsNull(Reference("a")).bind(schema)) == BoundIsNull(
term=BoundReference(
field=NestedField(field_id=2, name="a", field_type=IntegerType(), required=False),
accessor=Accessor(position=0, inner=None),
)
)
def test_to_dnf() -> None:
expr = Or(Not(EqualTo("P", "a")), And(EqualTo("Q", "b"), Not(Or(Not(EqualTo("R", "c")), EqualTo("S", "d")))))
assert rewrite_to_dnf(expr) == (NotEqualTo("P", "a"), And(EqualTo("Q", "b"), And(EqualTo("R", "c"), NotEqualTo("S", "d"))))
def test_to_dnf_nested_or() -> None:
expr = Or(EqualTo("P", "a"), And(EqualTo("Q", "b"), Or(EqualTo("R", "c"), EqualTo("S", "d"))))
assert rewrite_to_dnf(expr) == (
EqualTo("P", "a"),
And(EqualTo("Q", "b"), EqualTo("R", "c")),
And(EqualTo("Q", "b"), EqualTo("S", "d")),
)
def test_to_dnf_double_distribution() -> None:
expr = And(Or(EqualTo("P", "a"), EqualTo("Q", "b")), Or(EqualTo("R", "c"), EqualTo("S", "d")))
assert rewrite_to_dnf(expr) == (
And(
left=EqualTo(term=Reference(name="P"), literal=literal("a")),
right=EqualTo(term=Reference(name="R"), literal=literal("c")),
),
And(
left=EqualTo(term=Reference(name="P"), literal=literal("a")),
right=EqualTo(term=Reference(name="S"), literal=literal("d")),
),
And(
left=EqualTo(term=Reference(name="Q"), literal=literal("b")),
right=EqualTo(term=Reference(name="R"), literal=literal("c")),
),
And(
left=EqualTo(term=Reference(name="Q"), literal=literal("b")),
right=EqualTo(term=Reference(name="S"), literal=literal("d")),
),
)
def test_to_dnf_double_negation() -> None:
expr = rewrite_to_dnf(Not(Not(Not(Not(Not(Not(EqualTo("P", "a"))))))))
assert expr == (EqualTo("P", "a"),)
def test_to_dnf_and() -> None:
expr = And(Not(EqualTo("Q", "b")), EqualTo("R", "c"))
assert rewrite_to_dnf(expr) == (And(NotEqualTo("Q", "b"), EqualTo("R", "c")),)
def test_to_dnf_not_and() -> None:
expr = Not(And(Not(EqualTo("Q", "b")), EqualTo("R", "c")))
assert rewrite_to_dnf(expr) == (EqualTo("Q", "b"), NotEqualTo("R", "c"))
def test_dnf_to_dask(table_schema_simple: Schema) -> None:
expr = (
BoundGreaterThan[str](
term=BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)),
literal=literal("hello"),
),
And(
BoundIn[int](
term=BoundReference(table_schema_simple.find_field(2), table_schema_simple.accessor_for_field(2)),
literals={literal(1), literal(2), literal(3)},
),
BoundEqualTo[bool](
term=BoundReference(table_schema_simple.find_field(3), table_schema_simple.accessor_for_field(3)),
literal=literal(True),
),
),
)
assert expression_to_plain_format(expr) == [[("foo", ">", "hello")], [("bar", "in", {1, 2, 3}), ("baz", "==", True)]]
def test_expression_evaluator_null() -> None:
struct = Record(a=None)
schema = Schema(NestedField(1, "a", IntegerType(), required=False), schema_id=1)
assert expression_evaluator(schema, In("a", {1, 2, 3}), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, NotIn("a", {1, 2, 3}), case_sensitive=True)(struct) is True
assert expression_evaluator(schema, IsNaN("a"), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, NotNaN("a"), case_sensitive=True)(struct) is True
assert expression_evaluator(schema, IsNull("a"), case_sensitive=True)(struct) is True
assert expression_evaluator(schema, NotNull("a"), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, EqualTo("a", 1), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, NotEqualTo("a", 1), case_sensitive=True)(struct) is True
assert expression_evaluator(schema, GreaterThanOrEqual("a", 1), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, GreaterThan("a", 1), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, LessThanOrEqual("a", 1), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, LessThan("a", 1), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, StartsWith("a", 1), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True