| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from textwrap import dedent |
| from typing import Any, Dict, Optional |
| |
| import pytest |
| |
| from pyiceberg import schema |
| from pyiceberg.expressions.base import Accessor |
| from pyiceberg.files import StructProtocol |
| from pyiceberg.schema import Schema, build_position_accessors |
| from pyiceberg.types import ( |
| BooleanType, |
| FloatType, |
| IntegerType, |
| ListType, |
| MapType, |
| NestedField, |
| StringType, |
| StructType, |
| ) |
| |
| |
| def test_schema_str(table_schema_simple: Schema): |
| """Test casting a schema to a string""" |
| assert str(table_schema_simple) == dedent( |
| """\ |
| table { |
| 1: foo: optional string |
| 2: bar: required int |
| 3: baz: optional boolean |
| }""" |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| "schema_repr, expected_repr", |
| [ |
| ( |
| schema.Schema(NestedField(1, "foo", StringType()), schema_id=1), |
| "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), required=True),), schema_id=1, identifier_field_ids=[])", |
| ), |
| ( |
| schema.Schema(NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), required=False), schema_id=1), |
| "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), required=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), required=False)), schema_id=1, identifier_field_ids=[])", |
| ), |
| ], |
| ) |
| def test_schema_repr(schema_repr: Schema, expected_repr: str): |
| """Test schema representation""" |
| assert repr(schema_repr) == expected_repr |
| |
| |
| def test_schema_raise_on_duplicate_names(): |
| """Test schema representation""" |
| with pytest.raises(ValueError) as exc_info: |
| schema.Schema( |
| NestedField(field_id=1, name="foo", field_type=StringType(), required=False), |
| NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), |
| NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), |
| NestedField(field_id=4, name="baz", field_type=BooleanType(), required=False), |
| schema_id=1, |
| identifier_field_ids=[1], |
| ) |
| |
| assert "Invalid schema, multiple fields for name baz: 3 and 4" in str(exc_info.value) |
| |
| |
| def test_schema_index_by_id_visitor(table_schema_nested): |
| """Test index_by_id visitor function""" |
| index = schema.index_by_id(table_schema_nested) |
| assert index == { |
| 1: NestedField(field_id=1, name="foo", field_type=StringType(), required=False), |
| 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), |
| 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), |
| 4: NestedField( |
| field_id=4, |
| name="qux", |
| field_type=ListType(element_id=5, element_type=StringType(), element_required=True), |
| required=True, |
| ), |
| 5: NestedField(field_id=5, name="element", field_type=StringType(), required=True), |
| 6: NestedField( |
| field_id=6, |
| name="quux", |
| field_type=MapType( |
| key_id=7, |
| key_type=StringType(), |
| value_id=8, |
| value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), |
| value_required=True, |
| ), |
| required=True, |
| ), |
| 7: NestedField(field_id=7, name="key", field_type=StringType(), required=True), |
| 9: NestedField(field_id=9, name="key", field_type=StringType(), required=True), |
| 8: NestedField( |
| field_id=8, |
| name="value", |
| field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), |
| required=True, |
| ), |
| 10: NestedField(field_id=10, name="value", field_type=IntegerType(), required=True), |
| 11: NestedField( |
| field_id=11, |
| name="location", |
| field_type=ListType( |
| element_id=12, |
| element_type=StructType( |
| NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), |
| NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), |
| ), |
| element_required=True, |
| ), |
| required=True, |
| ), |
| 12: NestedField( |
| field_id=12, |
| name="element", |
| field_type=StructType( |
| NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), |
| NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), |
| ), |
| required=True, |
| ), |
| 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), |
| 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), |
| 15: NestedField( |
| field_id=15, |
| name="person", |
| field_type=StructType( |
| NestedField(field_id=16, name="name", field_type=StringType(), required=False), |
| NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), |
| ), |
| required=False, |
| ), |
| 16: NestedField(field_id=16, name="name", field_type=StringType(), required=False), |
| 17: NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), |
| } |
| |
| |
| def test_schema_index_by_name_visitor(table_schema_nested): |
| """Test index_by_name visitor function""" |
| index = schema.index_by_name(table_schema_nested) |
| assert index == { |
| "foo": 1, |
| "bar": 2, |
| "baz": 3, |
| "qux": 4, |
| "qux.element": 5, |
| "quux": 6, |
| "quux.key": 7, |
| "quux.value": 8, |
| "quux.value.key": 9, |
| "quux.value.value": 10, |
| "location": 11, |
| "location.element": 12, |
| "location.element.latitude": 13, |
| "location.element.longitude": 14, |
| "location.latitude": 13, |
| "location.longitude": 14, |
| "person": 15, |
| "person.name": 16, |
| "person.age": 17, |
| } |
| |
| |
| def test_schema_find_column_name(table_schema_nested): |
| """Test finding a column name using its field ID""" |
| assert table_schema_nested.find_column_name(1) == "foo" |
| assert table_schema_nested.find_column_name(2) == "bar" |
| assert table_schema_nested.find_column_name(3) == "baz" |
| assert table_schema_nested.find_column_name(4) == "qux" |
| assert table_schema_nested.find_column_name(5) == "qux.element" |
| assert table_schema_nested.find_column_name(6) == "quux" |
| assert table_schema_nested.find_column_name(7) == "quux.key" |
| assert table_schema_nested.find_column_name(8) == "quux.value" |
| assert table_schema_nested.find_column_name(9) == "quux.value.key" |
| assert table_schema_nested.find_column_name(10) == "quux.value.value" |
| assert table_schema_nested.find_column_name(11) == "location" |
| assert table_schema_nested.find_column_name(12) == "location.element" |
| assert table_schema_nested.find_column_name(13) == "location.element.latitude" |
| assert table_schema_nested.find_column_name(14) == "location.element.longitude" |
| |
| |
| def test_schema_find_column_name_on_id_not_found(table_schema_nested): |
| """Test raising an error when a field ID cannot be found""" |
| assert table_schema_nested.find_column_name(99) is None |
| |
| |
| def test_schema_find_column_name_by_id(table_schema_simple): |
| """Test finding a column name given its field ID""" |
| assert table_schema_simple.find_column_name(1) == "foo" |
| assert table_schema_simple.find_column_name(2) == "bar" |
| assert table_schema_simple.find_column_name(3) == "baz" |
| |
| |
| def test_schema_find_field_by_id(table_schema_simple): |
| """Test finding a column using its field ID""" |
| index = schema.index_by_id(table_schema_simple) |
| |
| column1 = index[1] |
| assert isinstance(column1, NestedField) |
| assert column1.field_id == 1 |
| assert column1.field_type == StringType() |
| assert column1.required is False |
| |
| column2 = index[2] |
| assert isinstance(column2, NestedField) |
| assert column2.field_id == 2 |
| assert column2.field_type == IntegerType() |
| assert column2.required is True |
| |
| column3 = index[3] |
| assert isinstance(column3, NestedField) |
| assert column3.field_id == 3 |
| assert column3.field_type == BooleanType() |
| assert column3.required is False |
| |
| |
| def test_schema_find_field_by_id_raise_on_unknown_field(table_schema_simple): |
| """Test raising when the field ID is not found among columns""" |
| index = schema.index_by_id(table_schema_simple) |
| with pytest.raises(Exception) as exc_info: |
| _ = index[4] |
| assert str(exc_info.value) == "4" |
| |
| |
| def test_schema_find_field_type_by_id(table_schema_simple): |
| """Test retrieving a columns' type using its field ID""" |
| index = schema.index_by_id(table_schema_simple) |
| assert index[1] == NestedField(field_id=1, name="foo", field_type=StringType(), required=False) |
| assert index[2] == NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True) |
| assert index[3] == NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False) |
| |
| |
| def test_index_by_id_schema_visitor(table_schema_nested): |
| """Test the index_by_id function that uses the IndexById schema visitor""" |
| assert schema.index_by_id(table_schema_nested) == { |
| 1: NestedField(field_id=1, name="foo", field_type=StringType(), required=False), |
| 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), |
| 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), |
| 4: NestedField( |
| field_id=4, |
| name="qux", |
| field_type=ListType(element_id=5, element_type=StringType(), element_required=True), |
| required=True, |
| ), |
| 5: NestedField(field_id=5, name="element", field_type=StringType(), required=True), |
| 6: NestedField( |
| field_id=6, |
| name="quux", |
| field_type=MapType( |
| key_id=7, |
| key_type=StringType(), |
| value_id=8, |
| value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), |
| value_required=True, |
| ), |
| required=True, |
| ), |
| 7: NestedField(field_id=7, name="key", field_type=StringType(), required=True), |
| 8: NestedField( |
| field_id=8, |
| name="value", |
| field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), |
| required=True, |
| ), |
| 9: NestedField(field_id=9, name="key", field_type=StringType(), required=True), |
| 10: NestedField(field_id=10, name="value", field_type=IntegerType(), required=True), |
| 11: NestedField( |
| field_id=11, |
| name="location", |
| field_type=ListType( |
| element_id=12, |
| element_type=StructType( |
| NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), |
| NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), |
| ), |
| element_required=True, |
| ), |
| required=True, |
| ), |
| 12: NestedField( |
| field_id=12, |
| name="element", |
| field_type=StructType( |
| NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), |
| NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), |
| ), |
| required=True, |
| ), |
| 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), |
| 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), |
| 15: NestedField( |
| field_id=15, |
| name="person", |
| field_type=StructType( |
| NestedField(field_id=16, name="name", field_type=StringType(), required=False), |
| NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), |
| ), |
| required=False, |
| ), |
| 16: NestedField(field_id=16, name="name", field_type=StringType(), required=False), |
| 17: NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), |
| } |
| |
| |
| def test_index_by_id_schema_visitor_raise_on_unregistered_type(): |
| """Test raising a NotImplementedError when an invalid type is provided to the index_by_id function""" |
| with pytest.raises(NotImplementedError) as exc_info: |
| schema.index_by_id("foo") |
| assert "Cannot visit non-type: foo" in str(exc_info.value) |
| |
| |
| def test_schema_find_field(table_schema_simple): |
| """Test finding a field in a schema""" |
| assert ( |
| table_schema_simple.find_field(1) |
| == table_schema_simple.find_field("foo") |
| == table_schema_simple.find_field("FOO", case_sensitive=False) |
| == NestedField(field_id=1, name="foo", field_type=StringType(), required=False) |
| ) |
| assert ( |
| table_schema_simple.find_field(2) |
| == table_schema_simple.find_field("bar") |
| == table_schema_simple.find_field("BAR", case_sensitive=False) |
| == NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True) |
| ) |
| assert ( |
| table_schema_simple.find_field(3) |
| == table_schema_simple.find_field("baz") |
| == table_schema_simple.find_field("BAZ", case_sensitive=False) |
| == NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False) |
| ) |
| |
| |
| def test_schema_find_type(table_schema_simple): |
| """Test finding the type of a column given its field ID""" |
| assert ( |
| table_schema_simple.find_type(1) |
| == table_schema_simple.find_type("foo") |
| == table_schema_simple.find_type("FOO", case_sensitive=False) |
| == StringType() |
| ) |
| assert ( |
| table_schema_simple.find_type(2) |
| == table_schema_simple.find_type("bar") |
| == table_schema_simple.find_type("BAR", case_sensitive=False) |
| == IntegerType() |
| ) |
| assert ( |
| table_schema_simple.find_type(3) |
| == table_schema_simple.find_type("baz") |
| == table_schema_simple.find_type("BAZ", case_sensitive=False) |
| == BooleanType() |
| ) |
| |
| |
| def test_build_position_accessors(table_schema_nested): |
| accessors = build_position_accessors(table_schema_nested) |
| assert accessors == { |
| 1: Accessor(position=0, inner=None), |
| 2: Accessor(position=1, inner=None), |
| 3: Accessor(position=2, inner=None), |
| 4: Accessor(position=3, inner=None), |
| 6: Accessor(position=4, inner=None), |
| 11: Accessor(position=5, inner=None), |
| 16: Accessor(position=6, inner=Accessor(position=0, inner=None)), |
| 17: Accessor(position=6, inner=Accessor(position=1, inner=None)), |
| } |
| |
| |
| def test_build_position_accessors_with_struct(table_schema_nested: Schema): |
| class TestStruct(StructProtocol): |
| def __init__(self, pos: Optional[Dict[int, Any]] = None): |
| self._pos: Dict[int, Any] = pos or {} |
| |
| def set(self, pos: int, value) -> None: |
| pass |
| |
| def get(self, pos: int) -> Any: |
| return self._pos[pos] |
| |
| accessors = build_position_accessors(table_schema_nested) |
| container = TestStruct({6: TestStruct({0: "name"})}) |
| inner_accessor = accessors.get(16) |
| assert inner_accessor |
| assert inner_accessor.get(container) == "name" |
| |
| |
| def test_serialize_schema(table_schema_simple: Schema): |
| actual = table_schema_simple.json() |
| expected = """{"fields": [{"id": 1, "name": "foo", "type": "string", "required": false}, {"id": 2, "name": "bar", "type": "int", "required": true}, {"id": 3, "name": "baz", "type": "boolean", "required": false}], "schema-id": 1, "identifier-field-ids": [1]}""" |
| assert actual == expected |
| |
| |
| def test_deserialize_schema(table_schema_simple: Schema): |
| actual = Schema.parse_raw( |
| """{"fields": [{"id": 1, "name": "foo", "type": "string", "required": false}, {"id": 2, "name": "bar", "type": "int", "required": true}, {"id": 3, "name": "baz", "type": "boolean", "required": false}], "schema-id": 1, "identifier-field-ids": [1]}""" |
| ) |
| expected = table_schema_simple |
| assert actual == expected |