| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ |
| Classes for building the Writer tree. |
| |
| Constructing a writer tree from the schema makes it easy |
| to decouple the writing implementation from the schema. |
| """ |
| |
| from __future__ import annotations |
| |
| from abc import abstractmethod |
| from dataclasses import dataclass |
| from dataclasses import field as dataclassfield |
| from typing import ( |
| Any, |
| Dict, |
| List, |
| Optional, |
| Tuple, |
| ) |
| from uuid import UUID |
| |
| from pyiceberg.avro.encoder import BinaryEncoder |
| from pyiceberg.typedef import Record |
| from pyiceberg.utils.decimal import decimal_required_bytes, decimal_to_bytes |
| from pyiceberg.utils.singleton import Singleton |
| |
| |
| @dataclass(frozen=True) |
| class Writer(Singleton): |
| @abstractmethod |
| def write(self, encoder: BinaryEncoder, val: Any) -> Any: ... |
| |
| def __repr__(self) -> str: |
| """Return string representation of this object.""" |
| return f"{self.__class__.__name__}()" |
| |
| |
| @dataclass(frozen=True) |
| class BooleanWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: bool) -> None: |
| encoder.write_boolean(val) |
| |
| |
| @dataclass(frozen=True) |
| class IntegerWriter(Writer): |
| """Longs and ints are encoded the same way, and there is no long in Python.""" |
| |
| def write(self, encoder: BinaryEncoder, val: int) -> None: |
| encoder.write_int(val) |
| |
| |
| @dataclass(frozen=True) |
| class FloatWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: float) -> None: |
| encoder.write_float(val) |
| |
| |
| @dataclass(frozen=True) |
| class DoubleWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: float) -> None: |
| encoder.write_double(val) |
| |
| |
| @dataclass(frozen=True) |
| class DateWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: int) -> None: |
| encoder.write_int(val) |
| |
| |
| @dataclass(frozen=True) |
| class TimeWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: int) -> None: |
| encoder.write_int(val) |
| |
| |
| @dataclass(frozen=True) |
| class TimestampWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: int) -> None: |
| encoder.write_int(val) |
| |
| |
| @dataclass(frozen=True) |
| class TimestamptzWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: int) -> None: |
| encoder.write_int(val) |
| |
| |
| @dataclass(frozen=True) |
| class StringWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: Any) -> None: |
| encoder.write_utf8(val) |
| |
| |
| @dataclass(frozen=True) |
| class UUIDWriter(Writer): |
| def write(self, encoder: BinaryEncoder, val: UUID) -> None: |
| encoder.write(val.bytes) |
| |
| |
| @dataclass(frozen=True) |
| class FixedWriter(Writer): |
| _len: int = dataclassfield() |
| |
| def write(self, encoder: BinaryEncoder, val: bytes) -> None: |
| if len(val) != self._len: |
| raise ValueError(f"Expected {self._len} bytes, got {len(val)}") |
| encoder.write(val) |
| |
| def __len__(self) -> int: |
| """Return the length of this object.""" |
| return self._len |
| |
| def __repr__(self) -> str: |
| """Return string representation of this object.""" |
| return f"FixedWriter({self._len})" |
| |
| |
| @dataclass(frozen=True) |
| class BinaryWriter(Writer): |
| """Variable byte length writer.""" |
| |
| def write(self, encoder: BinaryEncoder, val: Any) -> None: |
| encoder.write_bytes(val) |
| |
| |
| @dataclass(frozen=True) |
| class DecimalWriter(Writer): |
| precision: int = dataclassfield() |
| scale: int = dataclassfield() |
| |
| def write(self, encoder: BinaryEncoder, val: Any) -> None: |
| return encoder.write(decimal_to_bytes(val, byte_length=decimal_required_bytes(self.precision))) |
| |
| def __repr__(self) -> str: |
| """Return string representation of this object.""" |
| return f"DecimalWriter({self.precision}, {self.scale})" |
| |
| |
| @dataclass(frozen=True) |
| class OptionWriter(Writer): |
| option: Writer = dataclassfield() |
| |
| def write(self, encoder: BinaryEncoder, val: Any) -> None: |
| if val is not None: |
| encoder.write_int(1) |
| self.option.write(encoder, val) |
| else: |
| encoder.write_int(0) |
| |
| |
| @dataclass(frozen=True) |
| class StructWriter(Writer): |
| field_writers: Tuple[Tuple[Optional[int], Writer], ...] = dataclassfield() |
| |
| def write(self, encoder: BinaryEncoder, val: Record) -> None: |
| for pos, writer in self.field_writers: |
| # When pos is None, then it is a default value |
| writer.write(encoder, val[pos] if pos is not None else None) |
| |
| def __eq__(self, other: Any) -> bool: |
| """Implement the equality operator for this object.""" |
| return self.field_writers == other.field_writers if isinstance(other, StructWriter) else False |
| |
| def __repr__(self) -> str: |
| """Return string representation of this object.""" |
| return f"StructWriter(tuple(({','.join(repr(field) for field in self.field_writers)})))" |
| |
| def __hash__(self) -> int: |
| """Return the hash of the writer as hash of this object.""" |
| return hash(self.field_writers) |
| |
| |
| @dataclass(frozen=True) |
| class ListWriter(Writer): |
| element_writer: Writer |
| |
| def write(self, encoder: BinaryEncoder, val: List[Any]) -> None: |
| encoder.write_int(len(val)) |
| for v in val: |
| self.element_writer.write(encoder, v) |
| if len(val) > 0: |
| encoder.write_int(0) |
| |
| |
| @dataclass(frozen=True) |
| class MapWriter(Writer): |
| key_writer: Writer |
| value_writer: Writer |
| |
| def write(self, encoder: BinaryEncoder, val: Dict[Any, Any]) -> None: |
| encoder.write_int(len(val)) |
| for k, v in val.items(): |
| self.key_writer.write(encoder, k) |
| self.value_writer.write(encoder, v) |
| if len(val) > 0: |
| encoder.write_int(0) |
| |
| |
| @dataclass(frozen=True) |
| class DefaultWriter(Writer): |
| writer: Writer |
| value: Any |
| |
| def write(self, encoder: BinaryEncoder, _: Any) -> None: |
| self.writer.write(encoder, self.value) |