| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=arguments-renamed,unused-argument |
| from enum import Enum |
| from typing import ( |
| Callable, |
| Dict, |
| List, |
| Optional, |
| Tuple, |
| Union, |
| ) |
| |
| from pyiceberg.avro.decoder import BinaryDecoder |
| from pyiceberg.avro.reader import ( |
| BinaryReader, |
| BooleanReader, |
| DateReader, |
| DecimalReader, |
| DefaultReader, |
| DoubleReader, |
| FixedReader, |
| FloatReader, |
| IntegerReader, |
| ListReader, |
| MapReader, |
| NoneReader, |
| OptionReader, |
| Reader, |
| StringReader, |
| StructReader, |
| TimeReader, |
| TimestampNanoReader, |
| TimestampReader, |
| TimestamptzNanoReader, |
| TimestamptzReader, |
| UnknownReader, |
| UUIDReader, |
| ) |
| from pyiceberg.avro.writer import ( |
| BinaryWriter, |
| BooleanWriter, |
| DateWriter, |
| DecimalWriter, |
| DefaultWriter, |
| DoubleWriter, |
| FixedWriter, |
| FloatWriter, |
| IntegerWriter, |
| ListWriter, |
| MapWriter, |
| OptionWriter, |
| StringWriter, |
| StructWriter, |
| TimestampNanoWriter, |
| TimestamptzNanoWriter, |
| TimestamptzWriter, |
| TimestampWriter, |
| TimeWriter, |
| UnknownWriter, |
| UUIDWriter, |
| Writer, |
| ) |
| from pyiceberg.exceptions import ResolveError |
| from pyiceberg.schema import ( |
| PartnerAccessor, |
| PrimitiveWithPartnerVisitor, |
| Schema, |
| SchemaVisitorPerPrimitiveType, |
| promote, |
| visit, |
| visit_with_partner, |
| ) |
| from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol |
| from pyiceberg.types import ( |
| BinaryType, |
| BooleanType, |
| DateType, |
| DecimalType, |
| DoubleType, |
| FixedType, |
| FloatType, |
| IcebergType, |
| IntegerType, |
| ListType, |
| LongType, |
| MapType, |
| NestedField, |
| PrimitiveType, |
| StringType, |
| StructType, |
| TimestampNanoType, |
| TimestampType, |
| TimestamptzNanoType, |
| TimestamptzType, |
| TimeType, |
| UnknownType, |
| UUIDType, |
| ) |
| |
| STRUCT_ROOT = -1 |
| |
| |
| def construct_reader( |
| file_schema: Union[Schema, IcebergType], read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT |
| ) -> Reader: |
| """Construct a reader from a file schema. |
| |
| Args: |
| file_schema (Schema | IcebergType): The schema of the Avro file. |
| read_types (Dict[int, Callable[..., StructProtocol]]): Constructors for structs for certain field-ids |
| |
| Raises: |
| NotImplementedError: If attempting to resolve an unrecognized object type. |
| """ |
| return resolve_reader(file_schema, file_schema, read_types) |
| |
| |
| def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: |
| """Construct a writer from a file schema. |
| |
| Args: |
| file_schema (Schema | IcebergType): The schema of the Avro file. |
| |
| Raises: |
| NotImplementedError: If attempting to resolve an unrecognized object type. |
| """ |
| return visit(file_schema, CONSTRUCT_WRITER_VISITOR) |
| |
| |
| class ConstructWriter(SchemaVisitorPerPrimitiveType[Writer]): |
| """Construct a writer tree from an Iceberg schema.""" |
| |
| def schema(self, schema: Schema, struct_result: Writer) -> Writer: |
| return struct_result |
| |
| def struct(self, struct: StructType, field_results: List[Writer]) -> Writer: |
| return StructWriter(tuple((pos, result) for pos, result in enumerate(field_results))) |
| |
| def field(self, field: NestedField, field_result: Writer) -> Writer: |
| return field_result if field.required else OptionWriter(field_result) |
| |
| def list(self, list_type: ListType, element_result: Writer) -> Writer: |
| return ListWriter(element_result) |
| |
| def map(self, map_type: MapType, key_result: Writer, value_result: Writer) -> Writer: |
| return MapWriter(key_result, value_result) |
| |
| def visit_fixed(self, fixed_type: FixedType) -> Writer: |
| return FixedWriter(len(fixed_type)) |
| |
| def visit_decimal(self, decimal_type: DecimalType) -> Writer: |
| return DecimalWriter(decimal_type.precision, decimal_type.scale) |
| |
| def visit_boolean(self, boolean_type: BooleanType) -> Writer: |
| return BooleanWriter() |
| |
| def visit_integer(self, integer_type: IntegerType) -> Writer: |
| return IntegerWriter() |
| |
| def visit_long(self, long_type: LongType) -> Writer: |
| return IntegerWriter() |
| |
| def visit_float(self, float_type: FloatType) -> Writer: |
| return FloatWriter() |
| |
| def visit_double(self, double_type: DoubleType) -> Writer: |
| return DoubleWriter() |
| |
| def visit_date(self, date_type: DateType) -> Writer: |
| return DateWriter() |
| |
| def visit_time(self, time_type: TimeType) -> Writer: |
| return TimeWriter() |
| |
| def visit_timestamp(self, timestamp_type: TimestampType) -> Writer: |
| return TimestampWriter() |
| |
| def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType) -> Writer: |
| return TimestampNanoWriter() |
| |
| def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> Writer: |
| return TimestamptzWriter() |
| |
| def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType) -> Writer: |
| return TimestamptzNanoWriter() |
| |
| def visit_string(self, string_type: StringType) -> Writer: |
| return StringWriter() |
| |
| def visit_uuid(self, uuid_type: UUIDType) -> Writer: |
| return UUIDWriter() |
| |
| def visit_binary(self, binary_type: BinaryType) -> Writer: |
| return BinaryWriter() |
| |
| def visit_unknown(self, unknown_type: UnknownType) -> Writer: |
| return UnknownWriter() |
| |
| |
| CONSTRUCT_WRITER_VISITOR = ConstructWriter() |
| |
| |
| def resolve_writer( |
| record_schema: Union[Schema, IcebergType], |
| file_schema: Union[Schema, IcebergType], |
| ) -> Writer: |
| """Resolve the file and read schema to produce a reader. |
| |
| Args: |
| record_schema (Schema | IcebergType): The schema of the record in memory. |
| file_schema (Schema | IcebergType): The schema of the file that will be written |
| |
| Raises: |
| NotImplementedError: If attempting to resolve an unrecognized object type. |
| """ |
| if record_schema == file_schema: |
| return construct_writer(file_schema) |
| return visit_with_partner(file_schema, record_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore |
| |
| |
| def resolve_reader( |
| file_schema: Union[Schema, IcebergType], |
| read_schema: Union[Schema, IcebergType], |
| read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, |
| read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT, |
| ) -> Reader: |
| """Resolve the file and read schema to produce a reader. |
| |
| Args: |
| file_schema (Schema | IcebergType): The schema of the Avro file. |
| read_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema. |
| read_types (Dict[int, Callable[..., StructProtocol]]): A dict of types to use for struct data. |
| read_enums (Dict[int, Callable[..., Enum]]): A dict of fields that have to be converted to an enum. |
| |
| Raises: |
| NotImplementedError: If attempting to resolve an unrecognized object type. |
| """ |
| return visit_with_partner(file_schema, read_schema, ReadSchemaResolver(read_types, read_enums), SchemaPartnerAccessor()) # type: ignore |
| |
| |
| class EnumReader(Reader): |
| """An Enum reader to wrap primitive values into an Enum.""" |
| |
| __slots__ = ("enum", "reader") |
| |
| enum: Callable[..., Enum] |
| reader: Reader |
| |
| def __init__(self, enum: Callable[..., Enum], reader: Reader) -> None: |
| self.enum = enum |
| self.reader = reader |
| |
| def read(self, decoder: BinaryDecoder) -> Enum: |
| return self.enum(self.reader.read(decoder)) |
| |
| def skip(self, decoder: BinaryDecoder) -> None: |
| pass |
| |
| |
| class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]): |
| def schema(self, file_schema: Schema, record_schema: Optional[IcebergType], result: Writer) -> Writer: |
| return result |
| |
| def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], file_writers: List[Writer]) -> Writer: |
| if not isinstance(record_struct, StructType): |
| raise ResolveError(f"File/write schema are not aligned for struct, got {record_struct}") |
| |
| record_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(record_struct.fields)} |
| results: List[Tuple[Optional[int], Writer]] = [] |
| |
| for writer, file_field in zip(file_writers, file_schema.fields): |
| if file_field.field_id in record_struct_positions: |
| results.append((record_struct_positions[file_field.field_id], writer)) |
| elif file_field.required: |
| # There is a default value |
| if file_field.write_default is not None: |
| # The field is not in the record, but there is a write default value |
| results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) |
| elif file_field.required: |
| raise ValueError(f"Field is required, and there is no write default: {file_field}") |
| else: |
| results.append((None, writer)) |
| |
| return StructWriter(field_writers=tuple(results)) |
| |
| def field(self, file_field: NestedField, record_type: Optional[IcebergType], field_writer: Writer) -> Writer: |
| return field_writer if file_field.required else OptionWriter(field_writer) |
| |
| def list(self, file_list_type: ListType, file_list: Optional[IcebergType], element_writer: Writer) -> Writer: |
| return ListWriter(element_writer if file_list_type.element_required else OptionWriter(element_writer)) |
| |
| def map( |
| self, file_map_type: MapType, file_primitive: Optional[IcebergType], key_writer: Writer, value_writer: Writer |
| ) -> Writer: |
| return MapWriter(key_writer, value_writer if file_map_type.value_required else OptionWriter(value_writer)) |
| |
| def primitive(self, file_primitive: PrimitiveType, record_primitive: Optional[IcebergType]) -> Writer: |
| if record_primitive is not None: |
| # ensure that the type can be projected to the expected |
| if file_primitive != record_primitive: |
| promote(record_primitive, file_primitive) |
| |
| return super().primitive(file_primitive, file_primitive) |
| |
| def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer: |
| return BooleanWriter() |
| |
| def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Writer: |
| return IntegerWriter() |
| |
| def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Writer: |
| return IntegerWriter() |
| |
| def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Writer: |
| return FloatWriter() |
| |
| def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Writer: |
| return DoubleWriter() |
| |
| def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Writer: |
| return DecimalWriter(decimal_type.precision, decimal_type.scale) |
| |
| def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Writer: |
| return DateWriter() |
| |
| def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Writer: |
| return TimeWriter() |
| |
| def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Writer: |
| return TimestampWriter() |
| |
| def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[IcebergType]) -> Writer: |
| return TimestampNanoWriter() |
| |
| def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Writer: |
| return TimestamptzWriter() |
| |
| def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[IcebergType]) -> Writer: |
| return TimestamptzNanoWriter() |
| |
| def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Writer: |
| return StringWriter() |
| |
| def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Writer: |
| return UUIDWriter() |
| |
| def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Writer: |
| return FixedWriter(len(fixed_type)) |
| |
| def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Writer: |
| return BinaryWriter() |
| |
| def visit_unknown(self, unknown_type: UnknownType, partner: Optional[IcebergType]) -> Writer: |
| return UnknownWriter() |
| |
| |
| class ReadSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]): |
| __slots__ = ("read_types", "read_enums", "context") |
| read_types: Dict[int, Callable[..., StructProtocol]] |
| read_enums: Dict[int, Callable[..., Enum]] |
| context: List[int] |
| |
| def __init__( |
| self, |
| read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, |
| read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT, |
| ) -> None: |
| self.read_types = read_types |
| self.read_enums = read_enums |
| self.context = [] |
| |
| def schema(self, schema: Schema, expected_schema: Optional[IcebergType], result: Reader) -> Reader: |
| return result |
| |
| def before_field(self, field: NestedField, field_partner: Optional[NestedField]) -> None: |
| self.context.append(field.field_id) |
| |
| def after_field(self, field: NestedField, field_partner: Optional[NestedField]) -> None: |
| self.context.pop() |
| |
| def struct(self, struct: StructType, expected_struct: Optional[IcebergType], field_readers: List[Reader]) -> Reader: |
| read_struct_id = self.context[STRUCT_ROOT] if len(self.context) > 0 else STRUCT_ROOT |
| struct_callable = self.read_types.get(read_struct_id, Record) |
| |
| if not expected_struct: |
| return StructReader(tuple(enumerate(field_readers)), struct_callable, struct) |
| |
| if not isinstance(expected_struct, StructType): |
| raise ResolveError(f"File/read schema are not aligned for struct, got {expected_struct}") |
| |
| expected_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(expected_struct.fields)} |
| |
| # first, add readers for the file fields that must be in order |
| results: List[Tuple[Optional[int], Reader]] = [ |
| ( |
| expected_positions.get(field.field_id), |
| # Check if we need to convert it to an Enum |
| result_reader if not (enum_type := self.read_enums.get(field.field_id)) else EnumReader(enum_type, result_reader), |
| ) |
| for field, result_reader in zip(struct.fields, field_readers) |
| ] |
| |
| file_fields = {field.field_id for field in struct.fields} |
| for pos, read_field in enumerate(expected_struct.fields): |
| if read_field.field_id not in file_fields: |
| if isinstance(read_field, NestedField) and read_field.initial_default is not None: |
| # The field is not in the file, but there is a default value |
| # and that one can be required |
| results.append((pos, DefaultReader(read_field.initial_default))) |
| elif read_field.required: |
| raise ResolveError(f"{read_field} is non-optional, and not part of the file schema") |
| else: |
| # Just set the new field to None |
| results.append((pos, NoneReader())) |
| |
| return StructReader(tuple(results), struct_callable, expected_struct) |
| |
| def field(self, field: NestedField, expected_field: Optional[IcebergType], field_reader: Reader) -> Reader: |
| return field_reader if field.required else OptionReader(field_reader) |
| |
| def list(self, list_type: ListType, expected_list: Optional[IcebergType], element_reader: Reader) -> Reader: |
| if expected_list and not isinstance(expected_list, ListType): |
| raise ResolveError(f"File/read schema are not aligned for list, got {expected_list}") |
| |
| return ListReader(element_reader if list_type.element_required else OptionReader(element_reader)) |
| |
| def map(self, map_type: MapType, expected_map: Optional[IcebergType], key_reader: Reader, value_reader: Reader) -> Reader: |
| if expected_map and not isinstance(expected_map, MapType): |
| raise ResolveError(f"File/read schema are not aligned for map, got {expected_map}") |
| |
| return MapReader(key_reader, value_reader if map_type.value_required else OptionReader(value_reader)) |
| |
| def primitive(self, primitive: PrimitiveType, expected_primitive: Optional[IcebergType]) -> Reader: |
| if expected_primitive is not None: |
| if not isinstance(expected_primitive, PrimitiveType): |
| raise ResolveError(f"File/read schema are not aligned for {primitive}, got {expected_primitive}") |
| |
| # ensure that the type can be projected to the expected |
| if primitive != expected_primitive: |
| promote(primitive, expected_primitive) |
| |
| return super().primitive(primitive, expected_primitive) |
| |
| def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Reader: |
| return BooleanReader() |
| |
| def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Reader: |
| return IntegerReader() |
| |
| def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Reader: |
| return IntegerReader() |
| |
| def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Reader: |
| return FloatReader() |
| |
| def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Reader: |
| return DoubleReader() |
| |
| def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Reader: |
| return DecimalReader(decimal_type.precision, decimal_type.scale) |
| |
| def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Reader: |
| return DateReader() |
| |
| def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Reader: |
| return TimeReader() |
| |
| def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Reader: |
| return TimestampReader() |
| |
| def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[IcebergType]) -> Reader: |
| return TimestampNanoReader() |
| |
| def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Reader: |
| return TimestamptzReader() |
| |
| def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[IcebergType]) -> Reader: |
| return TimestamptzNanoReader() |
| |
| def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Reader: |
| return StringReader() |
| |
| def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Reader: |
| return UUIDReader() |
| |
| def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Reader: |
| return FixedReader(len(fixed_type)) |
| |
| def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Reader: |
| return BinaryReader() |
| |
| def visit_unknown(self, unknown_type: UnknownType, partner: Optional[IcebergType]) -> Reader: |
| return UnknownReader() |
| |
| |
| class SchemaPartnerAccessor(PartnerAccessor[IcebergType]): |
| def schema_partner(self, partner: Optional[IcebergType]) -> Optional[IcebergType]: |
| if isinstance(partner, Schema): |
| return partner.as_struct() |
| |
| raise ResolveError(f"File/read schema are not aligned for schema, got {partner}") |
| |
| def field_partner(self, partner: Optional[IcebergType], field_id: int, field_name: str) -> Optional[IcebergType]: |
| if isinstance(partner, StructType): |
| field = partner.field(field_id) |
| else: |
| raise ResolveError(f"File/read schema are not aligned for struct, got {partner}") |
| |
| return field.field_type if field else None |
| |
| def list_element_partner(self, partner_list: Optional[IcebergType]) -> Optional[IcebergType]: |
| if isinstance(partner_list, ListType): |
| return partner_list.element_type |
| |
| raise ResolveError(f"File/read schema are not aligned for list, got {partner_list}") |
| |
| def map_key_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]: |
| if isinstance(partner_map, MapType): |
| return partner_map.key_type |
| |
| raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}") |
| |
| def map_value_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]: |
| if isinstance(partner_map, MapType): |
| return partner_map.value_type |
| |
| raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}") |