blob: 26b44e8e23761b2c6ea42a84b744dacc85c33a98 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tempfile import TemporaryDirectory
from typing import Optional
import pytest
from pydantic import Field
from pyiceberg.avro.file import AvroFile
from pyiceberg.avro.reader import (
DecimalReader,
DefaultReader,
DoubleReader,
FloatReader,
IntegerReader,
MapReader,
OptionReader,
StringReader,
StructReader,
)
from pyiceberg.avro.resolver import resolve_reader, resolve_writer
from pyiceberg.avro.writer import (
BinaryWriter,
DefaultWriter,
DoubleWriter,
IntegerWriter,
ListWriter,
MapWriter,
OptionWriter,
StringWriter,
StructWriter,
)
from pyiceberg.exceptions import ResolveError
from pyiceberg.io.pyarrow import PyArrowFileIO
from pyiceberg.manifest import MANIFEST_ENTRY_SCHEMAS
from pyiceberg.schema import Schema
from pyiceberg.typedef import Record
from pyiceberg.types import (
BinaryType,
BooleanType,
DecimalType,
DoubleType,
FloatType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
StringType,
StructType,
)
def test_resolver() -> None:
write_schema = Schema(
NestedField(1, "id", LongType()),
NestedField(2, "data", StringType()),
NestedField(
3,
"location",
StructType(
NestedField(4, "lat", DoubleType()),
NestedField(5, "long", DoubleType()),
),
),
NestedField(6, "preferences", MapType(7, StringType(), 8, StringType())),
schema_id=1,
)
location_struct = StructType(
NestedField(4, "lat", DoubleType()),
NestedField(5, "long", DoubleType()),
)
read_schema = Schema(
NestedField(
3,
"location",
location_struct,
),
NestedField(1, "id", LongType(), required=False),
NestedField(6, "preferences", MapType(7, StringType(), 8, StringType())),
schema_id=1,
)
read_tree = resolve_reader(write_schema, read_schema)
assert read_tree == StructReader(
(
(1, OptionReader(option=IntegerReader())),
(None, OptionReader(option=StringReader())),
(
0,
OptionReader(
option=StructReader(
(
(0, OptionReader(option=DoubleReader())),
(1, OptionReader(option=DoubleReader())),
),
Record,
location_struct,
),
),
),
(2, OptionReader(option=MapReader(StringReader(), StringReader()))),
),
Record,
read_schema.as_struct(),
)
def test_resolver_new_required_field() -> None:
write_schema = Schema(
NestedField(1, "id", LongType()),
schema_id=1,
)
read_schema = Schema(
NestedField(1, "id", LongType()),
NestedField(2, "data", StringType(), required=True),
schema_id=1,
)
with pytest.raises(ResolveError) as exc_info:
resolve_reader(write_schema, read_schema)
assert "2: data: required string is non-optional, and not part of the file schema" in str(exc_info.value)
def test_resolver_invalid_evolution() -> None:
write_schema = Schema(
NestedField(1, "id", LongType()),
schema_id=1,
)
read_schema = Schema(
NestedField(1, "id", DoubleType()),
schema_id=1,
)
with pytest.raises(ResolveError) as exc_info:
resolve_reader(write_schema, read_schema)
assert "Cannot promote long to double" in str(exc_info.value)
def test_resolver_promotion_string_to_binary() -> None:
write_schema = Schema(
NestedField(1, "id", StringType()),
schema_id=1,
)
read_schema = Schema(
NestedField(1, "id", BinaryType()),
schema_id=1,
)
resolve_reader(write_schema, read_schema)
def test_resolver_promotion_binary_to_string() -> None:
write_schema = Schema(
NestedField(1, "id", BinaryType()),
schema_id=1,
)
read_schema = Schema(
NestedField(1, "id", StringType()),
schema_id=1,
)
resolve_reader(write_schema, read_schema)
def test_resolver_change_type() -> None:
write_schema = Schema(
NestedField(1, "properties", ListType(2, StringType())),
schema_id=1,
)
read_schema = Schema(
NestedField(1, "properties", MapType(2, StringType(), 3, StringType())),
schema_id=1,
)
with pytest.raises(ResolveError) as exc_info:
resolve_reader(write_schema, read_schema)
assert "File/read schema are not aligned for list, got map<string, string>" in str(exc_info.value)
def test_resolve_int_to_long() -> None:
assert resolve_reader(IntegerType(), LongType()) == IntegerReader()
def test_resolve_float_to_double() -> None:
# We should still read floats, because it is encoded in 4 bytes
assert resolve_reader(FloatType(), DoubleType()) == FloatReader()
def test_resolve_decimal_to_decimal() -> None:
# DecimalType(P, S) to DecimalType(P2, S) where P2 > P
assert resolve_reader(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(19, 25)
def test_struct_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(StructType(), StringType())
def test_map_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(MapType(1, StringType(), 2, IntegerType()), StringType())
def test_primitive_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(IntegerType(), MapType(1, StringType(), 2, IntegerType()))
def test_integer_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(IntegerType(), StringType())
def test_float_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(FloatType(), StringType())
def test_string_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(StringType(), FloatType())
def test_binary_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(BinaryType(), FloatType())
def test_decimal_not_aligned() -> None:
with pytest.raises(ResolveError):
assert resolve_reader(DecimalType(22, 19), StringType())
def test_resolve_decimal_to_decimal_reduce_precision() -> None:
# DecimalType(P, S) to DecimalType(P2, S) where P2 > P
with pytest.raises(ResolveError) as exc_info:
_ = resolve_reader(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25)
assert "Cannot reduce precision from decimal(19, 25) to decimal(10, 25)" in str(exc_info.value)
def test_column_assignment() -> None:
int_schema = {
"type": "record",
"name": "ints",
"fields": [
{"name": "a", "type": "int", "field-id": 1},
{"name": "b", "type": "int", "field-id": 2},
{"name": "c", "type": "int", "field-id": 3},
],
}
from fastavro import parse_schema, writer
parsed_schema = parse_schema(int_schema)
int_records = [
{
"a": 1,
"b": 2,
"c": 3,
}
]
with TemporaryDirectory() as tmpdir:
tmp_avro_file = tmpdir + "/manifest.avro"
with open(tmp_avro_file, "wb") as out:
writer(out, parsed_schema, int_records)
class Ints(Record):
c: int = Field()
d: Optional[int] = Field()
ints_schema = Schema(
NestedField(3, "c", IntegerType(), required=True),
NestedField(4, "d", IntegerType(), required=False),
)
with AvroFile[Ints](PyArrowFileIO().new_input(tmp_avro_file), ints_schema, {-1: Ints}) as reader:
records = list(reader)
assert repr(records) == "[Ints[3, None]]"
def test_resolver_initial_value() -> None:
write_schema = Schema(
NestedField(1, "name", StringType()),
schema_id=1,
)
read_schema = Schema(
NestedField(2, "something", StringType(), required=False, initial_default="vo"),
schema_id=2,
)
assert resolve_reader(write_schema, read_schema) == StructReader(
(
(None, OptionReader(option=StringReader())), # The one we skip
(0, DefaultReader("vo")),
),
Record,
read_schema.as_struct(),
)
def test_resolve_writer() -> None:
actual = resolve_writer(record_schema=MANIFEST_ENTRY_SCHEMAS[2], file_schema=MANIFEST_ENTRY_SCHEMAS[1])
expected = StructWriter(
(
(0, IntegerWriter()),
(1, IntegerWriter()),
(
4,
StructWriter(
(
(1, StringWriter()),
(2, StringWriter()),
(3, StructWriter(())),
(4, IntegerWriter()),
(5, IntegerWriter()),
(None, DefaultWriter(writer=IntegerWriter(), value=67108864)),
(6, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))),
(7, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))),
(8, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))),
(9, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))),
(10, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))),
(11, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))),
(12, OptionWriter(option=BinaryWriter())),
(13, OptionWriter(option=ListWriter(element_writer=IntegerWriter()))),
(15, OptionWriter(option=IntegerWriter())),
)
),
),
)
)
assert actual == expected
def test_resolve_writer_promotion() -> None:
with pytest.raises(ResolveError) as exc_info:
_ = resolve_writer(
record_schema=Schema(NestedField(field_id=1, name="floating", type=DoubleType(), required=True)),
file_schema=Schema(NestedField(field_id=1, name="floating", type=FloatType(), required=True)),
)
assert "Cannot promote double to float" in str(exc_info.value)
def test_writer_ordering() -> None:
actual = resolve_writer(
record_schema=Schema(
NestedField(field_id=1, name="str", type=StringType(), required=True),
NestedField(field_id=2, name="dbl", type=DoubleType(), required=True),
),
file_schema=Schema(
NestedField(field_id=2, name="dbl", type=DoubleType(), required=True),
NestedField(field_id=1, name="str", type=StringType(), required=True),
),
)
expected = StructWriter(((1, DoubleWriter()), (0, StringWriter())))
assert actual == expected
def test_writer_one_more_field() -> None:
actual = resolve_writer(
record_schema=Schema(
NestedField(field_id=3, name="bool", type=BooleanType(), required=True),
NestedField(field_id=1, name="str", type=StringType(), required=True),
NestedField(field_id=2, name="dbl", type=DoubleType(), required=True),
),
file_schema=Schema(
NestedField(field_id=2, name="dbl", type=DoubleType(), required=True),
NestedField(field_id=1, name="str", type=StringType(), required=True),
),
)
expected = StructWriter(((2, DoubleWriter()), (1, StringWriter())))
assert actual == expected
def test_writer_missing_optional_in_read_schema() -> None:
actual = resolve_writer(
record_schema=Schema(),
file_schema=Schema(
NestedField(field_id=1, name="str", type=StringType(), required=False),
),
)
expected = StructWriter(field_writers=((None, OptionWriter(option=StringWriter())),))
assert actual == expected