| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=redefined-outer-name,arguments-renamed,fixme |
| """FileIO implementation for reading and writing table files that uses pyarrow.fs. |
| |
| This file contains a FileIO implementation that relies on the filesystem interface provided |
| by PyArrow. It relies on PyArrow's `from_uri` method that infers the correct filesystem |
| type to use. Theoretically, this allows the supported storage types to grow naturally |
| with the pyarrow library. |
| """ |
| |
| from __future__ import annotations |
| |
| import concurrent.futures |
| import fnmatch |
| import itertools |
| import logging |
| import os |
| import re |
| from abc import ABC, abstractmethod |
| from concurrent.futures import Future |
| from dataclasses import dataclass |
| from enum import Enum |
| from functools import lru_cache, singledispatch |
| from typing import ( |
| TYPE_CHECKING, |
| Any, |
| Callable, |
| Dict, |
| Generic, |
| Iterable, |
| Iterator, |
| List, |
| Optional, |
| Set, |
| Tuple, |
| TypeVar, |
| Union, |
| cast, |
| ) |
| from urllib.parse import urlparse |
| |
| import numpy as np |
| import pyarrow as pa |
| import pyarrow.compute as pc |
| import pyarrow.dataset as ds |
| import pyarrow.lib |
| import pyarrow.parquet as pq |
| from pyarrow import ChunkedArray |
| from pyarrow.fs import ( |
| FileInfo, |
| FileSystem, |
| FileType, |
| FSSpecHandler, |
| ) |
| from sortedcontainers import SortedList |
| |
| from pyiceberg.avro.resolver import ResolveError |
| from pyiceberg.conversions import to_bytes |
| from pyiceberg.expressions import ( |
| AlwaysTrue, |
| BooleanExpression, |
| BoundTerm, |
| ) |
| from pyiceberg.expressions.literals import Literal |
| from pyiceberg.expressions.visitors import ( |
| BoundBooleanExpressionVisitor, |
| bind, |
| extract_field_ids, |
| translate_column_names, |
| ) |
| from pyiceberg.expressions.visitors import visit as boolean_expression_visit |
| from pyiceberg.io import ( |
| GCS_DEFAULT_LOCATION, |
| GCS_ENDPOINT, |
| GCS_TOKEN, |
| GCS_TOKEN_EXPIRES_AT_MS, |
| HDFS_HOST, |
| HDFS_KERB_TICKET, |
| HDFS_PORT, |
| HDFS_USER, |
| S3_ACCESS_KEY_ID, |
| S3_CONNECT_TIMEOUT, |
| S3_ENDPOINT, |
| S3_PROXY_URI, |
| S3_REGION, |
| S3_SECRET_ACCESS_KEY, |
| S3_SESSION_TOKEN, |
| FileIO, |
| InputFile, |
| InputStream, |
| OutputFile, |
| OutputStream, |
| ) |
| from pyiceberg.manifest import ( |
| DataFile, |
| DataFileContent, |
| FileFormat, |
| ) |
| from pyiceberg.schema import ( |
| PartnerAccessor, |
| PreOrderSchemaVisitor, |
| Schema, |
| SchemaVisitorPerPrimitiveType, |
| SchemaWithPartnerVisitor, |
| pre_order_visit, |
| promote, |
| prune_columns, |
| visit, |
| visit_with_partner, |
| ) |
| from pyiceberg.table import PropertyUtil, TableProperties, WriteTask |
| from pyiceberg.table.name_mapping import NameMapping |
| from pyiceberg.transforms import TruncateTransform |
| from pyiceberg.typedef import EMPTY_DICT, Properties, Record |
| from pyiceberg.types import ( |
| BinaryType, |
| BooleanType, |
| DateType, |
| DecimalType, |
| DoubleType, |
| FixedType, |
| FloatType, |
| IcebergType, |
| IntegerType, |
| ListType, |
| LongType, |
| MapType, |
| NestedField, |
| PrimitiveType, |
| StringType, |
| StructType, |
| TimestampType, |
| TimestamptzType, |
| TimeType, |
| UUIDType, |
| ) |
| from pyiceberg.utils.concurrent import ExecutorFactory |
| from pyiceberg.utils.datetime import millis_to_datetime |
| from pyiceberg.utils.singleton import Singleton |
| from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string |
| |
| if TYPE_CHECKING: |
| from pyiceberg.table import FileScanTask, Table |
| |
| logger = logging.getLogger(__name__) |
| |
| ONE_MEGABYTE = 1024 * 1024 |
| BUFFER_SIZE = "buffer-size" |
| ICEBERG_SCHEMA = b"iceberg.schema" |
| # The PARQUET: in front means that it is Parquet specific, in this case the field_id |
| PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" |
| PYARROW_FIELD_DOC_KEY = b"doc" |
| LIST_ELEMENT_NAME = "element" |
| MAP_KEY_NAME = "key" |
| MAP_VALUE_NAME = "value" |
| DOC = "doc" |
| |
| T = TypeVar("T") |
| |
| |
| class PyArrowLocalFileSystem(pyarrow.fs.LocalFileSystem): |
| def open_output_stream(self, path: str, *args: Any, **kwargs: Any) -> pyarrow.NativeFile: |
| # In LocalFileSystem, parent directories must be first created before opening an output stream |
| self.create_dir(os.path.dirname(path), recursive=True) |
| return super().open_output_stream(path, *args, **kwargs) |
| |
| |
| class PyArrowFile(InputFile, OutputFile): |
| """A combined InputFile and OutputFile implementation that uses a pyarrow filesystem to generate pyarrow.lib.NativeFile instances. |
| |
| Args: |
| location (str): A URI or a path to a local file. |
| |
| Attributes: |
| location(str): The URI or path to a local file for a PyArrowFile instance. |
| |
| Examples: |
| >>> from pyiceberg.io.pyarrow import PyArrowFile |
| >>> # input_file = PyArrowFile("s3://foo/bar.txt") |
| >>> # Read the contents of the PyArrowFile instance |
| >>> # Make sure that you have permissions to read/write |
| >>> # file_content = input_file.open().read() |
| |
| >>> # output_file = PyArrowFile("s3://baz/qux.txt") |
| >>> # Write bytes to a file |
| >>> # Make sure that you have permissions to read/write |
| >>> # output_file.create().write(b'foobytes') |
| """ |
| |
| _fs: FileSystem |
| _path: str |
| _buffer_size: int |
| |
| def __init__(self, location: str, path: str, fs: FileSystem, buffer_size: int = ONE_MEGABYTE): |
| self._filesystem = fs |
| self._path = path |
| self._buffer_size = buffer_size |
| super().__init__(location=location) |
| |
| def _file_info(self) -> FileInfo: |
| """Retrieve a pyarrow.fs.FileInfo object for the location. |
| |
| Raises: |
| PermissionError: If the file at self.location cannot be accessed due to a permission error such as |
| an AWS error code 15. |
| """ |
| try: |
| file_info = self._filesystem.get_file_info(self._path) |
| except OSError as e: |
| if e.errno == 13 or "AWS Error [code 15]" in str(e): |
| raise PermissionError(f"Cannot get file info, access denied: {self.location}") from e |
| raise # pragma: no cover - If some other kind of OSError, raise the raw error |
| |
| if file_info.type == FileType.NotFound: |
| raise FileNotFoundError(f"Cannot get file info, file not found: {self.location}") |
| return file_info |
| |
| def __len__(self) -> int: |
| """Return the total length of the file, in bytes.""" |
| file_info = self._file_info() |
| return file_info.size |
| |
| def exists(self) -> bool: |
| """Check whether the location exists.""" |
| try: |
| self._file_info() # raises FileNotFoundError if it does not exist |
| return True |
| except FileNotFoundError: |
| return False |
| |
| def open(self, seekable: bool = True) -> InputStream: |
| """Open the location using a PyArrow FileSystem inferred from the location. |
| |
| Args: |
| seekable: If the stream should support seek, or if it is consumed sequential. |
| |
| Returns: |
| pyarrow.lib.NativeFile: A NativeFile instance for the file located at `self.location`. |
| |
| Raises: |
| FileNotFoundError: If the file at self.location does not exist. |
| PermissionError: If the file at self.location cannot be accessed due to a permission error such as |
| an AWS error code 15. |
| """ |
| try: |
| if seekable: |
| input_file = self._filesystem.open_input_file(self._path) |
| else: |
| input_file = self._filesystem.open_input_stream(self._path, buffer_size=self._buffer_size) |
| except FileNotFoundError: |
| raise |
| except PermissionError: |
| raise |
| except OSError as e: |
| if e.errno == 2 or "Path does not exist" in str(e): |
| raise FileNotFoundError(f"Cannot open file, does not exist: {self.location}") from e |
| elif e.errno == 13 or "AWS Error [code 15]" in str(e): |
| raise PermissionError(f"Cannot open file, access denied: {self.location}") from e |
| raise # pragma: no cover - If some other kind of OSError, raise the raw error |
| return input_file |
| |
| def create(self, overwrite: bool = False) -> OutputStream: |
| """Create a writable pyarrow.lib.NativeFile for this PyArrowFile's location. |
| |
| Args: |
| overwrite (bool): Whether to overwrite the file if it already exists. |
| |
| Returns: |
| pyarrow.lib.NativeFile: A NativeFile instance for the file located at self.location. |
| |
| Raises: |
| FileExistsError: If the file already exists at `self.location` and `overwrite` is False. |
| |
| Note: |
| This retrieves a pyarrow NativeFile by opening an output stream. If overwrite is set to False, |
| a check is first performed to verify that the file does not exist. This is not thread-safe and |
| a possibility does exist that the file can be created by a concurrent process after the existence |
| check yet before the output stream is created. In such a case, the default pyarrow behavior will |
| truncate the contents of the existing file when opening the output stream. |
| """ |
| try: |
| if not overwrite and self.exists() is True: |
| raise FileExistsError(f"Cannot create file, already exists: {self.location}") |
| output_file = self._filesystem.open_output_stream(self._path, buffer_size=self._buffer_size) |
| except PermissionError: |
| raise |
| except OSError as e: |
| if e.errno == 13 or "AWS Error [code 15]" in str(e): |
| raise PermissionError(f"Cannot create file, access denied: {self.location}") from e |
| raise # pragma: no cover - If some other kind of OSError, raise the raw error |
| return output_file |
| |
| def to_input_file(self) -> PyArrowFile: |
| """Return a new PyArrowFile for the location of an existing PyArrowFile instance. |
| |
| This method is included to abide by the OutputFile abstract base class. Since this implementation uses a single |
| PyArrowFile class (as opposed to separate InputFile and OutputFile implementations), this method effectively returns |
| a copy of the same instance. |
| """ |
| return self |
| |
| |
| class PyArrowFileIO(FileIO): |
| fs_by_scheme: Callable[[str, Optional[str]], FileSystem] |
| |
| def __init__(self, properties: Properties = EMPTY_DICT): |
| self.fs_by_scheme: Callable[[str, Optional[str]], FileSystem] = lru_cache(self._initialize_fs) |
| super().__init__(properties=properties) |
| |
| @staticmethod |
| def parse_location(location: str) -> Tuple[str, str, str]: |
| """Return the path without the scheme.""" |
| uri = urlparse(location) |
| if not uri.scheme: |
| return "file", uri.netloc, os.path.abspath(location) |
| elif uri.scheme == "hdfs": |
| return uri.scheme, uri.netloc, location |
| else: |
| return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}" |
| |
| def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem: |
| if scheme in {"s3", "s3a", "s3n"}: |
| from pyarrow.fs import S3FileSystem |
| |
| client_kwargs: Dict[str, Any] = { |
| "endpoint_override": self.properties.get(S3_ENDPOINT), |
| "access_key": self.properties.get(S3_ACCESS_KEY_ID), |
| "secret_key": self.properties.get(S3_SECRET_ACCESS_KEY), |
| "session_token": self.properties.get(S3_SESSION_TOKEN), |
| "region": self.properties.get(S3_REGION), |
| } |
| |
| if proxy_uri := self.properties.get(S3_PROXY_URI): |
| client_kwargs["proxy_options"] = proxy_uri |
| |
| if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): |
| client_kwargs["connect_timeout"] = float(connect_timeout) |
| |
| return S3FileSystem(**client_kwargs) |
| elif scheme == "hdfs": |
| from pyarrow.fs import HadoopFileSystem |
| |
| hdfs_kwargs: Dict[str, Any] = {} |
| if netloc: |
| return HadoopFileSystem.from_uri(f"hdfs://{netloc}") |
| if host := self.properties.get(HDFS_HOST): |
| hdfs_kwargs["host"] = host |
| if port := self.properties.get(HDFS_PORT): |
| # port should be an integer type |
| hdfs_kwargs["port"] = int(port) |
| if user := self.properties.get(HDFS_USER): |
| hdfs_kwargs["user"] = user |
| if kerb_ticket := self.properties.get(HDFS_KERB_TICKET): |
| hdfs_kwargs["kerb_ticket"] = kerb_ticket |
| |
| return HadoopFileSystem(**hdfs_kwargs) |
| elif scheme in {"gs", "gcs"}: |
| from pyarrow.fs import GcsFileSystem |
| |
| gcs_kwargs: Dict[str, Any] = {} |
| if access_token := self.properties.get(GCS_TOKEN): |
| gcs_kwargs["access_token"] = access_token |
| if expiration := self.properties.get(GCS_TOKEN_EXPIRES_AT_MS): |
| gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration)) |
| if bucket_location := self.properties.get(GCS_DEFAULT_LOCATION): |
| gcs_kwargs["default_bucket_location"] = bucket_location |
| if endpoint := self.properties.get(GCS_ENDPOINT): |
| url_parts = urlparse(endpoint) |
| gcs_kwargs["scheme"] = url_parts.scheme |
| gcs_kwargs["endpoint_override"] = url_parts.netloc |
| |
| return GcsFileSystem(**gcs_kwargs) |
| elif scheme == "file": |
| return PyArrowLocalFileSystem() |
| else: |
| raise ValueError(f"Unrecognized filesystem type in URI: {scheme}") |
| |
| def new_input(self, location: str) -> PyArrowFile: |
| """Get a PyArrowFile instance to read bytes from the file at the given location. |
| |
| Args: |
| location (str): A URI or a path to a local file. |
| |
| Returns: |
| PyArrowFile: A PyArrowFile instance for the given location. |
| """ |
| scheme, netloc, path = self.parse_location(location) |
| return PyArrowFile( |
| fs=self.fs_by_scheme(scheme, netloc), |
| location=location, |
| path=path, |
| buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)), |
| ) |
| |
| def new_output(self, location: str) -> PyArrowFile: |
| """Get a PyArrowFile instance to write bytes to the file at the given location. |
| |
| Args: |
| location (str): A URI or a path to a local file. |
| |
| Returns: |
| PyArrowFile: A PyArrowFile instance for the given location. |
| """ |
| scheme, netloc, path = self.parse_location(location) |
| return PyArrowFile( |
| fs=self.fs_by_scheme(scheme, netloc), |
| location=location, |
| path=path, |
| buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)), |
| ) |
| |
| def delete(self, location: Union[str, InputFile, OutputFile]) -> None: |
| """Delete the file at the given location. |
| |
| Args: |
| location (Union[str, InputFile, OutputFile]): The URI to the file--if an InputFile instance or an OutputFile instance is provided, |
| the location attribute for that instance is used as the location to delete. |
| |
| Raises: |
| FileNotFoundError: When the file at the provided location does not exist. |
| PermissionError: If the file at the provided location cannot be accessed due to a permission error such as |
| an AWS error code 15. |
| """ |
| str_location = location.location if isinstance(location, (InputFile, OutputFile)) else location |
| scheme, netloc, path = self.parse_location(str_location) |
| fs = self.fs_by_scheme(scheme, netloc) |
| |
| try: |
| fs.delete_file(path) |
| except FileNotFoundError: |
| raise |
| except PermissionError: |
| raise |
| except OSError as e: |
| if e.errno == 2 or "Path does not exist" in str(e): |
| raise FileNotFoundError(f"Cannot delete file, does not exist: {location}") from e |
| elif e.errno == 13 or "AWS Error [code 15]" in str(e): |
| raise PermissionError(f"Cannot delete file, access denied: {location}") from e |
| raise # pragma: no cover - If some other kind of OSError, raise the raw error |
| |
| |
| def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema: |
| return visit(schema, _ConvertToArrowSchema(metadata)) |
| |
| |
| class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]): |
| _metadata: Dict[bytes, bytes] |
| |
| def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT) -> None: |
| self._metadata = metadata |
| |
| def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema: |
| return pa.schema(list(struct_result), metadata=self._metadata) |
| |
| def struct(self, _: StructType, field_results: List[pa.DataType]) -> pa.DataType: |
| return pa.struct(field_results) |
| |
| def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field: |
| return pa.field( |
| name=field.name, |
| type=field_result, |
| nullable=field.optional, |
| metadata={PYARROW_FIELD_DOC_KEY: field.doc, PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)} |
| if field.doc |
| else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}, |
| ) |
| |
| def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType: |
| element_field = self.field(list_type.element_field, element_result) |
| return pa.list_(value_type=element_field) |
| |
| def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType: |
| key_field = self.field(map_type.key_field, key_result) |
| value_field = self.field(map_type.value_field, value_result) |
| return pa.map_(key_type=key_field, item_type=value_field) |
| |
| def visit_fixed(self, fixed_type: FixedType) -> pa.DataType: |
| return pa.binary(len(fixed_type)) |
| |
| def visit_decimal(self, decimal_type: DecimalType) -> pa.DataType: |
| return pa.decimal128(decimal_type.precision, decimal_type.scale) |
| |
| def visit_boolean(self, _: BooleanType) -> pa.DataType: |
| return pa.bool_() |
| |
| def visit_integer(self, _: IntegerType) -> pa.DataType: |
| return pa.int32() |
| |
| def visit_long(self, _: LongType) -> pa.DataType: |
| return pa.int64() |
| |
| def visit_float(self, _: FloatType) -> pa.DataType: |
| # 32-bit IEEE 754 floating point |
| return pa.float32() |
| |
| def visit_double(self, _: DoubleType) -> pa.DataType: |
| # 64-bit IEEE 754 floating point |
| return pa.float64() |
| |
| def visit_date(self, _: DateType) -> pa.DataType: |
| # Date encoded as an int |
| return pa.date32() |
| |
| def visit_time(self, _: TimeType) -> pa.DataType: |
| return pa.time64("us") |
| |
| def visit_timestamp(self, _: TimestampType) -> pa.DataType: |
| return pa.timestamp(unit="us") |
| |
| def visit_timestamptz(self, _: TimestamptzType) -> pa.DataType: |
| return pa.timestamp(unit="us", tz="UTC") |
| |
| def visit_string(self, _: StringType) -> pa.DataType: |
| return pa.string() |
| |
| def visit_uuid(self, _: UUIDType) -> pa.DataType: |
| return pa.binary(16) |
| |
| def visit_binary(self, _: BinaryType) -> pa.DataType: |
| return pa.large_binary() |
| |
| |
| def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar: |
| if not isinstance(iceberg_type, PrimitiveType): |
| raise ValueError(f"Expected primitive type, got: {iceberg_type}") |
| return pa.scalar(value=value, type=schema_to_pyarrow(iceberg_type)) |
| |
| |
| class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]): |
| def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression: |
| pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) |
| return pc.field(term.ref().field.name).isin(pyarrow_literals) |
| |
| def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression: |
| pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) |
| return ~pc.field(term.ref().field.name).isin(pyarrow_literals) |
| |
| def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression: |
| ref = pc.field(term.ref().field.name) |
| return pc.is_nan(ref) |
| |
| def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression: |
| ref = pc.field(term.ref().field.name) |
| return ~pc.is_nan(ref) |
| |
| def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name).is_null(nan_is_null=False) |
| |
| def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name).is_valid() |
| |
| def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name) == _convert_scalar(literal.value, term.ref().field.field_type) |
| |
| def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name) != _convert_scalar(literal.value, term.ref().field.field_type) |
| |
| def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name) >= _convert_scalar(literal.value, term.ref().field.field_type) |
| |
| def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name) > _convert_scalar(literal.value, term.ref().field.field_type) |
| |
| def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name) < _convert_scalar(literal.value, term.ref().field.field_type) |
| |
| def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return pc.field(term.ref().field.name) <= _convert_scalar(literal.value, term.ref().field.field_type) |
| |
| def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return pc.starts_with(pc.field(term.ref().field.name), literal.value) |
| |
| def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: |
| return ~pc.starts_with(pc.field(term.ref().field.name), literal.value) |
| |
| def visit_true(self) -> pc.Expression: |
| return pc.scalar(True) |
| |
| def visit_false(self) -> pc.Expression: |
| return pc.scalar(False) |
| |
| def visit_not(self, child_result: pc.Expression) -> pc.Expression: |
| return ~child_result |
| |
| def visit_and(self, left_result: pc.Expression, right_result: pc.Expression) -> pc.Expression: |
| return left_result & right_result |
| |
| def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> pc.Expression: |
| return left_result | right_result |
| |
| |
| def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression: |
| return boolean_expression_visit(expr, _ConvertToArrowExpression()) |
| |
| |
| @lru_cache |
| def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat: |
| if file_format == FileFormat.PARQUET: |
| return ds.ParquetFileFormat(**kwargs) |
| else: |
| raise ValueError(f"Unsupported file format: {file_format}") |
| |
| |
| def _construct_fragment(fs: FileSystem, data_file: DataFile, file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment: |
| _, _, path = PyArrowFileIO.parse_location(data_file.file_path) |
| return _get_file_format(data_file.file_format, **file_format_kwargs).make_fragment(path, fs) |
| |
| |
| def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedArray]: |
| delete_fragment = _construct_fragment( |
| fs, data_file, file_format_kwargs={"dictionary_columns": ("file_path",), "pre_buffer": True, "buffer_size": ONE_MEGABYTE} |
| ) |
| table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table() |
| table = table.unify_dictionaries() |
| return { |
| file.as_py(): table.filter(pc.field("file_path") == file).column("pos") |
| for file in table.column("file_path").chunks[0].dictionary |
| } |
| |
| |
| def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array: |
| if len(positional_deletes) == 1: |
| all_chunks = positional_deletes[0] |
| else: |
| all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes])) |
| return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) |
| |
| |
| def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema: |
| has_ids = visit_pyarrow(schema, _HasIds()) |
| if has_ids: |
| visitor = _ConvertToIceberg() |
| elif name_mapping is not None: |
| visitor = _ConvertToIceberg(name_mapping=name_mapping) |
| else: |
| raise ValueError( |
| "Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined" |
| ) |
| return visit_pyarrow(schema, visitor) |
| |
| |
| def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema: |
| return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) |
| |
| |
| @singledispatch |
| def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T: |
| """Apply a pyarrow schema visitor to any point within a schema. |
| |
| The function traverses the schema in post-order fashion. |
| |
| Args: |
| obj (Union[pa.DataType, pa.Schema]): An instance of a Schema or an IcebergType. |
| visitor (PyArrowSchemaVisitor[T]): An instance of an implementation of the generic PyarrowSchemaVisitor base class. |
| |
| Raises: |
| NotImplementedError: If attempting to visit an unrecognized object type. |
| """ |
| raise NotImplementedError(f"Cannot visit non-type: {obj}") |
| |
| |
| @visit_pyarrow.register(pa.Schema) |
| def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T: |
| return visitor.schema(obj, visit_pyarrow(pa.struct(obj), visitor)) |
| |
| |
| @visit_pyarrow.register(pa.StructType) |
| def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> T: |
| results = [] |
| |
| for field in obj: |
| visitor.before_field(field) |
| result = visit_pyarrow(field.type, visitor) |
| results.append(visitor.field(field, result)) |
| visitor.after_field(field) |
| |
| return visitor.struct(obj, results) |
| |
| |
| @visit_pyarrow.register(pa.ListType) |
| def _(obj: pa.ListType, visitor: PyArrowSchemaVisitor[T]) -> T: |
| visitor.before_list_element(obj.value_field) |
| result = visit_pyarrow(obj.value_type, visitor) |
| visitor.after_list_element(obj.value_field) |
| |
| return visitor.list(obj, result) |
| |
| |
| @visit_pyarrow.register(pa.MapType) |
| def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> T: |
| visitor.before_map_key(obj.key_field) |
| key_result = visit_pyarrow(obj.key_type, visitor) |
| visitor.after_map_key(obj.key_field) |
| |
| visitor.before_map_value(obj.item_field) |
| value_result = visit_pyarrow(obj.item_type, visitor) |
| visitor.after_map_value(obj.item_field) |
| |
| return visitor.map(obj, key_result, value_result) |
| |
| |
| @visit_pyarrow.register(pa.DataType) |
| def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> T: |
| if pa.types.is_nested(obj): |
| raise TypeError(f"Expected primitive type, got: {type(obj)}") |
| return visitor.primitive(obj) |
| |
| |
| class PyArrowSchemaVisitor(Generic[T], ABC): |
| def before_field(self, field: pa.Field) -> None: |
| """Override this method to perform an action immediately before visiting a field.""" |
| |
| def after_field(self, field: pa.Field) -> None: |
| """Override this method to perform an action immediately after visiting a field.""" |
| |
| def before_list_element(self, element: pa.Field) -> None: |
| """Override this method to perform an action immediately before visiting an element within a ListType.""" |
| |
| def after_list_element(self, element: pa.Field) -> None: |
| """Override this method to perform an action immediately after visiting an element within a ListType.""" |
| |
| def before_map_key(self, key: pa.Field) -> None: |
| """Override this method to perform an action immediately before visiting a key within a MapType.""" |
| |
| def after_map_key(self, key: pa.Field) -> None: |
| """Override this method to perform an action immediately after visiting a key within a MapType.""" |
| |
| def before_map_value(self, value: pa.Field) -> None: |
| """Override this method to perform an action immediately before visiting a value within a MapType.""" |
| |
| def after_map_value(self, value: pa.Field) -> None: |
| """Override this method to perform an action immediately after visiting a value within a MapType.""" |
| |
| @abstractmethod |
| def schema(self, schema: pa.Schema, struct_result: T) -> T: |
| """Visit a schema.""" |
| |
| @abstractmethod |
| def struct(self, struct: pa.StructType, field_results: List[T]) -> T: |
| """Visit a struct.""" |
| |
| @abstractmethod |
| def field(self, field: pa.Field, field_result: T) -> T: |
| """Visit a field.""" |
| |
| @abstractmethod |
| def list(self, list_type: pa.ListType, element_result: T) -> T: |
| """Visit a list.""" |
| |
| @abstractmethod |
| def map(self, map_type: pa.MapType, key_result: T, value_result: T) -> T: |
| """Visit a map.""" |
| |
| @abstractmethod |
| def primitive(self, primitive: pa.DataType) -> T: |
| """Visit a primitive type.""" |
| |
| |
| def _get_field_id(field: pa.Field) -> Optional[int]: |
| return ( |
| int(field_id_str.decode()) |
| if (field.metadata and (field_id_str := field.metadata.get(PYARROW_PARQUET_FIELD_ID_KEY))) |
| else None |
| ) |
| |
| |
| class _HasIds(PyArrowSchemaVisitor[bool]): |
| def schema(self, schema: pa.Schema, struct_result: bool) -> bool: |
| return struct_result |
| |
| def struct(self, struct: pa.StructType, field_results: List[bool]) -> bool: |
| return all(field_results) |
| |
| def field(self, field: pa.Field, field_result: bool) -> bool: |
| return all([_get_field_id(field) is not None, field_result]) |
| |
| def list(self, list_type: pa.ListType, element_result: bool) -> bool: |
| element_field = list_type.value_field |
| element_id = _get_field_id(element_field) |
| return element_result and element_id is not None |
| |
| def map(self, map_type: pa.MapType, key_result: bool, value_result: bool) -> bool: |
| key_field = map_type.key_field |
| key_id = _get_field_id(key_field) |
| value_field = map_type.item_field |
| value_id = _get_field_id(value_field) |
| return all([key_id is not None, value_id is not None, key_result, value_result]) |
| |
| def primitive(self, primitive: pa.DataType) -> bool: |
| return True |
| |
| |
| class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): |
| """Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided.""" |
| |
| _field_names: List[str] |
| _name_mapping: Optional[NameMapping] |
| |
| def __init__(self, name_mapping: Optional[NameMapping] = None) -> None: |
| self._field_names = [] |
| self._name_mapping = name_mapping |
| |
| def _field_id(self, field: pa.Field) -> int: |
| if self._name_mapping: |
| return self._name_mapping.find(*self._field_names).field_id |
| elif (field_id := _get_field_id(field)) is not None: |
| return field_id |
| else: |
| raise ValueError(f"Cannot convert {field} to Iceberg Field as field_id is empty.") |
| |
| def schema(self, schema: pa.Schema, struct_result: StructType) -> Schema: |
| return Schema(*struct_result.fields) |
| |
| def struct(self, struct: pa.StructType, field_results: List[NestedField]) -> StructType: |
| return StructType(*field_results) |
| |
| def field(self, field: pa.Field, field_result: IcebergType) -> NestedField: |
| field_id = self._field_id(field) |
| field_doc = doc_str.decode() if (field.metadata and (doc_str := field.metadata.get(PYARROW_FIELD_DOC_KEY))) else None |
| field_type = field_result |
| return NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc) |
| |
| def list(self, list_type: pa.ListType, element_result: IcebergType) -> ListType: |
| element_field = list_type.value_field |
| self._field_names.append(LIST_ELEMENT_NAME) |
| element_id = self._field_id(element_field) |
| self._field_names.pop() |
| return ListType(element_id, element_result, element_required=not element_field.nullable) |
| |
| def map(self, map_type: pa.MapType, key_result: IcebergType, value_result: IcebergType) -> MapType: |
| key_field = map_type.key_field |
| self._field_names.append(MAP_KEY_NAME) |
| key_id = self._field_id(key_field) |
| self._field_names.pop() |
| value_field = map_type.item_field |
| self._field_names.append(MAP_VALUE_NAME) |
| value_id = self._field_id(value_field) |
| self._field_names.pop() |
| return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) |
| |
| def primitive(self, primitive: pa.DataType) -> PrimitiveType: |
| if pa.types.is_boolean(primitive): |
| return BooleanType() |
| elif pa.types.is_integer(primitive): |
| width = primitive.bit_width |
| if width <= 32: |
| return IntegerType() |
| elif width <= 64: |
| return LongType() |
| else: |
| # Does not exist (yet) |
| raise TypeError(f"Unsupported integer type: {primitive}") |
| elif pa.types.is_float32(primitive): |
| return FloatType() |
| elif pa.types.is_float64(primitive): |
| return DoubleType() |
| elif isinstance(primitive, pa.Decimal128Type): |
| primitive = cast(pa.Decimal128Type, primitive) |
| return DecimalType(primitive.precision, primitive.scale) |
| elif pa.types.is_string(primitive) or pa.types.is_large_string(primitive): |
| return StringType() |
| elif pa.types.is_date32(primitive): |
| return DateType() |
| elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": |
| return TimeType() |
| elif pa.types.is_timestamp(primitive): |
| primitive = cast(pa.TimestampType, primitive) |
| if primitive.unit == "us": |
| if primitive.tz == "UTC" or primitive.tz == "+00:00": |
| return TimestamptzType() |
| elif primitive.tz is None: |
| return TimestampType() |
| elif pa.types.is_binary(primitive) or pa.types.is_large_binary(primitive): |
| return BinaryType() |
| elif pa.types.is_fixed_size_binary(primitive): |
| primitive = cast(pa.FixedSizeBinaryType, primitive) |
| return FixedType(primitive.byte_width) |
| |
| raise TypeError(f"Unsupported type: {primitive}") |
| |
| def before_field(self, field: pa.Field) -> None: |
| self._field_names.append(field.name) |
| |
| def after_field(self, field: pa.Field) -> None: |
| self._field_names.pop() |
| |
| def before_list_element(self, element: pa.Field) -> None: |
| self._field_names.append(LIST_ELEMENT_NAME) |
| |
| def after_list_element(self, element: pa.Field) -> None: |
| self._field_names.pop() |
| |
| def before_map_key(self, key: pa.Field) -> None: |
| self._field_names.append(MAP_KEY_NAME) |
| |
| def after_map_key(self, element: pa.Field) -> None: |
| self._field_names.pop() |
| |
| def before_map_value(self, value: pa.Field) -> None: |
| self._field_names.append(MAP_VALUE_NAME) |
| |
| def after_map_value(self, element: pa.Field) -> None: |
| self._field_names.pop() |
| |
| |
| class _ConvertToIcebergWithoutIDs(_ConvertToIceberg): |
| """ |
| Converts PyArrowSchema to Iceberg Schema with all -1 ids. |
| |
| The schema generated through this visitor should always be |
| used in conjunction with `new_table_metadata` function to |
| assign new field ids in order. This is currently used only |
| when creating an Iceberg Schema from a PyArrow schema when |
| creating a new Iceberg table. |
| """ |
| |
| def _field_id(self, field: pa.Field) -> int: |
| return -1 |
| |
| |
| def _task_to_table( |
| fs: FileSystem, |
| task: FileScanTask, |
| bound_row_filter: BooleanExpression, |
| projected_schema: Schema, |
| projected_field_ids: Set[int], |
| positional_deletes: Optional[List[ChunkedArray]], |
| case_sensitive: bool, |
| limit: Optional[int] = None, |
| name_mapping: Optional[NameMapping] = None, |
| ) -> Optional[pa.Table]: |
| _, _, path = PyArrowFileIO.parse_location(task.file.file_path) |
| arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) |
| with fs.open_input_file(path) as fin: |
| fragment = arrow_format.make_fragment(fin) |
| physical_schema = fragment.physical_schema |
| file_schema = pyarrow_to_schema(physical_schema, name_mapping) |
| |
| pyarrow_filter = None |
| if bound_row_filter is not AlwaysTrue(): |
| translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive) |
| bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) |
| pyarrow_filter = expression_to_pyarrow(bound_file_filter) |
| |
| file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) |
| |
| if file_schema is None: |
| raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}") |
| |
| fragment_scanner = ds.Scanner.from_fragment( |
| fragment=fragment, |
| schema=physical_schema, |
| # This will push down the query to Arrow. |
| # But in case there are positional deletes, we have to apply them first |
| filter=pyarrow_filter if not positional_deletes else None, |
| columns=[col.name for col in file_project_schema.columns], |
| ) |
| |
| if positional_deletes: |
| # Create the mask of indices that we're interested in |
| indices = _combine_positional_deletes(positional_deletes, fragment.count_rows()) |
| |
| if limit: |
| if pyarrow_filter is not None: |
| # In case of the filter, we don't exactly know how many rows |
| # we need to fetch upfront, can be optimized in the future: |
| # https://github.com/apache/arrow/issues/35301 |
| arrow_table = fragment_scanner.take(indices) |
| arrow_table = arrow_table.filter(pyarrow_filter) |
| arrow_table = arrow_table.slice(0, limit) |
| else: |
| arrow_table = fragment_scanner.take(indices[0:limit]) |
| else: |
| arrow_table = fragment_scanner.take(indices) |
| # Apply the user filter |
| if pyarrow_filter is not None: |
| arrow_table = arrow_table.filter(pyarrow_filter) |
| else: |
| # If there are no deletes, we can just take the head |
| # and the user-filter is already applied |
| if limit: |
| arrow_table = fragment_scanner.head(limit) |
| else: |
| arrow_table = fragment_scanner.to_table() |
| |
| if len(arrow_table) < 1: |
| return None |
| |
| return to_requested_schema(projected_schema, file_project_schema, arrow_table) |
| |
| |
| def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: |
| deletes_per_file: Dict[str, List[ChunkedArray]] = {} |
| unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) |
| if len(unique_deletes) > 0: |
| executor = ExecutorFactory.get_or_create() |
| deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map( |
| lambda args: _read_deletes(*args), [(fs, delete) for delete in unique_deletes] |
| ) |
| for delete in deletes_per_files: |
| for file, arr in delete.items(): |
| if file in deletes_per_file: |
| deletes_per_file[file].append(arr) |
| else: |
| deletes_per_file[file] = [arr] |
| |
| return deletes_per_file |
| |
| |
| def project_table( |
| tasks: Iterable[FileScanTask], |
| table: Table, |
| row_filter: BooleanExpression, |
| projected_schema: Schema, |
| case_sensitive: bool = True, |
| limit: Optional[int] = None, |
| ) -> pa.Table: |
| """Resolve the right columns based on the identifier. |
| |
| Args: |
| tasks (Iterable[FileScanTask]): A URI or a path to a local file. |
| table (Table): The table that's being queried. |
| row_filter (BooleanExpression): The expression for filtering rows. |
| projected_schema (Schema): The output schema. |
| case_sensitive (bool): Case sensitivity when looking up column names. |
| limit (Optional[int]): Limit the number of records. |
| |
| Raises: |
| ResolveError: When an incompatible query is done. |
| """ |
| scheme, netloc, _ = PyArrowFileIO.parse_location(table.location()) |
| if isinstance(table.io, PyArrowFileIO): |
| fs = table.io.fs_by_scheme(scheme, netloc) |
| else: |
| try: |
| from pyiceberg.io.fsspec import FsspecFileIO |
| |
| if isinstance(table.io, FsspecFileIO): |
| from pyarrow.fs import PyFileSystem |
| |
| fs = PyFileSystem(FSSpecHandler(table.io.get_fs(scheme))) |
| else: |
| raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") |
| except ModuleNotFoundError as e: |
| # When FsSpec is not installed |
| raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e |
| |
| bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive) |
| |
| projected_field_ids = { |
| id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType)) |
| }.union(extract_field_ids(bound_row_filter)) |
| |
| deletes_per_file = _read_all_delete_files(fs, tasks) |
| executor = ExecutorFactory.get_or_create() |
| futures = [ |
| executor.submit( |
| _task_to_table, |
| fs, |
| task, |
| bound_row_filter, |
| projected_schema, |
| projected_field_ids, |
| deletes_per_file.get(task.file.file_path), |
| case_sensitive, |
| limit, |
| table.name_mapping(), |
| ) |
| for task in tasks |
| ] |
| total_row_count = 0 |
| # for consistent ordering, we need to maintain future order |
| futures_index = {f: i for i, f in enumerate(futures)} |
| completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f]) |
| for future in concurrent.futures.as_completed(futures): |
| completed_futures.add(future) |
| if table_result := future.result(): |
| total_row_count += len(table_result) |
| # stop early if limit is satisfied |
| if limit is not None and total_row_count >= limit: |
| break |
| |
| # by now, we've either completed all tasks or satisfied the limit |
| if limit is not None: |
| _ = [f.cancel() for f in futures if not f.done()] |
| |
| tables = [f.result() for f in completed_futures if f.result()] |
| |
| if len(tables) < 1: |
| return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema)) |
| |
| result = pa.concat_tables(tables) |
| |
| if limit is not None: |
| return result.slice(0, limit) |
| |
| return result |
| |
| |
| def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table: |
| struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) |
| |
| arrays = [] |
| fields = [] |
| for pos, field in enumerate(requested_schema.fields): |
| array = struct_array.field(pos) |
| arrays.append(array) |
| fields.append(pa.field(field.name, array.type, field.optional)) |
| return pa.Table.from_arrays(arrays, schema=pa.schema(fields)) |
| |
| |
| class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): |
| file_schema: Schema |
| |
| def __init__(self, file_schema: Schema): |
| self.file_schema = file_schema |
| |
| def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: |
| file_field = self.file_schema.find_field(field.field_id) |
| if field.field_type.is_primitive and field.field_type != file_field.field_type: |
| return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type))) |
| return values |
| |
| def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: |
| return pa.field( |
| name=field.name, |
| type=arrow_type, |
| nullable=field.optional, |
| metadata={DOC: field.doc} if field.doc is not None else None, |
| ) |
| |
| def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]: |
| return struct_result |
| |
| def struct( |
| self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]] |
| ) -> Optional[pa.Array]: |
| if struct_array is None: |
| return None |
| field_arrays: List[pa.Array] = [] |
| fields: List[pa.Field] = [] |
| for field, field_array in zip(struct.fields, field_results): |
| if field_array is not None: |
| array = self._cast_if_needed(field, field_array) |
| field_arrays.append(array) |
| fields.append(self._construct_field(field, array.type)) |
| elif field.optional: |
| arrow_type = schema_to_pyarrow(field.field_type) |
| field_arrays.append(pa.nulls(len(struct_array), type=arrow_type)) |
| fields.append(self._construct_field(field, arrow_type)) |
| else: |
| raise ResolveError(f"Field is required, and could not be found in the file: {field}") |
| |
| return pa.StructArray.from_arrays(arrays=field_arrays, fields=pa.struct(fields)) |
| |
| def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]: |
| return field_array |
| |
| def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]: |
| if isinstance(list_array, pa.ListArray) and value_array is not None: |
| if isinstance(value_array, pa.StructArray): |
| # This can be removed once this has been fixed: |
| # https://github.com/apache/arrow/issues/38809 |
| list_array = pa.ListArray.from_arrays(list_array.offsets, value_array) |
| |
| arrow_field = pa.list_(self._construct_field(list_type.element_field, value_array.type)) |
| return list_array.cast(arrow_field) |
| else: |
| return None |
| |
| def map( |
| self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array] |
| ) -> Optional[pa.Array]: |
| if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None: |
| arrow_field = pa.map_( |
| self._construct_field(map_type.key_field, key_result.type), |
| self._construct_field(map_type.value_field, value_result.type), |
| ) |
| if isinstance(value_result, pa.StructArray): |
| # Arrow does not allow reordering of fields, therefore we have to copy the array :( |
| return pa.MapArray.from_arrays(map_array.offsets, key_result, value_result, arrow_field) |
| else: |
| return map_array.cast(arrow_field) |
| else: |
| return None |
| |
| def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.Array]: |
| return array |
| |
| |
| class ArrowAccessor(PartnerAccessor[pa.Array]): |
| file_schema: Schema |
| |
| def __init__(self, file_schema: Schema): |
| self.file_schema = file_schema |
| |
| def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]: |
| return partner |
| |
| def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]: |
| if partner_struct: |
| # use the field name from the file schema |
| try: |
| name = self.file_schema.find_field(field_id).name |
| except ValueError: |
| return None |
| |
| if isinstance(partner_struct, pa.StructArray): |
| return partner_struct.field(name) |
| elif isinstance(partner_struct, pa.Table): |
| return partner_struct.column(name).combine_chunks() |
| |
| return None |
| |
| def list_element_partner(self, partner_list: Optional[pa.Array]) -> Optional[pa.Array]: |
| return partner_list.values if isinstance(partner_list, pa.ListArray) else None |
| |
| def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]: |
| return partner_map.keys if isinstance(partner_map, pa.MapArray) else None |
| |
| def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]: |
| return partner_map.items if isinstance(partner_map, pa.MapArray) else None |
| |
| |
| def _primitive_to_physical(iceberg_type: PrimitiveType) -> str: |
| return visit(iceberg_type, _PRIMITIVE_TO_PHYSICAL_TYPE_VISITOR) |
| |
| |
| class PrimitiveToPhysicalType(SchemaVisitorPerPrimitiveType[str]): |
| def schema(self, schema: Schema, struct_result: str) -> str: |
| raise ValueError(f"Expected primitive-type, got: {schema}") |
| |
| def struct(self, struct: StructType, field_results: List[str]) -> str: |
| raise ValueError(f"Expected primitive-type, got: {struct}") |
| |
| def field(self, field: NestedField, field_result: str) -> str: |
| raise ValueError(f"Expected primitive-type, got: {field}") |
| |
| def list(self, list_type: ListType, element_result: str) -> str: |
| raise ValueError(f"Expected primitive-type, got: {list_type}") |
| |
| def map(self, map_type: MapType, key_result: str, value_result: str) -> str: |
| raise ValueError(f"Expected primitive-type, got: {map_type}") |
| |
| def visit_fixed(self, fixed_type: FixedType) -> str: |
| return "FIXED_LEN_BYTE_ARRAY" |
| |
| def visit_decimal(self, decimal_type: DecimalType) -> str: |
| return "FIXED_LEN_BYTE_ARRAY" |
| |
| def visit_boolean(self, boolean_type: BooleanType) -> str: |
| return "BOOLEAN" |
| |
| def visit_integer(self, integer_type: IntegerType) -> str: |
| return "INT32" |
| |
| def visit_long(self, long_type: LongType) -> str: |
| return "INT64" |
| |
| def visit_float(self, float_type: FloatType) -> str: |
| return "FLOAT" |
| |
| def visit_double(self, double_type: DoubleType) -> str: |
| return "DOUBLE" |
| |
| def visit_date(self, date_type: DateType) -> str: |
| return "INT32" |
| |
| def visit_time(self, time_type: TimeType) -> str: |
| return "INT64" |
| |
| def visit_timestamp(self, timestamp_type: TimestampType) -> str: |
| return "INT64" |
| |
| def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> str: |
| return "INT64" |
| |
| def visit_string(self, string_type: StringType) -> str: |
| return "BYTE_ARRAY" |
| |
| def visit_uuid(self, uuid_type: UUIDType) -> str: |
| return "FIXED_LEN_BYTE_ARRAY" |
| |
| def visit_binary(self, binary_type: BinaryType) -> str: |
| return "BYTE_ARRAY" |
| |
| |
| _PRIMITIVE_TO_PHYSICAL_TYPE_VISITOR = PrimitiveToPhysicalType() |
| |
| |
| class StatsAggregator: |
| current_min: Any |
| current_max: Any |
| trunc_length: Optional[int] |
| |
| def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: Optional[int] = None) -> None: |
| self.current_min = None |
| self.current_max = None |
| self.trunc_length = trunc_length |
| |
| expected_physical_type = _primitive_to_physical(iceberg_type) |
| if expected_physical_type != physical_type_string: |
| raise ValueError( |
| f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}" |
| ) |
| |
| self.primitive_type = iceberg_type |
| |
| def serialize(self, value: Any) -> bytes: |
| return to_bytes(self.primitive_type, value) |
| |
| def update_min(self, val: Optional[Any]) -> None: |
| if self.current_min is None: |
| self.current_min = val |
| elif val is not None: |
| self.current_min = min(val, self.current_min) |
| |
| def update_max(self, val: Optional[Any]) -> None: |
| if self.current_max is None: |
| self.current_max = val |
| elif val is not None: |
| self.current_max = max(val, self.current_max) |
| |
| def min_as_bytes(self) -> Optional[bytes]: |
| if self.current_min is None: |
| return None |
| |
| return self.serialize( |
| self.current_min |
| if self.trunc_length is None |
| else TruncateTransform(width=self.trunc_length).transform(self.primitive_type)(self.current_min) |
| ) |
| |
| def max_as_bytes(self) -> Optional[bytes]: |
| if self.current_max is None: |
| return None |
| |
| if self.primitive_type == StringType(): |
| if not isinstance(self.current_max, str): |
| raise ValueError("Expected the current_max to be a string") |
| s_result = truncate_upper_bound_text_string(self.current_max, self.trunc_length) |
| return self.serialize(s_result) if s_result is not None else None |
| elif self.primitive_type == BinaryType(): |
| if not isinstance(self.current_max, bytes): |
| raise ValueError("Expected the current_max to be bytes") |
| b_result = truncate_upper_bound_binary_string(self.current_max, self.trunc_length) |
| return self.serialize(b_result) if b_result is not None else None |
| else: |
| if self.trunc_length is not None: |
| raise ValueError(f"{self.primitive_type} cannot be truncated") |
| return self.serialize(self.current_max) |
| |
| |
| DEFAULT_TRUNCATION_LENGTH = 16 |
| TRUNCATION_EXPR = r"^truncate\((\d+)\)$" |
| |
| |
| class MetricModeTypes(Enum): |
| TRUNCATE = "truncate" |
| NONE = "none" |
| COUNTS = "counts" |
| FULL = "full" |
| |
| |
| @dataclass(frozen=True) |
| class MetricsMode(Singleton): |
| type: MetricModeTypes |
| length: Optional[int] = None |
| |
| |
| def match_metrics_mode(mode: str) -> MetricsMode: |
| sanitized_mode = mode.strip().lower() |
| if sanitized_mode.startswith("truncate"): |
| m = re.match(TRUNCATION_EXPR, sanitized_mode) |
| if m: |
| length = int(m[1]) |
| if length < 1: |
| raise ValueError("Truncation length must be larger than 0") |
| return MetricsMode(MetricModeTypes.TRUNCATE, int(m[1])) |
| else: |
| raise ValueError(f"Malformed truncate: {mode}") |
| elif sanitized_mode == "none": |
| return MetricsMode(MetricModeTypes.NONE) |
| elif sanitized_mode == "counts": |
| return MetricsMode(MetricModeTypes.COUNTS) |
| elif sanitized_mode == "full": |
| return MetricsMode(MetricModeTypes.FULL) |
| else: |
| raise ValueError(f"Unsupported metrics mode: {mode}") |
| |
| |
| @dataclass(frozen=True) |
| class StatisticsCollector: |
| field_id: int |
| iceberg_type: PrimitiveType |
| mode: MetricsMode |
| column_name: str |
| |
| |
| class PyArrowStatisticsCollector(PreOrderSchemaVisitor[List[StatisticsCollector]]): |
| _field_id: int = 0 |
| _schema: Schema |
| _properties: Dict[str, str] |
| _default_mode: str |
| |
| def __init__(self, schema: Schema, properties: Dict[str, str]): |
| self._schema = schema |
| self._properties = properties |
| self._default_mode = self._properties.get( |
| TableProperties.DEFAULT_WRITE_METRICS_MODE, TableProperties.DEFAULT_WRITE_METRICS_MODE_DEFAULT |
| ) |
| |
| def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: |
| return struct_result() |
| |
| def struct( |
| self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]] |
| ) -> List[StatisticsCollector]: |
| return list(itertools.chain(*[result() for result in field_results])) |
| |
| def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: |
| self._field_id = field.field_id |
| return field_result() |
| |
| def list(self, list_type: ListType, element_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: |
| self._field_id = list_type.element_id |
| return element_result() |
| |
| def map( |
| self, |
| map_type: MapType, |
| key_result: Callable[[], List[StatisticsCollector]], |
| value_result: Callable[[], List[StatisticsCollector]], |
| ) -> List[StatisticsCollector]: |
| self._field_id = map_type.key_id |
| k = key_result() |
| self._field_id = map_type.value_id |
| v = value_result() |
| return k + v |
| |
| def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]: |
| column_name = self._schema.find_column_name(self._field_id) |
| if column_name is None: |
| return [] |
| |
| metrics_mode = match_metrics_mode(self._default_mode) |
| |
| col_mode = self._properties.get(f"{TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX}.{column_name}") |
| if col_mode: |
| metrics_mode = match_metrics_mode(col_mode) |
| |
| if ( |
| not (isinstance(primitive, StringType) or isinstance(primitive, BinaryType)) |
| and metrics_mode.type == MetricModeTypes.TRUNCATE |
| ): |
| metrics_mode = MetricsMode(MetricModeTypes.FULL) |
| |
| is_nested = column_name.find(".") >= 0 |
| |
| if is_nested and metrics_mode.type in [MetricModeTypes.TRUNCATE, MetricModeTypes.FULL]: |
| metrics_mode = MetricsMode(MetricModeTypes.COUNTS) |
| |
| return [StatisticsCollector(field_id=self._field_id, iceberg_type=primitive, mode=metrics_mode, column_name=column_name)] |
| |
| |
| def compute_statistics_plan( |
| schema: Schema, |
| table_properties: Dict[str, str], |
| ) -> Dict[int, StatisticsCollector]: |
| """ |
| Compute the statistics plan for all columns. |
| |
| The resulting list is assumed to have the same length and same order as the columns in the pyarrow table. |
| This allows the list to map from the column index to the Iceberg column ID. |
| For each element, the desired metrics collection that was provided by the user in the configuration |
| is computed and then adjusted according to the data type of the column. For nested columns the minimum |
| and maximum values are not computed. And truncation is only applied to text of binary strings. |
| |
| Args: |
| table_properties (from pyiceberg.table.metadata.TableMetadata): The Iceberg table metadata properties. |
| They are required to compute the mapping of column position to iceberg schema type id. It's also |
| used to set the mode for column metrics collection |
| """ |
| stats_cols = pre_order_visit(schema, PyArrowStatisticsCollector(schema, table_properties)) |
| result: Dict[int, StatisticsCollector] = {} |
| for stats_col in stats_cols: |
| result[stats_col.field_id] = stats_col |
| return result |
| |
| |
| @dataclass(frozen=True) |
| class ID2ParquetPath: |
| field_id: int |
| parquet_path: str |
| |
| |
| class ID2ParquetPathVisitor(PreOrderSchemaVisitor[List[ID2ParquetPath]]): |
| _field_id: int = 0 |
| _path: List[str] |
| |
| def __init__(self) -> None: |
| self._path = [] |
| |
| def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: |
| return struct_result() |
| |
| def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]: |
| return list(itertools.chain(*[result() for result in field_results])) |
| |
| def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: |
| self._field_id = field.field_id |
| self._path.append(field.name) |
| result = field_result() |
| self._path.pop() |
| return result |
| |
| def list(self, list_type: ListType, element_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: |
| self._field_id = list_type.element_id |
| self._path.append("list.element") |
| result = element_result() |
| self._path.pop() |
| return result |
| |
| def map( |
| self, |
| map_type: MapType, |
| key_result: Callable[[], List[ID2ParquetPath]], |
| value_result: Callable[[], List[ID2ParquetPath]], |
| ) -> List[ID2ParquetPath]: |
| self._field_id = map_type.key_id |
| self._path.append("key_value.key") |
| k = key_result() |
| self._path.pop() |
| self._field_id = map_type.value_id |
| self._path.append("key_value.value") |
| v = value_result() |
| self._path.pop() |
| return k + v |
| |
| def primitive(self, primitive: PrimitiveType) -> List[ID2ParquetPath]: |
| return [ID2ParquetPath(field_id=self._field_id, parquet_path=".".join(self._path))] |
| |
| |
| def parquet_path_to_id_mapping( |
| schema: Schema, |
| ) -> Dict[str, int]: |
| """ |
| Compute the mapping of parquet column path to Iceberg ID. |
| |
| For each column, the parquet file metadata has a path_in_schema attribute that follows |
| a specific naming scheme for nested columnds. This function computes a mapping of |
| the full paths to the corresponding Iceberg IDs. |
| |
| Args: |
| schema (pyiceberg.schema.Schema): The current table schema. |
| """ |
| result: Dict[str, int] = {} |
| for pair in pre_order_visit(schema, ID2ParquetPathVisitor()): |
| result[pair.parquet_path] = pair.field_id |
| return result |
| |
| |
| def fill_parquet_file_metadata( |
| data_file: DataFile, |
| parquet_metadata: pq.FileMetaData, |
| stats_columns: Dict[int, StatisticsCollector], |
| parquet_column_mapping: Dict[str, int], |
| ) -> None: |
| """ |
| Compute and fill the following fields of the DataFile object. |
| |
| - file_format |
| - column_sizes |
| - value_counts |
| - null_value_counts |
| - nan_value_counts |
| - lower_bounds |
| - upper_bounds |
| - split_offsets |
| |
| Args: |
| data_file (DataFile): A DataFile object representing the Parquet file for which metadata is to be filled. |
| parquet_metadata (pyarrow.parquet.FileMetaData): A pyarrow metadata object. |
| stats_columns (Dict[int, StatisticsCollector]): The statistics gathering plan. It is required to |
| set the mode for column metrics collection |
| """ |
| if parquet_metadata.num_columns != len(stats_columns): |
| raise ValueError( |
| f"Number of columns in statistics configuration ({len(stats_columns)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" |
| ) |
| |
| if parquet_metadata.num_columns != len(parquet_column_mapping): |
| raise ValueError( |
| f"Number of columns in column mapping ({len(parquet_column_mapping)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" |
| ) |
| |
| column_sizes: Dict[int, int] = {} |
| value_counts: Dict[int, int] = {} |
| split_offsets: List[int] = [] |
| |
| null_value_counts: Dict[int, int] = {} |
| nan_value_counts: Dict[int, int] = {} |
| |
| col_aggs = {} |
| |
| for r in range(parquet_metadata.num_row_groups): |
| # References: |
| # https://github.com/apache/iceberg/blob/fc381a81a1fdb8f51a0637ca27cd30673bd7aad3/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L232 |
| # https://github.com/apache/parquet-mr/blob/ac29db4611f86a07cc6877b416aa4b183e09b353/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/metadata/ColumnChunkMetaData.java#L184 |
| |
| row_group = parquet_metadata.row_group(r) |
| |
| data_offset = row_group.column(0).data_page_offset |
| dictionary_offset = row_group.column(0).dictionary_page_offset |
| |
| if row_group.column(0).has_dictionary_page and dictionary_offset < data_offset: |
| split_offsets.append(dictionary_offset) |
| else: |
| split_offsets.append(data_offset) |
| |
| invalidate_col: Set[int] = set() |
| |
| for pos in range(parquet_metadata.num_columns): |
| column = row_group.column(pos) |
| field_id = parquet_column_mapping[column.path_in_schema] |
| |
| stats_col = stats_columns[field_id] |
| |
| column_sizes.setdefault(field_id, 0) |
| column_sizes[field_id] += column.total_compressed_size |
| |
| if stats_col.mode == MetricsMode(MetricModeTypes.NONE): |
| continue |
| |
| value_counts[field_id] = value_counts.get(field_id, 0) + column.num_values |
| |
| if column.is_stats_set: |
| try: |
| statistics = column.statistics |
| |
| if statistics.has_null_count: |
| null_value_counts[field_id] = null_value_counts.get(field_id, 0) + statistics.null_count |
| |
| if stats_col.mode == MetricsMode(MetricModeTypes.COUNTS): |
| continue |
| |
| if field_id not in col_aggs: |
| col_aggs[field_id] = StatsAggregator( |
| stats_col.iceberg_type, statistics.physical_type, stats_col.mode.length |
| ) |
| |
| col_aggs[field_id].update_min(statistics.min) |
| col_aggs[field_id].update_max(statistics.max) |
| |
| except pyarrow.lib.ArrowNotImplementedError as e: |
| invalidate_col.add(field_id) |
| logger.warning(e) |
| else: |
| invalidate_col.add(field_id) |
| logger.warning("PyArrow statistics missing for column %d when writing file", pos) |
| |
| split_offsets.sort() |
| |
| lower_bounds = {} |
| upper_bounds = {} |
| |
| for k, agg in col_aggs.items(): |
| _min = agg.min_as_bytes() |
| if _min is not None: |
| lower_bounds[k] = _min |
| _max = agg.max_as_bytes() |
| if _max is not None: |
| upper_bounds[k] = _max |
| |
| for field_id in invalidate_col: |
| del lower_bounds[field_id] |
| del upper_bounds[field_id] |
| del null_value_counts[field_id] |
| |
| data_file.record_count = parquet_metadata.num_rows |
| data_file.column_sizes = column_sizes |
| data_file.value_counts = value_counts |
| data_file.null_value_counts = null_value_counts |
| data_file.nan_value_counts = nan_value_counts |
| data_file.lower_bounds = lower_bounds |
| data_file.upper_bounds = upper_bounds |
| data_file.split_offsets = split_offsets |
| |
| |
| def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: |
| task = next(tasks) |
| |
| try: |
| _ = next(tasks) |
| # If there are more tasks, raise an exception |
| raise NotImplementedError("Only unpartitioned writes are supported: https://github.com/apache/iceberg-python/issues/208") |
| except StopIteration: |
| pass |
| |
| parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties) |
| |
| file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}' |
| file_schema = schema_to_pyarrow(table.schema()) |
| |
| fo = table.io.new_output(file_path) |
| row_group_size = PropertyUtil.property_as_int( |
| properties=table.properties, |
| property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, |
| default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT, |
| ) |
| with fo.create(overwrite=True) as fos: |
| with pq.ParquetWriter(fos, schema=file_schema, **parquet_writer_kwargs) as writer: |
| writer.write_table(task.df, row_group_size=row_group_size) |
| |
| data_file = DataFile( |
| content=DataFileContent.DATA, |
| file_path=file_path, |
| file_format=FileFormat.PARQUET, |
| partition=Record(), |
| file_size_in_bytes=len(fo), |
| # After this has been fixed: |
| # https://github.com/apache/iceberg-python/issues/271 |
| # sort_order_id=task.sort_order_id, |
| sort_order_id=None, |
| # Just copy these from the table for now |
| spec_id=table.spec().spec_id, |
| equality_ids=None, |
| key_metadata=None, |
| ) |
| |
| fill_parquet_file_metadata( |
| data_file=data_file, |
| parquet_metadata=writer.writer.metadata, |
| stats_columns=compute_statistics_plan(table.schema(), table.properties), |
| parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), |
| ) |
| return iter([data_file]) |
| |
| |
| ICEBERG_UNCOMPRESSED_CODEC = "uncompressed" |
| PYARROW_UNCOMPRESSED_CODEC = "none" |
| |
| |
| def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]: |
| for key_pattern in [ |
| TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, |
| TableProperties.PARQUET_PAGE_ROW_LIMIT, |
| TableProperties.PARQUET_BLOOM_FILTER_MAX_BYTES, |
| f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX}.*", |
| ]: |
| if unsupported_keys := fnmatch.filter(table_properties, key_pattern): |
| raise NotImplementedError(f"Parquet writer option(s) {unsupported_keys} not implemented") |
| |
| compression_codec = table_properties.get(TableProperties.PARQUET_COMPRESSION, TableProperties.PARQUET_COMPRESSION_DEFAULT) |
| compression_level = PropertyUtil.property_as_int( |
| properties=table_properties, |
| property_name=TableProperties.PARQUET_COMPRESSION_LEVEL, |
| default=TableProperties.PARQUET_COMPRESSION_LEVEL_DEFAULT, |
| ) |
| if compression_codec == ICEBERG_UNCOMPRESSED_CODEC: |
| compression_codec = PYARROW_UNCOMPRESSED_CODEC |
| |
| return { |
| "compression": compression_codec, |
| "compression_level": compression_level, |
| "data_page_size": PropertyUtil.property_as_int( |
| properties=table_properties, |
| property_name=TableProperties.PARQUET_PAGE_SIZE_BYTES, |
| default=TableProperties.PARQUET_PAGE_SIZE_BYTES_DEFAULT, |
| ), |
| "dictionary_pagesize_limit": PropertyUtil.property_as_int( |
| properties=table_properties, |
| property_name=TableProperties.PARQUET_DICT_SIZE_BYTES, |
| default=TableProperties.PARQUET_DICT_SIZE_BYTES_DEFAULT, |
| ), |
| "write_batch_size": PropertyUtil.property_as_int( |
| properties=table_properties, |
| property_name=TableProperties.PARQUET_PAGE_ROW_LIMIT, |
| default=TableProperties.PARQUET_PAGE_ROW_LIMIT_DEFAULT, |
| ), |
| } |