| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint:disable=protected-access |
| |
| import io |
| import struct |
| from _decimal import Decimal |
| |
| from pyiceberg.avro.encoder import BinaryEncoder |
| from pyiceberg.avro.resolver import construct_writer |
| from pyiceberg.avro.writer import ( |
| BinaryWriter, |
| BooleanWriter, |
| DateWriter, |
| DecimalWriter, |
| DoubleWriter, |
| FixedWriter, |
| FloatWriter, |
| IntegerWriter, |
| StringWriter, |
| TimestampNanoWriter, |
| TimestamptzNanoWriter, |
| TimestamptzWriter, |
| TimestampWriter, |
| TimeWriter, |
| UnknownWriter, |
| UUIDWriter, |
| ) |
| from pyiceberg.typedef import Record |
| from pyiceberg.types import ( |
| BinaryType, |
| BooleanType, |
| DateType, |
| DecimalType, |
| DoubleType, |
| FixedType, |
| FloatType, |
| IntegerType, |
| ListType, |
| LongType, |
| MapType, |
| NestedField, |
| StringType, |
| StructType, |
| TimestampNanoType, |
| TimestampType, |
| TimestamptzNanoType, |
| TimestamptzType, |
| TimeType, |
| UnknownType, |
| UUIDType, |
| ) |
| |
| |
| def zigzag_encode(datum: int) -> bytes: |
| result = [] |
| datum = (datum << 1) ^ (datum >> 63) |
| while (datum & ~0x7F) != 0: |
| result.append(struct.pack("B", (datum & 0x7F) | 0x80)) |
| datum >>= 7 |
| result.append(struct.pack("B", datum)) |
| return b"".join(result) |
| |
| |
| def test_fixed_writer() -> None: |
| assert construct_writer(FixedType(22)) == FixedWriter(22) |
| |
| |
| def test_decimal_writer() -> None: |
| assert construct_writer(DecimalType(19, 25)) == DecimalWriter(19, 25) |
| |
| |
| def test_boolean_writer() -> None: |
| assert construct_writer(BooleanType()) == BooleanWriter() |
| |
| |
| def test_integer_writer() -> None: |
| assert construct_writer(IntegerType()) == IntegerWriter() |
| |
| |
| def test_long_writer() -> None: |
| assert construct_writer(LongType()) == IntegerWriter() |
| |
| |
| def test_float_writer() -> None: |
| assert construct_writer(FloatType()) == FloatWriter() |
| |
| |
| def test_double_writer() -> None: |
| assert construct_writer(DoubleType()) == DoubleWriter() |
| |
| |
| def test_date_writer() -> None: |
| assert construct_writer(DateType()) == DateWriter() |
| |
| |
| def test_time_writer() -> None: |
| assert construct_writer(TimeType()) == TimeWriter() |
| |
| |
| def test_timestamp_writer() -> None: |
| assert construct_writer(TimestampType()) == TimestampWriter() |
| |
| |
| def test_timestamp_ns_writer() -> None: |
| assert construct_writer(TimestampNanoType()) == TimestampNanoWriter() |
| |
| |
| def test_timestamptz_writer() -> None: |
| assert construct_writer(TimestamptzType()) == TimestamptzWriter() |
| |
| |
| def test_timestamptz_ns_writer() -> None: |
| assert construct_writer(TimestamptzNanoType()) == TimestamptzNanoWriter() |
| |
| |
| def test_string_writer() -> None: |
| assert construct_writer(StringType()) == StringWriter() |
| |
| |
| def test_binary_writer() -> None: |
| assert construct_writer(BinaryType()) == BinaryWriter() |
| |
| |
| def test_unknown_type() -> None: |
| assert construct_writer(UnknownType()) == UnknownWriter() |
| |
| |
| def test_uuid_writer() -> None: |
| assert construct_writer(UUIDType()) == UUIDWriter() |
| |
| |
| def test_write_simple_struct() -> None: |
| output = io.BytesIO() |
| encoder = BinaryEncoder(output) |
| |
| schema = StructType( |
| NestedField(1, "id", IntegerType(), required=True), NestedField(2, "property", StringType(), required=True) |
| ) |
| struct = Record(12, "awesome") |
| |
| enc_str = b"awesome" |
| |
| construct_writer(schema).write(encoder, struct) |
| |
| assert output.getbuffer() == b"".join([b"\x18", zigzag_encode(len(enc_str)), enc_str]) |
| |
| |
| def test_write_struct_with_dict() -> None: |
| output = io.BytesIO() |
| encoder = BinaryEncoder(output) |
| |
| schema = StructType( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "properties", MapType(3, IntegerType(), 4, IntegerType()), required=True), |
| ) |
| |
| struct = Record(12, {1: 2, 3: 4}) |
| construct_writer(schema).write(encoder, struct) |
| |
| assert output.getbuffer() == b"".join( |
| [ |
| b"\x18", |
| zigzag_encode(len(struct[1])), |
| zigzag_encode(1), |
| zigzag_encode(2), |
| zigzag_encode(3), |
| zigzag_encode(4), |
| b"\x00", |
| ] |
| ) |
| |
| |
| def test_write_struct_with_list() -> None: |
| output = io.BytesIO() |
| encoder = BinaryEncoder(output) |
| |
| schema = StructType( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "properties", ListType(3, IntegerType()), required=True), |
| ) |
| |
| struct = Record(12, [1, 2, 3, 4]) |
| |
| construct_writer(schema).write(encoder, struct) |
| |
| assert output.getbuffer() == b"".join( |
| [ |
| b"\x18", |
| zigzag_encode(len(struct[1])), |
| zigzag_encode(1), |
| zigzag_encode(2), |
| zigzag_encode(3), |
| zigzag_encode(4), |
| b"\x00", |
| ] |
| ) |
| |
| |
| def test_write_decimal() -> None: |
| output = io.BytesIO() |
| encoder = BinaryEncoder(output) |
| |
| schema = StructType( |
| NestedField(1, "decimal", DecimalType(10, 2), required=True), |
| ) |
| |
| class MyStruct(Record): |
| decimal: Decimal |
| |
| construct_writer(schema).write(encoder, MyStruct(Decimal("1000.12"))) |
| |
| assert output.getvalue() == b"\x00\x00\x01\x86\xac" |