| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=protected-access,unused-argument,redefined-outer-name |
| import logging |
| import os |
| import tempfile |
| import uuid |
| import warnings |
| from datetime import date, datetime, timezone |
| from pathlib import Path |
| from typing import Any |
| from unittest.mock import MagicMock, patch |
| from uuid import uuid4 |
| |
| import pyarrow |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| import pyarrow.parquet as pq |
| import pytest |
| from packaging import version |
| from pyarrow.fs import AwsDefaultS3RetryStrategy, FileType, LocalFileSystem, S3FileSystem |
| |
| from pyiceberg.exceptions import ResolveError |
| from pyiceberg.expressions import ( |
| AlwaysFalse, |
| AlwaysTrue, |
| And, |
| BooleanExpression, |
| BoundEqualTo, |
| BoundGreaterThan, |
| BoundGreaterThanOrEqual, |
| BoundIn, |
| BoundIsNaN, |
| BoundIsNull, |
| BoundLessThan, |
| BoundLessThanOrEqual, |
| BoundNotEqualTo, |
| BoundNotIn, |
| BoundNotNaN, |
| BoundNotNull, |
| BoundNotStartsWith, |
| BoundReference, |
| BoundStartsWith, |
| GreaterThan, |
| Not, |
| Or, |
| ) |
| from pyiceberg.expressions.literals import literal |
| from pyiceberg.io import S3_RETRY_STRATEGY_IMPL, InputStream, OutputStream, load_file_io |
| from pyiceberg.io.pyarrow import ( |
| ICEBERG_SCHEMA, |
| PYARROW_PARQUET_FIELD_ID_KEY, |
| ArrowScan, |
| PyArrowFile, |
| PyArrowFileIO, |
| StatsAggregator, |
| _check_pyarrow_schema_compatible, |
| _ConvertToArrowSchema, |
| _determine_partitions, |
| _primitive_to_physical, |
| _read_deletes, |
| _task_to_record_batches, |
| _to_requested_schema, |
| bin_pack_arrow_table, |
| compute_statistics_plan, |
| data_file_statistics_from_parquet_metadata, |
| expression_to_pyarrow, |
| parquet_path_to_id_mapping, |
| schema_to_pyarrow, |
| write_file, |
| ) |
| from pyiceberg.manifest import DataFile, DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionField, PartitionSpec |
| from pyiceberg.schema import Schema, make_compatible_name, visit |
| from pyiceberg.table import FileScanTask, TableProperties |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.table.name_mapping import create_mapping_from_schema |
| from pyiceberg.transforms import HourTransform, IdentityTransform |
| from pyiceberg.typedef import UTF8, Properties, Record, TableVersion |
| from pyiceberg.types import ( |
| BinaryType, |
| BooleanType, |
| DateType, |
| DecimalType, |
| DoubleType, |
| FixedType, |
| FloatType, |
| GeographyType, |
| GeometryType, |
| IntegerType, |
| ListType, |
| LongType, |
| MapType, |
| NestedField, |
| PrimitiveType, |
| StringType, |
| StructType, |
| TimestampNanoType, |
| TimestampType, |
| TimestamptzType, |
| TimeType, |
| ) |
| from tests.catalog.test_base import InMemoryCatalog |
| from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES |
| |
| skip_if_pyarrow_too_old = pytest.mark.skipif( |
| version.parse(pyarrow.__version__) < version.parse("20.0.0"), |
| reason="Requires pyarrow version >= 20.0.0", |
| ) |
| |
| |
| def test_pyarrow_infer_local_fs_from_path() -> None: |
| """Test path with `file` scheme and no scheme both use LocalFileSystem""" |
| assert isinstance(PyArrowFileIO().new_output("file://tmp/warehouse")._filesystem, LocalFileSystem) |
| assert isinstance(PyArrowFileIO().new_output("/tmp/warehouse")._filesystem, LocalFileSystem) |
| |
| |
| def test_pyarrow_local_fs_can_create_path_without_parent_dir() -> None: |
| """Test LocalFileSystem can create path without first creating the parent directories""" |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_path = f"{tmpdirname}/foo/bar/baz.txt" |
| output_file = PyArrowFileIO().new_output(file_path) |
| parent_path = os.path.dirname(file_path) |
| assert output_file._filesystem.get_file_info(parent_path).type == FileType.NotFound |
| try: |
| with output_file.create() as f: |
| f.write(b"foo") |
| except Exception: |
| pytest.fail("Failed to write to file without parent directory") |
| |
| |
| def test_pyarrow_input_file() -> None: |
| """Test reading a file using PyArrowFile""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| with open(file_location, "wb") as f: |
| f.write(b"foo") |
| |
| # Confirm that the file initially exists |
| assert os.path.exists(file_location) |
| |
| # Instantiate the input file |
| absolute_file_location = os.path.abspath(file_location) |
| input_file = PyArrowFileIO().new_input(location=f"{absolute_file_location}") |
| |
| # Test opening and reading the file |
| r = input_file.open(seekable=False) |
| assert isinstance(r, InputStream) # Test that the file object abides by the InputStream protocol |
| data = r.read() |
| assert data == b"foo" |
| assert len(input_file) == 3 |
| with pytest.raises(OSError) as exc_info: |
| r.seek(0, 0) |
| assert "only valid on seekable files" in str(exc_info.value) |
| |
| |
| def test_pyarrow_input_file_seekable() -> None: |
| """Test reading a file using PyArrowFile""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| with open(file_location, "wb") as f: |
| f.write(b"foo") |
| |
| # Confirm that the file initially exists |
| assert os.path.exists(file_location) |
| |
| # Instantiate the input file |
| absolute_file_location = os.path.abspath(file_location) |
| input_file = PyArrowFileIO().new_input(location=f"{absolute_file_location}") |
| |
| # Test opening and reading the file |
| r = input_file.open(seekable=True) |
| assert isinstance(r, InputStream) # Test that the file object abides by the InputStream protocol |
| data = r.read() |
| assert data == b"foo" |
| assert len(input_file) == 3 |
| r.seek(0, 0) |
| data = r.read() |
| assert data == b"foo" |
| assert len(input_file) == 3 |
| |
| |
| def test_pyarrow_output_file() -> None: |
| """Test writing a file using PyArrowFile""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| |
| # Instantiate the output file |
| absolute_file_location = os.path.abspath(file_location) |
| output_file = PyArrowFileIO().new_output(location=f"{absolute_file_location}") |
| |
| # Create the output file and write to it |
| f = output_file.create() |
| assert isinstance(f, OutputStream) # Test that the file object abides by the OutputStream protocol |
| f.write(b"foo") |
| |
| # Confirm that bytes were written |
| with open(file_location, "rb") as f: |
| assert f.read() == b"foo" |
| |
| assert len(output_file) == 3 |
| |
| |
| def test_pyarrow_invalid_scheme() -> None: |
| """Test that a ValueError is raised if a location is provided with an invalid scheme""" |
| |
| with pytest.raises(ValueError) as exc_info: |
| PyArrowFileIO().new_input("foo://bar/baz.txt") |
| |
| assert "Unrecognized filesystem type in URI" in str(exc_info.value) |
| |
| with pytest.raises(ValueError) as exc_info: |
| PyArrowFileIO().new_output("foo://bar/baz.txt") |
| |
| assert "Unrecognized filesystem type in URI" in str(exc_info.value) |
| |
| |
| def test_pyarrow_violating_input_stream_protocol() -> None: |
| """Test that a TypeError is raised if an input file is provided that violates the InputStream protocol""" |
| |
| # Missing seek, tell, closed, and close |
| input_file_mock = MagicMock(spec=["read"]) |
| |
| # Create a mocked filesystem that returns input_file_mock |
| filesystem_mock = MagicMock() |
| filesystem_mock.open_input_file.return_value = input_file_mock |
| |
| input_file = PyArrowFile("foo.txt", path="foo.txt", fs=filesystem_mock) |
| |
| f = input_file.open() |
| assert not isinstance(f, InputStream) |
| |
| |
| def test_pyarrow_violating_output_stream_protocol() -> None: |
| """Test that a TypeError is raised if an output stream is provided that violates the OutputStream protocol""" |
| |
| # Missing closed, and close |
| output_file_mock = MagicMock(spec=["write", "exists"]) |
| output_file_mock.exists.return_value = False |
| |
| file_info_mock = MagicMock() |
| file_info_mock.type = FileType.NotFound |
| |
| # Create a mocked filesystem that returns output_file_mock |
| filesystem_mock = MagicMock() |
| filesystem_mock.open_output_stream.return_value = output_file_mock |
| filesystem_mock.get_file_info.return_value = file_info_mock |
| |
| output_file = PyArrowFile("foo.txt", path="foo.txt", fs=filesystem_mock) |
| |
| f = output_file.create() |
| |
| assert not isinstance(f, OutputStream) |
| |
| |
| def test_raise_on_opening_a_local_file_not_found() -> None: |
| """Test that a PyArrowFile raises appropriately when a local file is not found""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| f = PyArrowFileIO().new_input(file_location) |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| f.open() |
| |
| assert "[Errno 2] Failed to open local file" in str(exc_info.value) |
| |
| |
| def test_raise_on_opening_an_s3_file_no_permission() -> None: |
| """Test that opening a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.open_input_file.side_effect = OSError("AWS Error [code 15]") |
| |
| f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock) |
| |
| with pytest.raises(PermissionError) as exc_info: |
| f.open() |
| |
| assert "Cannot open file, access denied:" in str(exc_info.value) |
| |
| |
| def test_raise_on_opening_an_s3_file_not_found() -> None: |
| """Test that a PyArrowFile raises a FileNotFoundError when the pyarrow error includes 'Path does not exist'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.open_input_file.side_effect = OSError("Path does not exist") |
| |
| f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock) |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| f.open() |
| |
| assert "Cannot open file, does not exist:" in str(exc_info.value) |
| |
| |
| @patch("pyiceberg.io.pyarrow.PyArrowFile.exists", return_value=False) |
| def test_raise_on_creating_an_s3_file_no_permission(_: Any) -> None: |
| """Test that creating a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.open_output_stream.side_effect = OSError("AWS Error [code 15]") |
| |
| f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock) |
| |
| with pytest.raises(PermissionError) as exc_info: |
| f.create() |
| |
| assert "Cannot create file, access denied:" in str(exc_info.value) |
| |
| |
| def test_deleting_s3_file_no_permission() -> None: |
| """Test that a PyArrowFile raises a PermissionError when the pyarrow OSError includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.delete_file.side_effect = OSError("AWS Error [code 15]") |
| |
| with patch.object(PyArrowFileIO, "_initialize_fs") as submocked: |
| submocked.return_value = s3fs_mock |
| |
| with pytest.raises(PermissionError) as exc_info: |
| PyArrowFileIO().delete("s3://foo/bar.txt") |
| |
| assert "Cannot delete file, access denied:" in str(exc_info.value) |
| |
| |
| def test_deleting_s3_file_not_found() -> None: |
| """Test that a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.delete_file.side_effect = OSError("Path does not exist") |
| |
| with patch.object(PyArrowFileIO, "_initialize_fs") as submocked: |
| submocked.return_value = s3fs_mock |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| PyArrowFileIO().delete("s3://foo/bar.txt") |
| |
| assert "Cannot delete file, does not exist:" in str(exc_info.value) |
| |
| |
| def test_deleting_hdfs_file_not_found() -> None: |
| """Test that a PyArrowFile raises a PermissionError when the pyarrow error includes 'No such file or directory'""" |
| |
| hdfs_mock = MagicMock() |
| hdfs_mock.delete_file.side_effect = OSError("Path does not exist") |
| |
| with patch.object(PyArrowFileIO, "_initialize_fs") as submocked: |
| submocked.return_value = hdfs_mock |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| PyArrowFileIO().delete("hdfs://foo/bar.txt") |
| |
| assert "Cannot delete file, does not exist:" in str(exc_info.value) |
| |
| |
| def test_pyarrow_s3_session_properties() -> None: |
| session_properties: Properties = { |
| "s3.endpoint": "http://localhost:9000", |
| "s3.access-key-id": "admin", |
| "s3.secret-access-key": "password", |
| "s3.region": "us-east-1", |
| "s3.session-token": "s3.session-token", |
| **UNIFIED_AWS_SESSION_PROPERTIES, |
| } |
| |
| with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| s3_fileio = PyArrowFileIO(properties=session_properties) |
| filename = str(uuid.uuid4()) |
| |
| # Mock `resolve_s3_region` to prevent from the location used resolving to a different s3 region |
| mock_s3_region_resolver.side_effect = OSError("S3 bucket is not found") |
| s3_fileio.new_input(location=f"s3://warehouse/{filename}") |
| |
| mock_s3fs.assert_called_with( |
| endpoint_override="http://localhost:9000", |
| access_key="admin", |
| secret_key="password", |
| region="us-east-1", |
| session_token="s3.session-token", |
| ) |
| |
| |
| def test_pyarrow_s3_session_properties_with_anonymous() -> None: |
| session_properties: Properties = { |
| "s3.anonymous": "true", |
| "s3.endpoint": "http://localhost:9000", |
| "s3.access-key-id": "admin", |
| "s3.secret-access-key": "password", |
| "s3.region": "us-east-1", |
| "s3.session-token": "s3.session-token", |
| **UNIFIED_AWS_SESSION_PROPERTIES, |
| } |
| |
| with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| s3_fileio = PyArrowFileIO(properties=session_properties) |
| filename = str(uuid.uuid4()) |
| |
| # Mock `resolve_s3_region` to prevent from the location used resolving to a different s3 region |
| mock_s3_region_resolver.side_effect = OSError("S3 bucket is not found") |
| s3_fileio.new_input(location=f"s3://warehouse/{filename}") |
| |
| mock_s3fs.assert_called_with( |
| anonymous=True, |
| endpoint_override="http://localhost:9000", |
| access_key="admin", |
| secret_key="password", |
| region="us-east-1", |
| session_token="s3.session-token", |
| ) |
| |
| |
| def test_pyarrow_unified_session_properties() -> None: |
| session_properties: Properties = { |
| "s3.endpoint": "http://localhost:9000", |
| **UNIFIED_AWS_SESSION_PROPERTIES, |
| } |
| |
| with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| s3_fileio = PyArrowFileIO(properties=session_properties) |
| filename = str(uuid.uuid4()) |
| |
| mock_s3_region_resolver.return_value = "client.region" |
| s3_fileio.new_input(location=f"s3://warehouse/{filename}") |
| |
| mock_s3fs.assert_called_with( |
| endpoint_override="http://localhost:9000", |
| access_key="client.access-key-id", |
| secret_key="client.secret-access-key", |
| region="client.region", |
| session_token="client.session-token", |
| ) |
| |
| |
| def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None: |
| actual = schema_to_pyarrow(table_schema_nested) |
| expected = """foo: large_string |
| -- field metadata -- |
| PARQUET:field_id: '1' |
| bar: int32 not null |
| -- field metadata -- |
| PARQUET:field_id: '2' |
| baz: bool |
| -- field metadata -- |
| PARQUET:field_id: '3' |
| qux: large_list<element: large_string not null> not null |
| child 0, element: large_string not null |
| -- field metadata -- |
| PARQUET:field_id: '5' |
| -- field metadata -- |
| PARQUET:field_id: '4' |
| quux: map<large_string, map<large_string, int32>> not null |
| child 0, entries: struct<key: large_string not null, value: map<large_string, int32> not null> not null |
| child 0, key: large_string not null |
| -- field metadata -- |
| PARQUET:field_id: '7' |
| child 1, value: map<large_string, int32> not null |
| child 0, entries: struct<key: large_string not null, value: int32 not null> not null |
| child 0, key: large_string not null |
| -- field metadata -- |
| PARQUET:field_id: '9' |
| child 1, value: int32 not null |
| -- field metadata -- |
| PARQUET:field_id: '10' |
| -- field metadata -- |
| PARQUET:field_id: '8' |
| -- field metadata -- |
| PARQUET:field_id: '6' |
| location: large_list<element: struct<latitude: float, longitude: float> not null> not null |
| child 0, element: struct<latitude: float, longitude: float> not null |
| child 0, latitude: float |
| -- field metadata -- |
| PARQUET:field_id: '13' |
| child 1, longitude: float |
| -- field metadata -- |
| PARQUET:field_id: '14' |
| -- field metadata -- |
| PARQUET:field_id: '12' |
| -- field metadata -- |
| PARQUET:field_id: '11' |
| person: struct<name: large_string, age: int32 not null> |
| child 0, name: large_string |
| -- field metadata -- |
| PARQUET:field_id: '16' |
| child 1, age: int32 not null |
| -- field metadata -- |
| PARQUET:field_id: '17' |
| -- field metadata -- |
| PARQUET:field_id: '15'""" |
| assert repr(actual) == expected |
| |
| |
| def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: Schema) -> None: |
| actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False) |
| expected = """foo: large_string |
| bar: int32 not null |
| baz: bool |
| qux: large_list<element: large_string not null> not null |
| child 0, element: large_string not null |
| quux: map<large_string, map<large_string, int32>> not null |
| child 0, entries: struct<key: large_string not null, value: map<large_string, int32> not null> not null |
| child 0, key: large_string not null |
| child 1, value: map<large_string, int32> not null |
| child 0, entries: struct<key: large_string not null, value: int32 not null> not null |
| child 0, key: large_string not null |
| child 1, value: int32 not null |
| location: large_list<element: struct<latitude: float, longitude: float> not null> not null |
| child 0, element: struct<latitude: float, longitude: float> not null |
| child 0, latitude: float |
| child 1, longitude: float |
| person: struct<name: large_string, age: int32 not null> |
| child 0, name: large_string |
| child 1, age: int32 not null""" |
| assert repr(actual) == expected |
| |
| |
| def test_fixed_type_to_pyarrow() -> None: |
| length = 22 |
| iceberg_type = FixedType(length) |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.binary(length) |
| |
| |
| def test_decimal_type_to_pyarrow() -> None: |
| precision = 25 |
| scale = 19 |
| iceberg_type = DecimalType(precision, scale) |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.decimal128(precision, scale) |
| |
| |
| def test_boolean_type_to_pyarrow() -> None: |
| iceberg_type = BooleanType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.bool_() |
| |
| |
| def test_integer_type_to_pyarrow() -> None: |
| iceberg_type = IntegerType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.int32() |
| |
| |
| def test_long_type_to_pyarrow() -> None: |
| iceberg_type = LongType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.int64() |
| |
| |
| def test_float_type_to_pyarrow() -> None: |
| iceberg_type = FloatType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.float32() |
| |
| |
| def test_double_type_to_pyarrow() -> None: |
| iceberg_type = DoubleType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.float64() |
| |
| |
| def test_date_type_to_pyarrow() -> None: |
| iceberg_type = DateType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.date32() |
| |
| |
| def test_time_type_to_pyarrow() -> None: |
| iceberg_type = TimeType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.time64("us") |
| |
| |
| def test_timestamp_type_to_pyarrow() -> None: |
| iceberg_type = TimestampType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.timestamp(unit="us") |
| |
| |
| def test_timestamptz_type_to_pyarrow() -> None: |
| iceberg_type = TimestamptzType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.timestamp(unit="us", tz="UTC") |
| |
| |
| def test_string_type_to_pyarrow() -> None: |
| iceberg_type = StringType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_string() |
| |
| |
| def test_binary_type_to_pyarrow() -> None: |
| iceberg_type = BinaryType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_binary() |
| |
| |
| def test_geometry_type_to_pyarrow_without_geoarrow() -> None: |
| """Test geometry type falls back to large_binary when geoarrow is not available.""" |
| import sys |
| |
| iceberg_type = GeometryType() |
| |
| # Remove geoarrow from sys.modules if present and block re-import |
| saved_modules = {} |
| for mod_name in list(sys.modules.keys()): |
| if mod_name.startswith("geoarrow"): |
| saved_modules[mod_name] = sys.modules.pop(mod_name) |
| |
| import builtins |
| |
| original_import = builtins.__import__ |
| |
| def mock_import(name: str, *args: Any, **kwargs: Any) -> Any: |
| if name.startswith("geoarrow"): |
| raise ImportError(f"No module named '{name}'") |
| return original_import(name, *args, **kwargs) |
| |
| try: |
| builtins.__import__ = mock_import |
| result = visit(iceberg_type, _ConvertToArrowSchema()) |
| assert result == pa.large_binary() |
| finally: |
| builtins.__import__ = original_import |
| sys.modules.update(saved_modules) |
| |
| |
| def test_geography_type_to_pyarrow_without_geoarrow() -> None: |
| """Test geography type falls back to large_binary when geoarrow is not available.""" |
| import sys |
| |
| iceberg_type = GeographyType() |
| |
| # Remove geoarrow from sys.modules if present and block re-import |
| saved_modules = {} |
| for mod_name in list(sys.modules.keys()): |
| if mod_name.startswith("geoarrow"): |
| saved_modules[mod_name] = sys.modules.pop(mod_name) |
| |
| import builtins |
| |
| original_import = builtins.__import__ |
| |
| def mock_import(name: str, *args: Any, **kwargs: Any) -> Any: |
| if name.startswith("geoarrow"): |
| raise ImportError(f"No module named '{name}'") |
| return original_import(name, *args, **kwargs) |
| |
| try: |
| builtins.__import__ = mock_import |
| result = visit(iceberg_type, _ConvertToArrowSchema()) |
| assert result == pa.large_binary() |
| finally: |
| builtins.__import__ = original_import |
| sys.modules.update(saved_modules) |
| |
| |
| def test_geometry_type_to_pyarrow_with_geoarrow() -> None: |
| """Test geometry type uses geoarrow WKB extension type when available.""" |
| pytest.importorskip("geoarrow.pyarrow") |
| import geoarrow.pyarrow as ga |
| |
| # Test default CRS |
| iceberg_type = GeometryType() |
| result = visit(iceberg_type, _ConvertToArrowSchema()) |
| expected = ga.wkb().with_crs("OGC:CRS84") |
| assert result == expected |
| |
| # Test custom CRS |
| iceberg_type_custom = GeometryType("EPSG:4326") |
| result_custom = visit(iceberg_type_custom, _ConvertToArrowSchema()) |
| expected_custom = ga.wkb().with_crs("EPSG:4326") |
| assert result_custom == expected_custom |
| |
| |
| def test_geography_type_to_pyarrow_with_geoarrow() -> None: |
| """Test geography type uses geoarrow WKB extension type with edge type when available.""" |
| pytest.importorskip("geoarrow.pyarrow") |
| import geoarrow.pyarrow as ga |
| |
| # Test default (spherical algorithm) |
| iceberg_type = GeographyType() |
| result = visit(iceberg_type, _ConvertToArrowSchema()) |
| expected = ga.wkb().with_crs("OGC:CRS84").with_edge_type(ga.EdgeType.SPHERICAL) |
| assert result == expected |
| |
| # Test custom CRS with spherical |
| iceberg_type_custom = GeographyType("EPSG:4326", "spherical") |
| result_custom = visit(iceberg_type_custom, _ConvertToArrowSchema()) |
| expected_custom = ga.wkb().with_crs("EPSG:4326").with_edge_type(ga.EdgeType.SPHERICAL) |
| assert result_custom == expected_custom |
| |
| # Test planar algorithm (no edge type set, uses default) |
| iceberg_type_planar = GeographyType("OGC:CRS84", "planar") |
| result_planar = visit(iceberg_type_planar, _ConvertToArrowSchema()) |
| expected_planar = ga.wkb().with_crs("OGC:CRS84") |
| assert result_planar == expected_planar |
| |
| |
| def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None: |
| expected = pa.struct( |
| [ |
| pa.field("foo", pa.large_string(), nullable=True, metadata={"field_id": "1"}), |
| pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), |
| pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), |
| ] |
| ) |
| assert visit(table_schema_simple.as_struct(), _ConvertToArrowSchema()) == expected |
| |
| |
| def test_map_type_to_pyarrow() -> None: |
| iceberg_map = MapType( |
| key_id=1, |
| key_type=IntegerType(), |
| value_id=2, |
| value_type=StringType(), |
| value_required=True, |
| ) |
| assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_( |
| pa.field("key", pa.int32(), nullable=False, metadata={"field_id": "1"}), |
| pa.field("value", pa.large_string(), nullable=False, metadata={"field_id": "2"}), |
| ) |
| |
| |
| def test_list_type_to_pyarrow() -> None: |
| iceberg_map = ListType( |
| element_id=1, |
| element_type=IntegerType(), |
| element_required=True, |
| ) |
| assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.large_list( |
| pa.field("element", pa.int32(), nullable=False, metadata={"field_id": "1"}) |
| ) |
| |
| |
| @pytest.fixture |
| def bound_reference(table_schema_simple: Schema) -> BoundReference: |
| return BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)) |
| |
| |
| @pytest.fixture |
| def bound_double_reference() -> BoundReference: |
| schema = Schema( |
| NestedField(field_id=1, name="foo", field_type=DoubleType(), required=False), |
| schema_id=1, |
| identifier_field_ids=[], |
| ) |
| return BoundReference(schema.find_field(1), schema.accessor_for_field(1)) |
| |
| |
| def test_expr_is_null_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundIsNull(bound_reference))) |
| == "<pyarrow.compute.Expression is_null(foo, {nan_is_null=false})>" |
| ) |
| |
| |
| def test_expr_not_null_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert repr(expression_to_pyarrow(BoundNotNull(bound_reference))) == "<pyarrow.compute.Expression is_valid(foo)>" |
| |
| |
| def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference) -> None: |
| assert repr(expression_to_pyarrow(BoundIsNaN(bound_double_reference))) == "<pyarrow.compute.Expression is_nan(foo)>" |
| |
| |
| def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference) -> None: |
| assert repr(expression_to_pyarrow(BoundNotNaN(bound_double_reference))) == "<pyarrow.compute.Expression invert(is_nan(foo))>" |
| |
| |
| def test_expr_equal_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundEqualTo(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo == "hello")>' |
| ) |
| |
| |
| def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundNotEqualTo(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo != "hello")>' |
| ) |
| |
| |
| def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundGreaterThanOrEqual(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo >= "hello")>' |
| ) |
| |
| |
| def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundGreaterThan(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo > "hello")>' |
| ) |
| |
| |
| def test_expr_less_than_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundLessThan(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo < "hello")>' |
| ) |
| |
| |
| def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundLessThanOrEqual(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo <= "hello")>' |
| ) |
| |
| |
| def test_expr_in_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert repr(expression_to_pyarrow(BoundIn(bound_reference, {literal("hello"), literal("world")}))) in ( |
| """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[ |
| "hello", |
| "world" |
| ], null_matching_behavior=MATCH})>""", |
| """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[ |
| "world", |
| "hello" |
| ], null_matching_behavior=MATCH})>""", |
| ) |
| |
| |
| def test_expr_not_in_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert repr(expression_to_pyarrow(BoundNotIn(bound_reference, {literal("hello"), literal("world")}))) in ( |
| """<pyarrow.compute.Expression invert(is_in(foo, {value_set=large_string:[ |
| "hello", |
| "world" |
| ], null_matching_behavior=MATCH}))>""", |
| """<pyarrow.compute.Expression invert(is_in(foo, {value_set=large_string:[ |
| "world", |
| "hello" |
| ], null_matching_behavior=MATCH}))>""", |
| ) |
| |
| |
| def test_expr_starts_with_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundStartsWith(bound_reference, literal("he")))) |
| == '<pyarrow.compute.Expression starts_with(foo, {pattern="he", ignore_case=false})>' |
| ) |
| |
| |
| def test_expr_not_starts_with_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundNotStartsWith(bound_reference, literal("he")))) |
| == '<pyarrow.compute.Expression invert(starts_with(foo, {pattern="he", ignore_case=false}))>' |
| ) |
| |
| |
| def test_and_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(And(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) |
| == '<pyarrow.compute.Expression ((foo == "hello") and is_null(foo, {nan_is_null=false}))>' |
| ) |
| |
| |
| def test_or_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(Or(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) |
| == '<pyarrow.compute.Expression ((foo == "hello") or is_null(foo, {nan_is_null=false}))>' |
| ) |
| |
| |
| def test_not_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert ( |
| repr(expression_to_pyarrow(Not(BoundEqualTo(bound_reference, literal("hello"))))) |
| == '<pyarrow.compute.Expression invert((foo == "hello"))>' |
| ) |
| |
| |
| def test_always_true_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert repr(expression_to_pyarrow(AlwaysTrue())) == "<pyarrow.compute.Expression true>" |
| |
| |
| def test_always_false_to_pyarrow(bound_reference: BoundReference) -> None: |
| assert repr(expression_to_pyarrow(AlwaysFalse())) == "<pyarrow.compute.Expression false>" |
| |
| |
| @pytest.fixture |
| def schema_int() -> Schema: |
| return Schema(NestedField(1, "id", IntegerType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_int_str() -> Schema: |
| return Schema(NestedField(1, "id", IntegerType(), required=False), NestedField(2, "data", StringType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_str() -> Schema: |
| return Schema(NestedField(2, "data", StringType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_long() -> Schema: |
| return Schema(NestedField(3, "id", LongType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_struct() -> Schema: |
| return Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType()), |
| NestedField(42, "long", DoubleType()), |
| ), |
| ) |
| ) |
| |
| |
| @pytest.fixture |
| def schema_list() -> Schema: |
| return Schema( |
| NestedField(5, "ids", ListType(51, IntegerType(), element_required=False), required=False), |
| ) |
| |
| |
| @pytest.fixture |
| def schema_list_of_structs() -> Schema: |
| return Schema( |
| NestedField( |
| 5, |
| "locations", |
| ListType( |
| 51, |
| StructType(NestedField(511, "lat", DoubleType()), NestedField(512, "long", DoubleType())), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def schema_map_of_structs() -> Schema: |
| return Schema( |
| NestedField( |
| 5, |
| "locations", |
| MapType( |
| key_id=51, |
| value_id=52, |
| key_type=StringType(), |
| value_type=StructType( |
| NestedField(511, "lat", DoubleType(), required=True), NestedField(512, "long", DoubleType(), required=True) |
| ), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def schema_map() -> Schema: |
| return Schema( |
| NestedField( |
| 5, |
| "properties", |
| MapType( |
| key_id=51, |
| key_type=StringType(), |
| value_id=52, |
| value_type=StringType(), |
| value_required=True, |
| ), |
| required=False, |
| ), |
| ) |
| |
| |
| def _write_table_to_file(filepath: str, schema: pa.Schema, table: pa.Table) -> str: |
| with pq.ParquetWriter(filepath, schema) as writer: |
| writer.write_table(table) |
| return filepath |
| |
| |
| def _write_table_to_data_file(filepath: str, schema: pa.Schema, table: pa.Table) -> DataFile: |
| filepath = _write_table_to_file(filepath, schema, table) |
| return DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=filepath, |
| file_format=FileFormat.PARQUET, |
| partition={}, |
| record_count=len(table), |
| file_size_in_bytes=22, # This is not relevant for now |
| ) |
| |
| |
| @pytest.fixture |
| def file_int(schema_int: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_int, metadata={ICEBERG_SCHEMA: bytes(schema_int.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/a.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema) |
| ) |
| |
| |
| @pytest.fixture |
| def file_int_str(schema_int_str: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_int_str, metadata={ICEBERG_SCHEMA: bytes(schema_int_str.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/a.parquet", |
| pyarrow_schema, |
| pa.Table.from_arrays([pa.array([0, 1, 2]), pa.array(["0", "1", "2"])], schema=pyarrow_schema), |
| ) |
| |
| |
| @pytest.fixture |
| def file_string(schema_str: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_str, metadata={ICEBERG_SCHEMA: bytes(schema_str.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/b.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array(["0", "1", "2"])], schema=pyarrow_schema) |
| ) |
| |
| |
| @pytest.fixture |
| def file_long(schema_long: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_long, metadata={ICEBERG_SCHEMA: bytes(schema_long.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/c.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema) |
| ) |
| |
| |
| @pytest.fixture |
| def file_struct(schema_struct: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_struct, metadata={ICEBERG_SCHEMA: bytes(schema_struct.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/d.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"location": {"lat": 52.371807, "long": 4.896029}}, |
| {"location": {"lat": 52.387386, "long": 4.646219}}, |
| {"location": {"lat": 52.078663, "long": 4.288788}}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_list(schema_list: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_list, metadata={ICEBERG_SCHEMA: bytes(schema_list.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"ids": list(range(1, 10))}, |
| {"ids": list(range(2, 20))}, |
| {"ids": list(range(3, 30))}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_list_of_structs(schema_list_of_structs: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow( |
| schema_list_of_structs, metadata={ICEBERG_SCHEMA: bytes(schema_list_of_structs.model_dump_json(), UTF8)} |
| ) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"locations": [{"lat": 52.371807, "long": 4.896029}, {"lat": 52.387386, "long": 4.646219}]}, |
| {"locations": []}, |
| {"locations": [{"lat": 52.078663, "long": 4.288788}, {"lat": 52.387386, "long": 4.646219}]}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_map_of_structs(schema_map_of_structs: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow( |
| schema_map_of_structs, metadata={ICEBERG_SCHEMA: bytes(schema_map_of_structs.model_dump_json(), UTF8)} |
| ) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"locations": {"1": {"lat": 52.371807, "long": 4.896029}, "2": {"lat": 52.387386, "long": 4.646219}}}, |
| {"locations": {}}, |
| {"locations": {"3": {"lat": 52.078663, "long": 4.288788}, "4": {"lat": 52.387386, "long": 4.646219}}}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_map(schema_map: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_map, metadata={ICEBERG_SCHEMA: bytes(schema_map.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"properties": [("a", "b")]}, |
| {"properties": [("c", "d")]}, |
| {"properties": [("e", "f"), ("g", "h")]}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| def project( |
| schema: Schema, files: list[str], expr: BooleanExpression | None = None, table_schema: Schema | None = None |
| ) -> pa.Table: |
| def _set_spec_id(datafile: DataFile) -> DataFile: |
| datafile.spec_id = 0 |
| return datafile |
| |
| return ArrowScan( |
| table_metadata=TableMetadataV2( |
| location="file://a/b/", |
| last_column_id=1, |
| format_version=2, |
| schemas=[table_schema or schema], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=PyArrowFileIO(), |
| projected_schema=schema, |
| row_filter=expr or AlwaysTrue(), |
| case_sensitive=True, |
| ).to_table( |
| tasks=[ |
| FileScanTask( |
| _set_spec_id( |
| DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=file, |
| file_format=FileFormat.PARQUET, |
| partition={}, |
| record_count=3, |
| file_size_in_bytes=3, |
| ) |
| ) |
| ) |
| for file in files |
| ] |
| ) |
| |
| |
| def test_projection_add_column(file_int: str) -> None: |
| schema = Schema( |
| # All new IDs |
| NestedField(10, "id", IntegerType(), required=False), |
| NestedField(20, "list", ListType(21, IntegerType(), element_required=False), required=False), |
| NestedField( |
| 30, |
| "map", |
| MapType(key_id=31, key_type=IntegerType(), value_id=32, value_type=StringType(), value_required=False), |
| required=False, |
| ), |
| NestedField( |
| 40, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "lon", DoubleType(), required=False) |
| ), |
| required=False, |
| ), |
| ) |
| result_table = project(schema, [file_int]) |
| |
| for col in result_table.columns: |
| assert len(col) == 3 |
| |
| for actual, expected in zip(result_table.columns[0], [None, None, None], strict=True): |
| assert actual.as_py() == expected |
| |
| for actual, expected in zip(result_table.columns[1], [None, None, None], strict=True): |
| assert actual.as_py() == expected |
| |
| for actual, expected in zip(result_table.columns[2], [None, None, None], strict=True): |
| assert actual.as_py() == expected |
| |
| for actual, expected in zip(result_table.columns[3], [None, None, None], strict=True): |
| assert actual.as_py() == expected |
| assert ( |
| repr(result_table.schema) |
| == """id: int32 |
| list: large_list<element: int32> |
| child 0, element: int32 |
| map: map<int32, large_string> |
| child 0, entries: struct<key: int32 not null, value: large_string> not null |
| child 0, key: int32 not null |
| child 1, value: large_string |
| location: struct<lat: double, lon: double> |
| child 0, lat: double |
| child 1, lon: double""" |
| ) |
| |
| |
| def test_read_list(schema_list: Schema, file_list: str) -> None: |
| result_table = project(schema_list, [file_list]) |
| |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip( |
| result_table.columns[0], [list(range(1, 10)), list(range(2, 20)), list(range(3, 30))], strict=True |
| ): |
| assert actual.as_py() == expected |
| |
| assert ( |
| repr(result_table.schema) |
| == """ids: large_list<element: int32> |
| child 0, element: int32""" |
| ) |
| |
| |
| def test_read_map(schema_map: Schema, file_map: str) -> None: |
| result_table = project(schema_map, [file_map]) |
| |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip(result_table.columns[0], [[("a", "b")], [("c", "d")], [("e", "f"), ("g", "h")]], strict=True): |
| assert actual.as_py() == expected |
| |
| assert ( |
| repr(result_table.schema) |
| == """properties: map<string, string> |
| child 0, entries: struct<key: string not null, value: string not null> not null |
| child 0, key: string not null |
| child 1, value: string not null""" |
| ) |
| |
| |
| def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None: |
| schema = Schema( |
| # A new ID |
| NestedField( |
| 2, |
| "id", |
| MapType(key_id=3, key_type=IntegerType(), value_id=4, value_type=StringType(), value_required=False), |
| required=False, |
| ) |
| ) |
| result_table = project(schema, [file_int]) |
| # Everything should be None |
| for r in result_table.columns[0]: |
| assert r.as_py() is None |
| assert ( |
| repr(result_table.schema) |
| == """id: map<int32, large_string> |
| child 0, entries: struct<key: int32 not null, value: large_string> not null |
| child 0, key: int32 not null |
| child 1, value: large_string""" |
| ) |
| |
| |
| def test_projection_add_column_struct_required(file_int: str) -> None: |
| schema = Schema( |
| # A new ID |
| NestedField( |
| 2, |
| "other_id", |
| IntegerType(), |
| required=True, |
| ) |
| ) |
| with pytest.raises(ResolveError) as exc_info: |
| _ = project(schema, [file_int]) |
| assert "Field is required, and could not be found in the file: 2: other_id: required int" in str(exc_info.value) |
| |
| |
| def test_projection_rename_column(schema_int: Schema, file_int: str) -> None: |
| schema = Schema( |
| # Reuses the id 1 |
| NestedField(1, "other_name", IntegerType(), required=True) |
| ) |
| result_table = project(schema, [file_int]) |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2], strict=True): |
| assert actual.as_py() == expected |
| |
| assert repr(result_table.schema) == "other_name: int32 not null" |
| |
| |
| def test_projection_concat_files(schema_int: Schema, file_int: str) -> None: |
| result_table = project(schema_int, [file_int, file_int]) |
| |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2, 0, 1, 2], strict=True): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 6 |
| assert repr(result_table.schema) == "id: int32" |
| |
| |
| def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: |
| # Test by adding a non-partitioned data file to a partitioned table, verifying partition value |
| # projection from manifest metadata. |
| # TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs. |
| # (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875) |
| |
| schema = Schema( |
| NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False) |
| ) |
| |
| partition_spec = PartitionSpec( |
| PartitionField(2, 1000, IdentityTransform(), "partition_id"), |
| ) |
| |
| catalog.create_namespace("default") |
| table = catalog.create_table( |
| "default.test_projection_partition", |
| schema=schema, |
| partition_spec=partition_spec, |
| properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()}, |
| ) |
| |
| file_data = pa.array(["foo", "bar", "baz"], type=pa.string()) |
| file_loc = f"{tmp_path}/test.parquet" |
| pq.write_table(pa.table([file_data], names=["other_field"]), file_loc) |
| |
| statistics = data_file_statistics_from_parquet_metadata( |
| parquet_metadata=pq.read_metadata(file_loc), |
| stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties), |
| parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), |
| ) |
| |
| unpartitioned_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=file_loc, |
| file_format=FileFormat.PARQUET, |
| # projected value |
| partition=Record(1), |
| file_size_in_bytes=os.path.getsize(file_loc), |
| sort_order_id=None, |
| spec_id=table.metadata.default_spec_id, |
| equality_ids=None, |
| key_metadata=None, |
| **statistics.to_serialized_dict(), |
| ) |
| |
| with table.transaction() as transaction: |
| with transaction.update_snapshot().overwrite() as update: |
| update.append_data_file(unpartitioned_file) |
| |
| schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int32())]) |
| assert table.scan().to_arrow() == pa.table( |
| { |
| "other_field": ["foo", "bar", "baz"], |
| "partition_id": [1, 1, 1], |
| }, |
| schema=schema, |
| ) |
| # Test that row filter works with partition value projection |
| assert table.scan(row_filter="partition_id = 1").to_arrow() == pa.table( |
| { |
| "other_field": ["foo", "bar", "baz"], |
| "partition_id": [1, 1, 1], |
| }, |
| schema=schema, |
| ) |
| # Test that row filter does not return any rows for a non-existing partition value |
| assert len(table.scan(row_filter="partition_id = -1").to_arrow()) == 0 |
| |
| |
| def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: |
| # Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value |
| # projection from manifest metadata. |
| # TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs. |
| # (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875) |
| schema = Schema( |
| NestedField(1, "field_1", StringType(), required=False), |
| NestedField(2, "field_2", IntegerType(), required=False), |
| NestedField(3, "field_3", IntegerType(), required=False), |
| ) |
| |
| partition_spec = PartitionSpec( |
| PartitionField(2, 1000, IdentityTransform(), "field_2"), |
| PartitionField(3, 1001, IdentityTransform(), "field_3"), |
| ) |
| |
| catalog.create_namespace("default") |
| table = catalog.create_table( |
| "default.test_projection_partitions", |
| schema=schema, |
| partition_spec=partition_spec, |
| properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()}, |
| ) |
| |
| file_data = pa.array(["foo"], type=pa.string()) |
| file_loc = f"{tmp_path}/test.parquet" |
| pq.write_table(pa.table([file_data], names=["field_1"]), file_loc) |
| |
| statistics = data_file_statistics_from_parquet_metadata( |
| parquet_metadata=pq.read_metadata(file_loc), |
| stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties), |
| parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), |
| ) |
| |
| unpartitioned_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=file_loc, |
| file_format=FileFormat.PARQUET, |
| # projected value |
| partition=Record(2, 3), |
| file_size_in_bytes=os.path.getsize(file_loc), |
| sort_order_id=None, |
| spec_id=table.metadata.default_spec_id, |
| equality_ids=None, |
| key_metadata=None, |
| **statistics.to_serialized_dict(), |
| ) |
| |
| with table.transaction() as transaction: |
| with transaction.update_snapshot().overwrite() as update: |
| update.append_data_file(unpartitioned_file) |
| |
| assert ( |
| str(table.scan().to_arrow()) |
| == """pyarrow.Table |
| field_1: string |
| field_2: int32 |
| field_3: int32 |
| ---- |
| field_1: [["foo"]] |
| field_2: [[2]] |
| field_3: [[3]]""" |
| ) |
| |
| |
| @pytest.fixture |
| def catalog() -> InMemoryCatalog: |
| return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"}) |
| |
| |
| def test_projection_filter(schema_int: Schema, file_int: str) -> None: |
| result_table = project(schema_int, [file_int], GreaterThan("id", 4)) |
| assert len(result_table.columns[0]) == 0 |
| assert repr(result_table.schema) == """id: int32""" |
| |
| |
| def test_projection_filter_renamed_column(file_int: str) -> None: |
| schema = Schema( |
| # Reuses the id 1 |
| NestedField(1, "other_id", IntegerType(), required=True) |
| ) |
| result_table = project(schema, [file_int], GreaterThan("other_id", 1)) |
| assert len(result_table.columns[0]) == 1 |
| assert repr(result_table.schema) == "other_id: int32 not null" |
| |
| |
| def test_projection_filter_add_column(schema_int: Schema, file_int: str, file_string: str) -> None: |
| """We have one file that has the column, and the other one doesn't""" |
| result_table = project(schema_int, [file_int, file_string]) |
| |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2, None, None, None], strict=True): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 6 |
| assert repr(result_table.schema) == "id: int32" |
| |
| |
| def test_projection_filter_add_column_promote(file_int: str) -> None: |
| schema_long = Schema(NestedField(1, "id", LongType(), required=True)) |
| result_table = project(schema_long, [file_int]) |
| |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2], strict=True): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 3 |
| assert repr(result_table.schema) == "id: int64 not null" |
| |
| |
| def test_projection_filter_add_column_demote(file_long: str) -> None: |
| schema_int = Schema(NestedField(3, "id", IntegerType())) |
| with pytest.raises(ResolveError) as exc_info: |
| _ = project(schema_int, [file_long]) |
| assert "Cannot promote long to int" in str(exc_info.value) |
| |
| |
| def test_projection_nested_struct_subset(file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=True), |
| # long is missing! |
| ), |
| required=True, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| |
| for actual, expected in zip(result_table.columns[0], [52.371807, 52.387386, 52.078663], strict=True): |
| assert actual.as_py() == {"lat": expected} |
| |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<lat: double not null> not null |
| child 0, lat: double not null""" |
| ) |
| |
| |
| def test_projection_nested_new_field(file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(43, "null", DoubleType(), required=False), # Whoa, this column doesn't exist in the file |
| ), |
| required=True, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| |
| for actual, expected in zip(result_table.columns[0], [None, None, None], strict=True): |
| assert actual.as_py() == {"null": expected} |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<null: double> not null |
| child 0, null: double""" |
| ) |
| |
| |
| def test_projection_nested_struct(schema_struct: Schema, file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=False), |
| NestedField(43, "null", DoubleType(), required=False), |
| NestedField(42, "long", DoubleType(), required=False), |
| ), |
| required=True, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| for actual, expected in zip( |
| result_table.columns[0], |
| [ |
| {"lat": 52.371807, "long": 4.896029, "null": None}, |
| {"lat": 52.387386, "long": 4.646219, "null": None}, |
| {"lat": 52.078663, "long": 4.288788, "null": None}, |
| ], |
| strict=True, |
| ): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<lat: double, null: double, long: double> not null |
| child 0, lat: double |
| child 1, null: double |
| child 2, long: double""" |
| ) |
| |
| |
| def test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of_structs: str) -> None: |
| schema = Schema( |
| NestedField( |
| 5, |
| "locations", |
| ListType( |
| 51, |
| StructType( |
| NestedField(511, "latitude", DoubleType(), required=True), |
| NestedField(512, "longitude", DoubleType(), required=True), |
| NestedField(513, "altitude", DoubleType(), required=False), |
| ), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| result_table = project(schema, [file_list_of_structs]) |
| assert len(result_table.columns) == 1 |
| assert len(result_table.columns[0]) == 3 |
| results = [row.as_py() for row in result_table.columns[0]] |
| assert results == [ |
| [ |
| {"latitude": 52.371807, "longitude": 4.896029, "altitude": None}, |
| {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}, |
| ], |
| [], |
| [ |
| {"latitude": 52.078663, "longitude": 4.288788, "altitude": None}, |
| {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}, |
| ], |
| ] |
| assert ( |
| repr(result_table.schema) |
| == """locations: large_list<element: struct<latitude: double not null, longitude: double not null, altitude: double>> |
| child 0, element: struct<latitude: double not null, longitude: double not null, altitude: double> |
| child 0, latitude: double not null |
| child 1, longitude: double not null |
| child 2, altitude: double""" |
| ) |
| |
| |
| def test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_structs: str) -> None: |
| schema = Schema( |
| NestedField( |
| 5, |
| "locations", |
| MapType( |
| key_id=51, |
| value_id=52, |
| key_type=StringType(), |
| value_type=StructType( |
| NestedField(511, "latitude", DoubleType(), required=True), |
| NestedField(512, "longitude", DoubleType(), required=True), |
| NestedField(513, "altitude", DoubleType()), |
| ), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| result_table = project(schema, [file_map_of_structs]) |
| assert len(result_table.columns) == 1 |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip( |
| result_table.columns[0], |
| [ |
| [ |
| ("1", {"latitude": 52.371807, "longitude": 4.896029, "altitude": None}), |
| ("2", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}), |
| ], |
| [], |
| [ |
| ("3", {"latitude": 52.078663, "longitude": 4.288788, "altitude": None}), |
| ("4", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}), |
| ], |
| ], |
| strict=True, |
| ): |
| assert actual.as_py() == expected |
| expected_schema_repr = ( |
| "locations: map<string, struct<latitude: double not null, " |
| "longitude: double not null, altitude: double>>\n" |
| " child 0, entries: struct<key: string not null, value: struct<latitude: double not null, " |
| "longitude: double not null, al (... 25 chars omitted) not null\n" |
| " child 0, key: string not null\n" |
| " child 1, value: struct<latitude: double not null, longitude: double not null, " |
| "altitude: double> not null\n" |
| " child 0, latitude: double not null\n" |
| " child 1, longitude: double not null\n" |
| " child 2, altitude: double" |
| ) |
| assert repr(result_table.schema) == expected_schema_repr |
| |
| |
| def test_projection_nested_struct_different_parent_id(file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 5, # 😱 this is 4 in the file, this will be fixed when projecting the file schema |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "long", DoubleType(), required=False) |
| ), |
| required=False, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| for actual, expected in zip(result_table.columns[0], [None, None, None], strict=True): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<lat: double, long: double> |
| child 0, lat: double |
| child 1, long: double""" |
| ) |
| |
| |
| def test_projection_filter_on_unprojected_field(schema_int_str: Schema, file_int_str: str) -> None: |
| schema = Schema(NestedField(1, "id", IntegerType(), required=True)) |
| |
| result_table = project(schema, [file_int_str], GreaterThan("data", "1"), schema_int_str) |
| |
| for actual, expected in zip(result_table.columns[0], [2], strict=True): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 1 |
| assert repr(result_table.schema) == "id: int32 not null" |
| |
| |
| def test_projection_filter_on_unknown_field(schema_int_str: Schema, file_int_str: str) -> None: |
| schema = Schema(NestedField(1, "id", IntegerType())) |
| |
| with pytest.raises(ValueError) as exc_info: |
| _ = project(schema, [file_int_str], GreaterThan("unknown_field", "1"), schema_int_str) |
| |
| assert "Could not find field with name unknown_field, case_sensitive=True" in str(exc_info.value) |
| |
| |
| @pytest.fixture(params=["parquet", "orc"]) |
| def deletes_file(tmp_path: str, request: pytest.FixtureRequest) -> str: |
| if request.param == "parquet": |
| example_task = request.getfixturevalue("example_task") |
| import pyarrow.parquet as pq |
| |
| write_func = pq.write_table |
| file_ext = "parquet" |
| else: # orc |
| example_task = request.getfixturevalue("example_task_orc") |
| import pyarrow.orc as orc |
| |
| write_func = orc.write_table |
| file_ext = "orc" |
| |
| path = example_task.file.file_path |
| table = pa.table({"file_path": [path, path, path], "pos": [1, 3, 5]}) |
| |
| deletes_file_path = f"{tmp_path}/deletes.{file_ext}" |
| write_func(table, deletes_file_path) |
| |
| return deletes_file_path |
| |
| |
| def test_read_deletes(deletes_file: str, request: pytest.FixtureRequest) -> None: |
| # Determine file format from the file extension |
| file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC |
| |
| # Get the appropriate example_task fixture based on file format |
| if file_format == FileFormat.PARQUET: |
| request.getfixturevalue("example_task") |
| else: |
| request.getfixturevalue("example_task_orc") |
| |
| deletes = _read_deletes(PyArrowFileIO(), DataFile.from_args(file_path=deletes_file, file_format=file_format)) |
| # Get the expected file path from the actual deletes keys since they might differ between formats |
| expected_file_path = list(deletes.keys())[0] |
| assert set(deletes.keys()) == {expected_file_path} |
| assert list(deletes.values())[0] == pa.chunked_array([[1, 3, 5]]) |
| |
| |
| def test_delete(deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema) -> None: |
| # Determine file format from the file extension |
| file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC |
| |
| # Get the appropriate example_task fixture based on file format |
| if file_format == FileFormat.PARQUET: |
| example_task = request.getfixturevalue("example_task") |
| else: |
| example_task = request.getfixturevalue("example_task_orc") |
| |
| metadata_location = "file://a/b/c.json" |
| example_task_with_delete = FileScanTask( |
| data_file=example_task.file, |
| delete_files={ |
| DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=file_format) |
| }, |
| ) |
| with_deletes = ArrowScan( |
| table_metadata=TableMetadataV2( |
| location=metadata_location, |
| last_column_id=1, |
| format_version=2, |
| current_schema_id=1, |
| schemas=[table_schema_simple], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=load_file_io(), |
| projected_schema=table_schema_simple, |
| row_filter=AlwaysTrue(), |
| ).to_table(tasks=[example_task_with_delete]) |
| |
| # ORC uses 'string' while Parquet uses 'large_string' for string columns |
| expected_foo_type = "string" if file_format == FileFormat.ORC else "large_string" |
| expected_str = f"""pyarrow.Table |
| foo: {expected_foo_type} |
| bar: int32 not null |
| baz: bool |
| ---- |
| foo: [["a","c"]] |
| bar: [[1,3]] |
| baz: [[true,null]]""" |
| |
| assert str(with_deletes) == expected_str |
| |
| |
| def test_delete_duplicates(deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema) -> None: |
| # Determine file format from the file extension |
| file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC |
| |
| # Get the appropriate example_task fixture based on file format |
| if file_format == FileFormat.PARQUET: |
| example_task = request.getfixturevalue("example_task") |
| else: |
| example_task = request.getfixturevalue("example_task_orc") |
| |
| metadata_location = "file://a/b/c.json" |
| example_task_with_delete = FileScanTask( |
| data_file=example_task.file, |
| delete_files={ |
| DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=file_format), |
| DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=file_format), |
| }, |
| ) |
| |
| with_deletes = ArrowScan( |
| table_metadata=TableMetadataV2( |
| location=metadata_location, |
| last_column_id=1, |
| format_version=2, |
| current_schema_id=1, |
| schemas=[table_schema_simple], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=load_file_io(), |
| projected_schema=table_schema_simple, |
| row_filter=AlwaysTrue(), |
| ).to_table(tasks=[example_task_with_delete]) |
| |
| # ORC uses 'string' while Parquet uses 'large_string' for string columns |
| expected_foo_type = "string" if file_format == FileFormat.ORC else "large_string" |
| expected_str = f"""pyarrow.Table |
| foo: {expected_foo_type} |
| bar: int32 not null |
| baz: bool |
| ---- |
| foo: [["a","c"]] |
| bar: [[1,3]] |
| baz: [[true,null]]""" |
| |
| assert str(with_deletes) == expected_str |
| |
| |
| def test_pyarrow_wrap_fsspec(example_task: FileScanTask, table_schema_simple: Schema) -> None: |
| metadata_location = "file://a/b/c.json" |
| |
| projection = ArrowScan( |
| table_metadata=TableMetadataV2( |
| location=metadata_location, |
| last_column_id=1, |
| format_version=2, |
| current_schema_id=1, |
| schemas=[table_schema_simple], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=load_file_io(properties={"py-io-impl": "pyiceberg.io.fsspec.FsspecFileIO"}, location=metadata_location), |
| case_sensitive=True, |
| projected_schema=table_schema_simple, |
| row_filter=AlwaysTrue(), |
| ).to_table(tasks=[example_task]) |
| |
| assert ( |
| str(projection) |
| == """pyarrow.Table |
| foo: large_string |
| bar: int32 not null |
| baz: bool |
| ---- |
| foo: [["a","b","c"]] |
| bar: [[1,2,3]] |
| baz: [[true,false,null]]""" |
| ) |
| |
| |
| @pytest.mark.gcs |
| def test_new_input_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test creating a new input file from a fsspec file-io""" |
| filename = str(uuid4()) |
| |
| input_file = pyarrow_fileio_gcs.new_input(f"gs://warehouse/{filename}") |
| |
| assert isinstance(input_file, PyArrowFile) |
| assert input_file.location == f"gs://warehouse/{filename}" |
| |
| |
| @pytest.mark.gcs |
| def test_new_output_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test creating a new output file from an fsspec file-io""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_gcs.new_output(f"gs://warehouse/{filename}") |
| |
| assert isinstance(output_file, PyArrowFile) |
| assert output_file.location == f"gs://warehouse/{filename}" |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_write_and_read_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test writing and reading a file using PyArrowFile""" |
| location = f"gs://warehouse/{uuid4()}.txt" |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| with output_file.create() as f: |
| assert f.write(b"foo") == 3 |
| |
| assert output_file.exists() |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=location) |
| with input_file.open() as f: |
| assert f.read() == b"foo" |
| |
| pyarrow_fileio_gcs.delete(input_file) |
| |
| |
| @pytest.mark.gcs |
| def test_getting_length_of_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test getting the length of PyArrowFile""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}") |
| with output_file.create() as f: |
| f.write(b"foobar") |
| |
| assert len(output_file) == 6 |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}") |
| assert len(input_file) == 6 |
| |
| pyarrow_fileio_gcs.delete(output_file) |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_file_tell_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| location = f"gs://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foobar") |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert f.tell() == 0 |
| f.seek(1) |
| assert f.tell() == 1 |
| f.seek(3) |
| assert f.tell() == 3 |
| f.seek(0) |
| assert f.tell() == 0 |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_read_specified_bytes_for_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| location = f"gs://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert b"f" == f.read(1) |
| f.seek(0) |
| assert b"fo" == f.read(2) |
| f.seek(1) |
| assert b"o" == f.read(1) |
| f.seek(1) |
| assert b"oo" == f.read(2) |
| f.seek(0) |
| assert b"foo" == f.read(999) # test reading amount larger than entire content length |
| |
| pyarrow_fileio_gcs.delete(input_file) |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_raise_on_opening_file_not_found_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test that PyArrowFile raises appropriately when the gcs file is not found""" |
| |
| filename = str(uuid4()) |
| input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}") |
| with pytest.raises(FileNotFoundError) as exc_info: |
| input_file.open().read() |
| |
| assert filename in str(exc_info.value) |
| |
| |
| @pytest.mark.gcs |
| def test_checking_if_a_file_exists_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test checking if a file exists""" |
| non_existent_file = pyarrow_fileio_gcs.new_input(location="gs://warehouse/does-not-exist.txt") |
| assert not non_existent_file.exists() |
| |
| location = f"gs://warehouse/{uuid4()}" |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| assert not output_file.exists() |
| with output_file.create() as f: |
| f.write(b"foo") |
| |
| existing_input_file = pyarrow_fileio_gcs.new_input(location=location) |
| assert existing_input_file.exists() |
| |
| existing_output_file = pyarrow_fileio_gcs.new_output(location=location) |
| assert existing_output_file.exists() |
| |
| pyarrow_fileio_gcs.delete(existing_output_file) |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_closing_a_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test closing an output file and input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}") |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| assert not write_file.closed # type: ignore |
| assert write_file.closed # type: ignore |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}") |
| with input_file.open() as f: |
| assert not f.closed # type: ignore |
| assert f.closed # type: ignore |
| |
| pyarrow_fileio_gcs.delete(f"gs://warehouse/{filename}") |
| |
| |
| @pytest.mark.gcs |
| def test_converting_an_outputfile_to_an_inputfile_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test converting an output file to an input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}") |
| input_file = output_file.to_input_file() |
| assert input_file.location == output_file.location |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_writing_avro_file_gcs(generated_manifest_entry_file: str, pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test that bytes match when reading a local avro file, writing it using pyarrow file-io, and then reading it again""" |
| filename = str(uuid4()) |
| with PyArrowFileIO().new_input(location=generated_manifest_entry_file).open() as f: |
| b1 = f.read() |
| with pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}").create() as out_f: |
| out_f.write(b1) |
| with pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}").open() as in_f: |
| b2 = in_f.read() |
| assert b1 == b2 # Check that bytes of read from local avro file match bytes written to s3 |
| |
| pyarrow_fileio_gcs.delete(f"gs://warehouse/{filename}") |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_new_input_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test creating a new input file from pyarrow file-io""" |
| filename = str(uuid4()) |
| |
| input_file = pyarrow_fileio_adls.new_input(f"{adls_scheme}://warehouse/{filename}") |
| |
| assert isinstance(input_file, PyArrowFile) |
| assert input_file.location == f"{adls_scheme}://warehouse/{filename}" |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_new_output_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test creating a new output file from pyarrow file-io""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_adls.new_output(f"{adls_scheme}://warehouse/{filename}") |
| |
| assert isinstance(output_file, PyArrowFile) |
| assert output_file.location == f"{adls_scheme}://warehouse/{filename}" |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_write_and_read_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test writing and reading a file using PyArrowFile""" |
| location = f"{adls_scheme}://warehouse/{uuid4()}.txt" |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| with output_file.create() as f: |
| assert f.write(b"foo") == 3 |
| |
| assert output_file.exists() |
| |
| input_file = pyarrow_fileio_adls.new_input(location=location) |
| with input_file.open() as f: |
| assert f.read() == b"foo" |
| |
| pyarrow_fileio_adls.delete(input_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_getting_length_of_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test getting the length of PyArrowFile""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}") |
| with output_file.create() as f: |
| f.write(b"foobar") |
| |
| assert len(output_file) == 6 |
| |
| input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}") |
| assert len(input_file) == 6 |
| |
| pyarrow_fileio_adls.delete(output_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_file_tell_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| location = f"{adls_scheme}://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foobar") |
| |
| input_file = pyarrow_fileio_adls.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert f.tell() == 0 |
| f.seek(1) |
| assert f.tell() == 1 |
| f.seek(3) |
| assert f.tell() == 3 |
| f.seek(0) |
| assert f.tell() == 0 |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_read_specified_bytes_for_file_adls(pyarrow_fileio_adls: PyArrowFileIO) -> None: |
| location = f"abfss://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| |
| input_file = pyarrow_fileio_adls.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert b"f" == f.read(1) |
| f.seek(0) |
| assert b"fo" == f.read(2) |
| f.seek(1) |
| assert b"o" == f.read(1) |
| f.seek(1) |
| assert b"oo" == f.read(2) |
| f.seek(0) |
| assert b"foo" == f.read(999) # test reading amount larger than entire content length |
| |
| pyarrow_fileio_adls.delete(input_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_raise_on_opening_file_not_found_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test that PyArrowFile raises appropriately when the adls file is not found""" |
| |
| filename = str(uuid4()) |
| input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}") |
| with pytest.raises(FileNotFoundError) as exc_info: |
| input_file.open().read() |
| |
| assert filename in str(exc_info.value) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_checking_if_a_file_exists_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test checking if a file exists""" |
| non_existent_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/does-not-exist.txt") |
| assert not non_existent_file.exists() |
| |
| location = f"{adls_scheme}://warehouse/{uuid4()}" |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| assert not output_file.exists() |
| with output_file.create() as f: |
| f.write(b"foo") |
| |
| existing_input_file = pyarrow_fileio_adls.new_input(location=location) |
| assert existing_input_file.exists() |
| |
| existing_output_file = pyarrow_fileio_adls.new_output(location=location) |
| assert existing_output_file.exists() |
| |
| pyarrow_fileio_adls.delete(existing_output_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_closing_a_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test closing an output file and input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}") |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| assert not write_file.closed # type: ignore |
| assert write_file.closed # type: ignore |
| |
| input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}") |
| with input_file.open() as f: |
| assert not f.closed # type: ignore |
| assert f.closed # type: ignore |
| |
| pyarrow_fileio_adls.delete(f"{adls_scheme}://warehouse/{filename}") |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_converting_an_outputfile_to_an_inputfile_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test converting an output file to an input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}") |
| input_file = output_file.to_input_file() |
| assert input_file.location == output_file.location |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_writing_avro_file_adls(generated_manifest_entry_file: str, pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test that bytes match when reading a local avro file, writing it using pyarrow file-io, and then reading it again""" |
| filename = str(uuid4()) |
| with PyArrowFileIO().new_input(location=generated_manifest_entry_file).open() as f: |
| b1 = f.read() |
| with pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}").create() as out_f: |
| out_f.write(b1) |
| with pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}").open() as in_f: |
| b2 = in_f.read() |
| assert b1 == b2 # Check that bytes of read from local avro file match bytes written to s3 |
| |
| pyarrow_fileio_adls.delete(f"{adls_scheme}://warehouse/{filename}") |
| |
| |
| def test_parse_location() -> None: |
| def check_results(location: str, expected_schema: str, expected_netloc: str, expected_uri: str) -> None: |
| schema, netloc, uri = PyArrowFileIO.parse_location(location) |
| assert schema == expected_schema |
| assert netloc == expected_netloc |
| assert uri == expected_uri |
| |
| check_results("hdfs://127.0.0.1:9000/root/foo.txt", "hdfs", "127.0.0.1:9000", "/root/foo.txt") |
| check_results("hdfs://127.0.0.1/root/foo.txt", "hdfs", "127.0.0.1", "/root/foo.txt") |
| check_results("hdfs://clusterA/root/foo.txt", "hdfs", "clusterA", "/root/foo.txt") |
| |
| check_results("/root/foo.txt", "file", "", "/root/foo.txt") |
| check_results("/root/tmp/foo.txt", "file", "", "/root/tmp/foo.txt") |
| |
| |
| def test_make_compatible_name() -> None: |
| assert make_compatible_name("label/abc") == "label_x2Fabc" |
| assert make_compatible_name("label?abc") == "label_x3Fabc" |
| |
| |
| @pytest.mark.parametrize( |
| "vals, primitive_type, expected_result", |
| [ |
| ([None, 2, 1], IntegerType(), 1), |
| ([1, None, 2], IntegerType(), 1), |
| ([None, None, None], IntegerType(), None), |
| ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 1, 2)), |
| ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 1, 2)), |
| ([None, None, None], DateType(), None), |
| ], |
| ) |
| def test_stats_aggregator_update_min(vals: list[Any], primitive_type: PrimitiveType, expected_result: Any) -> None: |
| stats = StatsAggregator(primitive_type, _primitive_to_physical(primitive_type)) |
| |
| for val in vals: |
| stats.update_min(val) |
| |
| assert stats.current_min == expected_result |
| |
| |
| @pytest.mark.parametrize( |
| "vals, primitive_type, expected_result", |
| [ |
| ([None, 2, 1], IntegerType(), 2), |
| ([1, None, 2], IntegerType(), 2), |
| ([None, None, None], IntegerType(), None), |
| ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 2, 4)), |
| ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 2, 4)), |
| ([None, None, None], DateType(), None), |
| ], |
| ) |
| def test_stats_aggregator_update_max(vals: list[Any], primitive_type: PrimitiveType, expected_result: Any) -> None: |
| stats = StatsAggregator(primitive_type, _primitive_to_physical(primitive_type)) |
| |
| for val in vals: |
| stats.update_max(val) |
| |
| assert stats.current_max == expected_result |
| |
| |
| @pytest.mark.parametrize( |
| "iceberg_type, physical_type_string", |
| [ |
| # Exact match |
| (IntegerType(), "INT32"), |
| # Allowed INT32 -> INT64 promotion |
| (LongType(), "INT32"), |
| # Allowed FLOAT -> DOUBLE promotion |
| (DoubleType(), "FLOAT"), |
| # Allowed FIXED_LEN_BYTE_ARRAY -> INT32 |
| (DecimalType(precision=2, scale=2), "FIXED_LEN_BYTE_ARRAY"), |
| # Allowed FIXED_LEN_BYTE_ARRAY -> INT64 |
| (DecimalType(precision=12, scale=2), "FIXED_LEN_BYTE_ARRAY"), |
| ], |
| ) |
| def test_stats_aggregator_conditionally_allowed_types_pass(iceberg_type: PrimitiveType, physical_type_string: str) -> None: |
| stats = StatsAggregator(iceberg_type, physical_type_string) |
| |
| assert stats.primitive_type == iceberg_type |
| assert stats.current_min is None |
| assert stats.current_max is None |
| |
| |
| @pytest.mark.parametrize( |
| "iceberg_type, physical_type_string", |
| [ |
| # Fail case: INT64 cannot be cast to INT32 |
| (IntegerType(), "INT64"), |
| ], |
| ) |
| def test_stats_aggregator_physical_type_does_not_match_expected_raise_error( |
| iceberg_type: PrimitiveType, physical_type_string: str |
| ) -> None: |
| with pytest.raises(ValueError, match="Unexpected physical type"): |
| StatsAggregator(iceberg_type, physical_type_string) |
| |
| |
| def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None: |
| # default packs to 1 bin since the table is small |
| bin_packed = bin_pack_arrow_table( |
| arrow_table_with_null, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT |
| ) |
| assert len(list(bin_packed)) == 1 |
| |
| # as long as table is smaller than default target size, it should pack to 1 bin |
| bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10) |
| assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT |
| bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT) |
| assert len(list(bin_packed)) == 1 |
| |
| # unless we override the target size to be smaller |
| bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes) |
| assert len(list(bin_packed)) == 10 |
| |
| # and will produce half the number of files if we double the target size |
| bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2) |
| assert len(list(bin_packed)) == 5 |
| |
| |
| def test_bin_pack_arrow_table_target_size_smaller_than_row(arrow_table_with_null: pa.Table) -> None: |
| bin_packed = list(bin_pack_arrow_table(arrow_table_with_null, target_file_size=1)) |
| assert len(bin_packed) == arrow_table_with_null.num_rows |
| assert sum(batch.num_rows for bin_ in bin_packed for batch in bin_) == arrow_table_with_null.num_rows |
| |
| |
| def test_schema_mismatch_type(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.decimal128(18, 6), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| expected = r"""Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| └────┴──────────────────────────┴─────────────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=True), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| expected = """Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ❌ │ 2: bar: required int │ 2: bar: optional int │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| └────┴──────────────────────────┴──────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=False), |
| ) |
| ) |
| |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| expected = """Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ❌ │ 2: bar: required int │ Missing │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| └────┴──────────────────────────┴──────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Schema) -> None: |
| schema = table_schema_nested.as_arrow() |
| schema = schema.remove(6).insert( |
| 6, |
| pa.field( |
| "person", |
| pa.struct( |
| [ |
| pa.field("age", pa.int32(), nullable=False), |
| ] |
| ), |
| nullable=True, |
| ), |
| ) |
| try: |
| _check_pyarrow_schema_compatible(table_schema_nested, schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Schema) -> None: |
| other_schema = table_schema_nested.as_arrow() |
| other_schema = other_schema.remove(6).insert( |
| 6, |
| pa.field( |
| "person", |
| pa.struct( |
| [ |
| pa.field("name", pa.string(), nullable=True), |
| ] |
| ), |
| nullable=True, |
| ), |
| ) |
| expected = """Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ✅ │ 2: bar: required int │ 2: bar: required int │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| │ ✅ │ 4: qux: required list<string> │ 4: qux: required list<string> │ |
| │ ✅ │ 5: element: required string │ 5: element: required string │ |
| │ ✅ │ 6: quux: required map<string, │ 6: quux: required map<string, │ |
| │ │ map<string, int>> │ map<string, int>> │ |
| │ ✅ │ 7: key: required string │ 7: key: required string │ |
| │ ✅ │ 8: value: required map<string, │ 8: value: required map<string, │ |
| │ │ int> │ int> │ |
| │ ✅ │ 9: key: required string │ 9: key: required string │ |
| │ ✅ │ 10: value: required int │ 10: value: required int │ |
| │ ✅ │ 11: location: required │ 11: location: required │ |
| │ │ list<struct<13: latitude: optional │ list<struct<13: latitude: optional │ |
| │ │ float, 14: longitude: optional │ float, 14: longitude: optional │ |
| │ │ float>> │ float>> │ |
| │ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │ |
| │ │ latitude: optional float, 14: │ latitude: optional float, 14: │ |
| │ │ longitude: optional float> │ longitude: optional float> │ |
| │ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │ |
| │ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │ |
| │ ✅ │ 15: person: optional struct<16: │ 15: person: optional struct<16: │ |
| │ │ name: optional string, 17: age: │ name: optional string> │ |
| │ │ required int> │ │ |
| │ ✅ │ 16: name: optional string │ 16: name: optional string │ |
| │ ❌ │ 17: age: required int │ Missing │ |
| └────┴────────────────────────────────────┴────────────────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_nested, other_schema) |
| |
| |
| def test_schema_compatible_nested(table_schema_nested: Schema) -> None: |
| try: |
| _check_pyarrow_schema_compatible(table_schema_nested, table_schema_nested.as_arrow()) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=True), |
| pa.field("new_field", pa.date32(), nullable=True), |
| ) |
| ) |
| |
| with pytest.raises( |
| ValueError, match=r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." |
| ): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_compatible(table_schema_simple: Schema) -> None: |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_projection(table_schema_simple: Schema) -> None: |
| # remove optional `baz` field from `table_schema_simple` |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| ) |
| ) |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_downcast(table_schema_simple: Schema) -> None: |
| # large_string type is compatible with string type |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.large_string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_partition_for_demo() -> None: |
| test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) |
| test_schema = Schema( |
| NestedField(field_id=1, name="year", field_type=StringType(), required=False), |
| NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), |
| NestedField(field_id=3, name="animal", field_type=StringType(), required=False), |
| schema_id=1, |
| ) |
| test_data = { |
| "year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], |
| "n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100], |
| "animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], |
| } |
| arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) |
| partition_spec = PartitionSpec( |
| PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), |
| PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), |
| ) |
| result = list(_determine_partitions(partition_spec, test_schema, arrow_table)) |
| assert {table_partition.partition_key.partition for table_partition in result} == { |
| Record(2, 2020), |
| Record(100, 2021), |
| Record(4, 2021), |
| Record(4, 2022), |
| Record(2, 2022), |
| Record(5, 2019), |
| } |
| assert ( |
| pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows |
| ) |
| |
| |
| def test_partition_for_nested_field() -> None: |
| schema = Schema( |
| NestedField(id=1, name="foo", field_type=StringType(), required=True), |
| NestedField( |
| id=2, |
| name="bar", |
| field_type=StructType( |
| NestedField(id=3, name="baz", field_type=TimestampType(), required=False), |
| NestedField(id=4, name="qux", field_type=IntegerType(), required=False), |
| ), |
| required=True, |
| ), |
| ) |
| |
| spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts")) |
| |
| t1 = datetime(2025, 7, 11, 9, 30, 0) |
| t2 = datetime(2025, 7, 11, 10, 30, 0) |
| |
| test_data = [ |
| {"foo": "a", "bar": {"baz": t1, "qux": 1}}, |
| {"foo": "b", "bar": {"baz": t2, "qux": 2}}, |
| ] |
| |
| arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) |
| partitions = list(_determine_partitions(spec, schema, arrow_table)) |
| partition_values = {p.partition_key.partition[0] for p in partitions} |
| |
| assert partition_values == {486729, 486730} |
| |
| |
| def test_partition_for_deep_nested_field() -> None: |
| schema = Schema( |
| NestedField( |
| id=1, |
| name="foo", |
| field_type=StructType( |
| NestedField( |
| id=2, |
| name="bar", |
| field_type=StructType(NestedField(id=3, name="baz", field_type=StringType(), required=False)), |
| required=True, |
| ) |
| ), |
| required=True, |
| ) |
| ) |
| |
| spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="qux")) |
| |
| test_data = [ |
| {"foo": {"bar": {"baz": "data-1"}}}, |
| {"foo": {"bar": {"baz": "data-2"}}}, |
| {"foo": {"bar": {"baz": "data-1"}}}, |
| ] |
| |
| arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) |
| partitions = list(_determine_partitions(spec, schema, arrow_table)) |
| |
| assert len(partitions) == 2 # 2 unique partitions |
| partition_values = {p.partition_key.partition[0] for p in partitions} |
| assert partition_values == {"data-1", "data-2"} |
| |
| |
| def test_inspect_partition_for_nested_field(catalog: InMemoryCatalog) -> None: |
| schema = Schema( |
| NestedField(id=1, name="foo", field_type=StringType(), required=True), |
| NestedField( |
| id=2, |
| name="bar", |
| field_type=StructType( |
| NestedField(id=3, name="baz", field_type=StringType(), required=False), |
| NestedField(id=4, name="qux", field_type=IntegerType(), required=False), |
| ), |
| required=True, |
| ), |
| ) |
| spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="part")) |
| catalog.create_namespace("default") |
| table = catalog.create_table("default.test_partition_in_struct", schema=schema, partition_spec=spec) |
| test_data = [ |
| {"foo": "a", "bar": {"baz": "data-a", "qux": 1}}, |
| {"foo": "b", "bar": {"baz": "data-b", "qux": 2}}, |
| ] |
| |
| arrow_table = pa.Table.from_pylist(test_data, schema=table.schema().as_arrow()) |
| table.append(arrow_table) |
| partitions_table = table.inspect.partitions() |
| partitions = partitions_table["partition"].to_pylist() |
| |
| assert len(partitions) == 2 |
| assert {part["part"] for part in partitions} == {"data-a", "data-b"} |
| |
| |
| def test_inspect_partitions_respects_partition_evolution(catalog: InMemoryCatalog) -> None: |
| schema = Schema( |
| NestedField(id=1, name="dt", field_type=DateType(), required=False), |
| NestedField(id=2, name="category", field_type=StringType(), required=False), |
| ) |
| spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="dt")) |
| catalog.create_namespace("default") |
| table = catalog.create_table( |
| "default.test_inspect_partitions_respects_partition_evolution", schema=schema, partition_spec=spec |
| ) |
| |
| old_spec_id = table.spec().spec_id |
| old_data = [{"dt": date(2025, 1, 1), "category": "old"}] |
| table.append(pa.Table.from_pylist(old_data, schema=table.schema().as_arrow())) |
| |
| table.update_spec().add_identity("category").commit() |
| new_spec_id = table.spec().spec_id |
| assert new_spec_id != old_spec_id |
| |
| partitions_table = table.inspect.partitions() |
| partitions = partitions_table["partition"].to_pylist() |
| assert all("category" not in partition for partition in partitions) |
| |
| new_data = [{"dt": date(2025, 1, 2), "category": "new"}] |
| table.append(pa.Table.from_pylist(new_data, schema=table.schema().as_arrow())) |
| |
| partitions_table = table.inspect.partitions() |
| assert set(partitions_table["spec_id"].to_pylist()) == {old_spec_id, new_spec_id} |
| |
| |
| def test_identity_partition_on_multi_columns() -> None: |
| test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) |
| test_schema = Schema( |
| NestedField(field_id=1, name="born_year", field_type=StringType(), required=False), |
| NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), |
| NestedField(field_id=3, name="animal", field_type=StringType(), required=False), |
| schema_id=1, |
| ) |
| # 5 partitions, 6 unique row values, 12 rows |
| test_rows = [ |
| (2021, 4, "Dog"), |
| (2022, 4, "Horse"), |
| (2022, 4, "Another Horse"), |
| (2021, 100, "Centipede"), |
| (None, 4, "Kirin"), |
| (2021, None, "Fish"), |
| ] * 2 |
| expected = {Record(test_rows[i][1], test_rows[i][0]) for i in range(len(test_rows))} |
| partition_spec = PartitionSpec( |
| PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), |
| PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), |
| ) |
| import random |
| |
| # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all |
| for _ in range(1000): |
| random.shuffle(test_rows) |
| test_data = { |
| "born_year": [row[0] for row in test_rows], |
| "n_legs": [row[1] for row in test_rows], |
| "animal": [row[2] for row in test_rows], |
| } |
| arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) |
| |
| result = list(_determine_partitions(partition_spec, test_schema, arrow_table)) |
| |
| assert {table_partition.partition_key.partition for table_partition in result} == expected |
| concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) |
| assert concatenated_arrow_table.num_rows == arrow_table.num_rows |
| assert concatenated_arrow_table.sort_by( |
| [ |
| ("born_year", "ascending"), |
| ("n_legs", "ascending"), |
| ("animal", "ascending"), |
| ] |
| ) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) |
| |
| |
| def test_initial_value() -> None: |
| # Have some fake data, otherwise it will generate a table without records |
| data = pa.record_batch([pa.nulls(10, pa.int64())], names=["some_field"]) |
| result = _to_requested_schema( |
| Schema(NestedField(1, "we-love-22", LongType(), required=True, initial_default=22)), Schema(), data |
| ) |
| assert result.column_names == ["we-love-22"] |
| for val in result[0]: |
| assert val.as_py() == 22 |
| |
| |
| def test__to_requested_schema_timestamp_to_timestamptz_projection() -> None: |
| # file is written with timestamp without timezone |
| file_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False)) |
| batch = pa.record_batch( |
| [ |
| pa.array( |
| [ |
| datetime(2025, 8, 14, 12, 0, 0), |
| datetime(2025, 8, 14, 13, 0, 0), |
| ], |
| type=pa.timestamp("us"), |
| ) |
| ], |
| names=["ts_field"], |
| ) |
| |
| # table is written with timestamp with timezone |
| table_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False)) |
| |
| actual_result = _to_requested_schema(table_schema, file_schema, batch, downcast_ns_timestamp_to_us=True) |
| expected = pa.record_batch( |
| [ |
| pa.array( |
| [ |
| datetime(2025, 8, 14, 12, 0, 0), |
| datetime(2025, 8, 14, 13, 0, 0), |
| ], |
| type=pa.timestamp("us", tz=timezone.utc), |
| ) |
| ], |
| names=["ts_field"], |
| ) |
| |
| # expect actual_result to have timezone |
| assert expected.equals(actual_result) |
| |
| |
| def test__to_requested_schema_timestamptz_to_timestamp_projection() -> None: |
| # file is written with timestamp with timezone |
| file_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False)) |
| batch = pa.record_batch( |
| [ |
| pa.array( |
| [ |
| datetime(2025, 8, 14, 12, 0, 0, tzinfo=timezone.utc), |
| datetime(2025, 8, 14, 13, 0, 0, tzinfo=timezone.utc), |
| ], |
| type=pa.timestamp("us", tz="UTC"), |
| ) |
| ], |
| names=["ts_field"], |
| ) |
| |
| # table schema expects timestamp without timezone |
| table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False)) |
| |
| # allow_timestamp_tz_mismatch=True enables reading timestamptz as timestamp |
| actual_result = _to_requested_schema( |
| table_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, allow_timestamp_tz_mismatch=True |
| ) |
| expected = pa.record_batch( |
| [ |
| pa.array( |
| [ |
| datetime(2025, 8, 14, 12, 0, 0), |
| datetime(2025, 8, 14, 13, 0, 0), |
| ], |
| type=pa.timestamp("us"), |
| ) |
| ], |
| names=["ts_field"], |
| ) |
| |
| # expect actual_result to have no timezone |
| assert expected.equals(actual_result) |
| |
| |
| def test__to_requested_schema_timestamptz_to_timestamp_write_rejects() -> None: |
| """Test that the write path (default) rejects timestamptz to timestamp casting. |
| |
| This ensures we enforce the Iceberg spec distinction between timestamp and timestamptz on writes, |
| while the read path can be more permissive (like Spark) via allow_timestamp_tz_mismatch=True. |
| """ |
| # file is written with timestamp with timezone |
| file_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False)) |
| batch = pa.record_batch( |
| [ |
| pa.array( |
| [ |
| datetime(2025, 8, 14, 12, 0, 0, tzinfo=timezone.utc), |
| datetime(2025, 8, 14, 13, 0, 0, tzinfo=timezone.utc), |
| ], |
| type=pa.timestamp("us", tz="UTC"), |
| ) |
| ], |
| names=["ts_field"], |
| ) |
| |
| # table schema expects timestamp without timezone |
| table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False)) |
| |
| # allow_timestamp_tz_mismatch=False (default, used in write path) should raise |
| with pytest.raises(ValueError, match="Unsupported schema projection"): |
| _to_requested_schema( |
| table_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, allow_timestamp_tz_mismatch=False |
| ) |
| |
| |
| def test_write_file_rejects_timestamptz_to_timestamp(tmp_path: Path) -> None: |
| """Test that write_file rejects writing timestamptz data to a timestamp column.""" |
| from pyiceberg.table import WriteTask |
| |
| # Table expects timestamp (no tz), but data has timestamptz |
| table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False)) |
| task_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False)) |
| |
| arrow_data = pa.table({"ts_field": [datetime(2025, 8, 14, 12, 0, 0, tzinfo=timezone.utc)]}) |
| |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}", |
| last_column_id=1, |
| format_version=2, |
| schemas=[table_schema], |
| partition_specs=[PartitionSpec()], |
| ) |
| |
| task = WriteTask( |
| write_uuid=uuid.uuid4(), |
| task_id=0, |
| record_batches=arrow_data.to_batches(), |
| schema=task_schema, |
| ) |
| |
| with pytest.raises(ValueError, match="Unsupported schema projection"): |
| list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) |
| |
| |
| def test__to_requested_schema_timestamps( |
| arrow_table_schema_with_all_timestamp_precisions: pa.Schema, |
| arrow_table_with_all_timestamp_precisions: pa.Table, |
| arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, |
| table_schema_with_all_microseconds_timestamp_precision: Schema, |
| ) -> None: |
| requested_schema = table_schema_with_all_microseconds_timestamp_precision |
| file_schema = requested_schema |
| batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] |
| result = _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False) |
| |
| expected = arrow_table_with_all_timestamp_precisions.cast( |
| arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False |
| ).to_batches()[0] |
| assert result == expected |
| |
| |
| def test__to_requested_schema_timestamps_without_downcast_raises_exception( |
| arrow_table_schema_with_all_timestamp_precisions: pa.Schema, |
| arrow_table_with_all_timestamp_precisions: pa.Table, |
| arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, |
| table_schema_with_all_microseconds_timestamp_precision: Schema, |
| ) -> None: |
| requested_schema = table_schema_with_all_microseconds_timestamp_precision |
| file_schema = requested_schema |
| batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] |
| with pytest.raises(ValueError) as exc_info: |
| _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) |
| |
| assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value) |
| |
| |
| @pytest.mark.parametrize( |
| "arrow_type,iceberg_type,expected_arrow_type", |
| [ |
| (pa.uint8(), IntegerType(), pa.int32()), |
| (pa.int8(), IntegerType(), pa.int32()), |
| (pa.int16(), IntegerType(), pa.int32()), |
| (pa.uint16(), IntegerType(), pa.int32()), |
| (pa.uint32(), LongType(), pa.int64()), |
| (pa.int32(), LongType(), pa.int64()), |
| ], |
| ) |
| def test__to_requested_schema_integer_promotion( |
| arrow_type: pa.DataType, |
| iceberg_type: PrimitiveType, |
| expected_arrow_type: pa.DataType, |
| ) -> None: |
| """Test that smaller integer types are cast to target Iceberg type during write.""" |
| requested_schema = Schema(NestedField(1, "col", iceberg_type, required=False)) |
| file_schema = requested_schema |
| |
| arrow_schema = pa.schema([pa.field("col", arrow_type)]) |
| data = pa.array([1, 2, 3, None], type=arrow_type) |
| batch = pa.RecordBatch.from_arrays([data], schema=arrow_schema) |
| |
| result = _to_requested_schema( |
| requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False |
| ) |
| |
| assert result.schema[0].type == expected_arrow_type |
| assert result.column(0).to_pylist() == [1, 2, 3, None] |
| |
| |
| def test_pyarrow_file_io_fs_by_scheme_cache() -> None: |
| # It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument |
| # becomes available for `resolve_s3_region` |
| # Refer to: https://github.com/apache/arrow/issues/43713 |
| |
| pyarrow_file_io = PyArrowFileIO() |
| us_east_1_region = "us-east-1" |
| ap_southeast_2_region = "ap-southeast-2" |
| |
| with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| # Call with new argument resolves region automatically |
| mock_s3_region_resolver.return_value = us_east_1_region |
| filesystem_us = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket") |
| assert filesystem_us.region == us_east_1_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 1 # type: ignore |
| assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 1 # type: ignore |
| |
| # Call with different argument also resolves region automatically |
| mock_s3_region_resolver.return_value = ap_southeast_2_region |
| filesystem_ap_southeast_2 = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket") |
| assert filesystem_ap_southeast_2.region == ap_southeast_2_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 2 # type: ignore |
| assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 2 # type: ignore |
| |
| # Call with same argument hits cache |
| filesystem_us_cached = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket") |
| assert filesystem_us_cached.region == us_east_1_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 1 # type: ignore |
| |
| # Call with same argument hits cache |
| filesystem_ap_southeast_2_cached = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket") |
| assert filesystem_ap_southeast_2_cached.region == ap_southeast_2_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 2 # type: ignore |
| |
| |
| def test_pyarrow_io_new_input_multi_region(caplog: Any) -> None: |
| # It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument |
| # becomes available for `resolve_s3_region` |
| # Refer to: https://github.com/apache/arrow/issues/43713 |
| user_provided_region = "ap-southeast-1" |
| bucket_regions = [ |
| ("us-east-2-bucket", "us-east-2"), |
| ("ap-southeast-2-bucket", "ap-southeast-2"), |
| ] |
| |
| def _s3_region_map(bucket: str) -> str: |
| for bucket_region in bucket_regions: |
| if bucket_region[0] == bucket: |
| return bucket_region[1] |
| raise OSError("Unknown bucket") |
| |
| # For a pyarrow io instance with configured default s3 region |
| pyarrow_file_io = PyArrowFileIO({"s3.region": user_provided_region, "s3.resolve-region": "true"}) |
| with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| mock_s3_region_resolver.side_effect = _s3_region_map |
| |
| # The region is set to provided region if bucket region cannot be resolved |
| with caplog.at_level(logging.WARNING): |
| assert pyarrow_file_io.new_input("s3://non-exist-bucket/path/to/file")._filesystem.region == user_provided_region |
| assert "Unable to resolve region for bucket non-exist-bucket" in caplog.text |
| |
| for bucket_region in bucket_regions: |
| # For s3 scheme, region is overwritten by resolved bucket region if different from user provided region |
| with caplog.at_level(logging.WARNING): |
| assert pyarrow_file_io.new_input(f"s3://{bucket_region[0]}/path/to/file")._filesystem.region == bucket_region[1] |
| assert ( |
| f"PyArrow FileIO overriding S3 bucket region for bucket {bucket_region[0]}: " |
| f"provided region {user_provided_region}, actual region {bucket_region[1]}" in caplog.text |
| ) |
| |
| # For oss scheme, user provided region is used instead |
| assert pyarrow_file_io.new_input(f"oss://{bucket_region[0]}/path/to/file")._filesystem.region == user_provided_region |
| |
| |
| def test_pyarrow_io_multi_fs() -> None: |
| pyarrow_file_io = PyArrowFileIO({"s3.region": "ap-southeast-1"}) |
| |
| with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| mock_s3_region_resolver.return_value = None |
| |
| # The PyArrowFileIO instance resolves s3 file input to S3FileSystem |
| assert isinstance(pyarrow_file_io.new_input("s3://bucket/path/to/file")._filesystem, S3FileSystem) |
| |
| # Same PyArrowFileIO instance resolves local file input to LocalFileSystem |
| assert isinstance(pyarrow_file_io.new_input("file:///path/to/file")._filesystem, LocalFileSystem) |
| |
| |
| class SomeRetryStrategy(AwsDefaultS3RetryStrategy): |
| def __init__(self) -> None: |
| super().__init__() |
| warnings.warn("Initialized SomeRetryStrategy 👍", stacklevel=2) |
| |
| |
| def test_retry_strategy() -> None: |
| io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "tests.io.test_pyarrow.SomeRetryStrategy"}) |
| with pytest.warns(UserWarning, match="Initialized SomeRetryStrategy.*"): |
| io.new_input("s3://bucket/path/to/file") |
| |
| |
| def test_retry_strategy_not_found() -> None: |
| io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"}) |
| with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"): |
| io.new_input("s3://bucket/path/to/file") |
| |
| |
| @pytest.mark.parametrize("format_version", [1, 2, 3]) |
| def test_task_to_record_batches_nanos(format_version: TableVersion, tmpdir: str) -> None: |
| arrow_table = pa.table( |
| [ |
| pa.array( |
| [ |
| datetime(2025, 8, 14, 12, 0, 0), |
| datetime(2025, 8, 14, 13, 0, 0), |
| ], |
| type=pa.timestamp("ns"), |
| ) |
| ], |
| pa.schema((pa.field("ts_field", pa.timestamp("ns"), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),)), |
| ) |
| |
| data_file = _write_table_to_data_file(f"{tmpdir}/test_task_to_record_batches_nanos.parquet", arrow_table.schema, arrow_table) |
| |
| if format_version <= 2: |
| table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False)) |
| else: |
| table_schema = Schema(NestedField(1, "ts_field", TimestampNanoType(), required=False)) |
| |
| actual_result = list( |
| _task_to_record_batches( |
| PyArrowFileIO(), |
| FileScanTask(data_file), |
| bound_row_filter=AlwaysTrue(), |
| projected_schema=table_schema, |
| table_schema=table_schema, |
| projected_field_ids={1}, |
| positional_deletes=None, |
| case_sensitive=True, |
| format_version=format_version, |
| ) |
| )[0] |
| |
| def _expected_batch(unit: str) -> pa.RecordBatch: |
| return pa.record_batch( |
| [ |
| pa.array( |
| [ |
| datetime(2025, 8, 14, 12, 0, 0), |
| datetime(2025, 8, 14, 13, 0, 0), |
| ], |
| type=pa.timestamp(unit), |
| ) |
| ], |
| names=["ts_field"], |
| ) |
| |
| assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result) |
| |
| |
| def test_parse_location_defaults() -> None: |
| """Test that parse_location uses defaults.""" |
| |
| from pyiceberg.io.pyarrow import PyArrowFileIO |
| |
| # if no default scheme or netloc is provided, use file scheme and empty netloc |
| scheme, netloc, path = PyArrowFileIO.parse_location("/foo/bar") |
| assert scheme == "file" |
| assert netloc == "" |
| assert path == "/foo/bar" |
| |
| scheme, netloc, path = PyArrowFileIO.parse_location( |
| "/foo/bar", properties={"DEFAULT_SCHEME": "scheme", "DEFAULT_NETLOC": "netloc:8000"} |
| ) |
| assert scheme == "scheme" |
| assert netloc == "netloc:8000" |
| assert path == "/foo/bar" |
| |
| scheme, netloc, path = PyArrowFileIO.parse_location( |
| "/foo/bar", properties={"DEFAULT_SCHEME": "hdfs", "DEFAULT_NETLOC": "netloc:8000"} |
| ) |
| assert scheme == "hdfs" |
| assert netloc == "netloc:8000" |
| assert path == "/foo/bar" |
| |
| |
| def test_write_and_read_orc(tmp_path: Path) -> None: |
| """Test basic ORC write and read functionality.""" |
| # Create a simple Arrow table |
| data = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) |
| orc_path = tmp_path / "test.orc" |
| orc.write_table(data, str(orc_path)) |
| # Read it back |
| orc_file = orc.ORCFile(str(orc_path)) |
| table_read = orc_file.read() |
| assert table_read.equals(data) |
| |
| |
| def test_orc_file_format_integration(tmp_path: Path) -> None: |
| """Test ORC file format integration with PyArrow dataset API.""" |
| # This test mimics a minimal integration with PyIceberg's FileFormat enum and pyarrow.orc |
| import pyarrow.dataset as ds |
| |
| data = pa.table({"a": [10, 20], "b": ["foo", "bar"]}) |
| orc_path = tmp_path / "iceberg.orc" |
| orc.write_table(data, str(orc_path)) |
| # Use PyArrow dataset API to read as ORC |
| dataset = ds.dataset(str(orc_path), format=ds.OrcFileFormat()) |
| table_read = dataset.to_table() |
| assert table_read.equals(data) |
| |
| |
| def test_iceberg_read_orc(tmp_path: Path) -> None: |
| """ |
| Integration test: Read ORC files via Iceberg API. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_iceberg_read_orc |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema and data |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "name", StringType(), required=False), |
| ) |
| data = pa.table({"id": pa.array([1, 2, 3], type=pa.int32()), "name": ["a", "b", "c"]}) |
| |
| # Create ORC file directly using PyArrow |
| orc_path = tmp_path / "test_data.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "write.format.default": "parquet", # This doesn't matter for reading |
| # Add name mapping for ORC files without field IDs |
| "schema.name-mapping.default": ('[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["name"]}]'), |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Create a DataFile pointing to the ORC file |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=3, |
| column_sizes={1: 12, 2: 12}, # Approximate sizes |
| value_counts={1: 3, 2: 3}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"a"}, # Approximate bounds |
| upper_bounds={1: b"\x03\x00\x00\x00", 2: b"c"}, |
| split_offsets=None, |
| ) |
| # Ensure spec_id is properly set |
| data_file.spec_id = 0 |
| |
| # Read back using ArrowScan |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| scan_task = FileScanTask(data_file=data_file) |
| table_read = scan.to_table([scan_task]) |
| |
| # Compare data ignoring schema metadata (like not null constraints) |
| assert table_read.num_rows == data.num_rows |
| assert table_read.num_columns == data.num_columns |
| assert table_read.column_names == data.column_names |
| |
| # Compare actual column data values |
| for col_name in data.column_names: |
| assert table_read.column(col_name).to_pylist() == data.column(col_name).to_pylist() |
| |
| |
| def test_orc_row_filtering_predicate_pushdown(tmp_path: Path) -> None: |
| """ |
| Test ORC row filtering and predicate pushdown functionality. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_row_filtering_predicate_pushdown |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import And, EqualTo, GreaterThan, In, LessThan, Or |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import BooleanType, IntegerType, StringType |
| |
| # Define schema and data with more complex data for filtering |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "name", StringType(), required=False), |
| NestedField(3, "age", IntegerType(), required=True), |
| NestedField(4, "active", BooleanType(), required=True), |
| ) |
| |
| # Create data with various values for filtering |
| data = pa.table( |
| { |
| "id": pa.array([1, 2, 3, 4, 5], type=pa.int32()), |
| "name": ["alice", "bob", "charlie", "david", "eve"], |
| "age": pa.array([25, 30, 35, 40, 45], type=pa.int32()), |
| "active": [True, False, True, True, False], |
| } |
| ) |
| |
| # Create ORC file |
| orc_path = tmp_path / "filter_test.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=4, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": ( |
| '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["name"]}, ' |
| '{"field-id": 3, "names": ["age"]}, {"field-id": 4, "names": ["active"]}]' |
| ), |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=5, |
| column_sizes={1: 20, 2: 50, 3: 20, 4: 10}, |
| value_counts={1: 5, 2: 5, 3: 5, 4: 5}, |
| null_value_counts={1: 0, 2: 0, 3: 0, 4: 0}, |
| nan_value_counts={1: 0, 2: 0, 3: 0, 4: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"alice", 3: b"\x19\x00\x00\x00", 4: b"\x00"}, |
| upper_bounds={1: b"\x05\x00\x00\x00", 2: b"eve", 3: b"\x2d\x00\x00\x00", 4: b"\x01"}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| |
| scan_task = FileScanTask(data_file=data_file) |
| |
| # Test 1: Simple equality filter |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=EqualTo("id", 3), |
| case_sensitive=True, |
| ) |
| result = scan.to_table([scan_task]) |
| assert result.num_rows == 1 |
| assert result.column("id").to_pylist() == [3] |
| assert result.column("name").to_pylist() == ["charlie"] |
| |
| # Test 2: Range filter |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=And(GreaterThan("age", 30), LessThan("age", 45)), |
| case_sensitive=True, |
| ) |
| result = scan.to_table([scan_task]) |
| assert result.num_rows == 2 |
| assert set(result.column("id").to_pylist()) == {3, 4} |
| |
| # Test 3: String filter |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=EqualTo("name", "bob"), |
| case_sensitive=True, |
| ) |
| result = scan.to_table([scan_task]) |
| assert result.num_rows == 1 |
| assert result.column("name").to_pylist() == ["bob"] |
| |
| # Test 4: Boolean filter |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=EqualTo("active", True), |
| case_sensitive=True, |
| ) |
| result = scan.to_table([scan_task]) |
| assert result.num_rows == 3 |
| assert set(result.column("id").to_pylist()) == {1, 3, 4} |
| |
| # Test 5: IN filter |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=In("id", [1, 3, 5]), |
| case_sensitive=True, |
| ) |
| result = scan.to_table([scan_task]) |
| assert result.num_rows == 3 |
| assert set(result.column("id").to_pylist()) == {1, 3, 5} |
| |
| # Test 6: Complex AND/OR filter |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=Or(And(EqualTo("active", True), GreaterThan("age", 30)), EqualTo("name", "bob")), |
| case_sensitive=True, |
| ) |
| result = scan.to_table([scan_task]) |
| assert result.num_rows == 3 |
| assert set(result.column("id").to_pylist()) == {2, 3, 4} # bob, charlie, david |
| |
| |
| def test_orc_record_batching_streaming(tmp_path: Path) -> None: |
| """ |
| Test ORC record batching and streaming functionality with multiple files and fragments. |
| This test validates that we get the expected number of batches based on file scan tasks |
| and ORC fragments, providing end-to-end validation of the batching behavior. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_record_batching_streaming |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "value", StringType(), required=False), |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Test with larger files to better demonstrate batching behavior |
| # PyArrow default batch size is typically 1024 rows, so we'll create files larger than that |
| num_files = 2 |
| rows_per_file = 2000 # Larger than default batch size to ensure multiple batches per file |
| total_rows = num_files * rows_per_file |
| |
| scan_tasks = [] |
| for file_idx in range(num_files): |
| # Create data for this file |
| start_id = file_idx * rows_per_file + 1 |
| end_id = (file_idx + 1) * rows_per_file |
| data = pa.table( |
| { |
| "id": pa.array(range(start_id, end_id + 1), type=pa.int32()), |
| "value": [f"file_{file_idx}_value_{i}" for i in range(start_id, end_id + 1)], |
| } |
| ) |
| |
| # Create ORC file |
| orc_path = tmp_path / f"batch_test_{file_idx}.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=rows_per_file, |
| column_sizes={1: 8000, 2: 16000}, |
| value_counts={1: rows_per_file, 2: rows_per_file}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: start_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{start_id}".encode()}, |
| upper_bounds={1: end_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{end_id}".encode()}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| scan_tasks.append(FileScanTask(data_file=data_file)) |
| |
| # Test 1: Multiple file batching - verify we get batches from all files |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| |
| batches = list(scan.to_record_batches(scan_tasks)) |
| |
| # Verify we get the expected number of batches |
| # Based on our testing, PyArrow creates 1 batch per file |
| expected_batches = num_files # 1 batch per file |
| assert len(batches) == expected_batches, f"Expected {expected_batches} batches (1 per file), got {len(batches)}" |
| |
| # Verify batch sizes are reasonable (not too large) |
| max_batch_size = max(batch.num_rows for batch in batches) |
| assert max_batch_size <= 2000, f"Batch size {max_batch_size} seems too large for ORC files" |
| assert max_batch_size > 0, "Batch should not be empty" |
| |
| # We shouldn't get more batches than total rows (one batch per row maximum) |
| assert len(batches) <= total_rows, f"Expected at most {total_rows} batches (one per row), got {len(batches)}" |
| |
| # Verify all batches are RecordBatch objects |
| for batch in batches: |
| assert isinstance(batch, pa.RecordBatch), f"Expected RecordBatch, got {type(batch)}" |
| assert batch.num_columns == 2, f"Expected 2 columns, got {batch.num_columns}" |
| assert "id" in batch.schema.names, "Missing 'id' column" |
| assert "value" in batch.schema.names, "Missing 'value' column" |
| |
| # Test 2: Verify data integrity across all batches from all files |
| total_rows = sum(batch.num_rows for batch in batches) |
| assert total_rows == total_rows, f"Expected {total_rows} rows total, got {total_rows}" |
| |
| # Collect all data from batches and verify it spans all files |
| all_ids = [] |
| all_values = [] |
| for batch in batches: |
| all_ids.extend(batch.column("id").to_pylist()) |
| all_values.extend(batch.column("value").to_pylist()) |
| |
| # Verify we have data from all files |
| expected_ids = list(range(1, total_rows + 1)) |
| assert sorted(all_ids) == expected_ids, f"ID data doesn't match expected range 1-{total_rows}" |
| |
| # Verify values contain data from all files |
| file_values = set() |
| for value in all_values: |
| if value.startswith("file_"): |
| file_idx = int(value.split("_")[1]) |
| file_values.add(file_idx) |
| assert file_values == set(range(num_files)), f"Expected values from all {num_files} files, got from files: {file_values}" |
| |
| # Test 3: Verify batch distribution across files |
| # Each file should contribute at least one batch |
| batch_sizes = [batch.num_rows for batch in batches] |
| total_batch_rows = sum(batch_sizes) |
| assert total_batch_rows == total_rows, f"Total batch rows {total_batch_rows} != expected {total_rows}" |
| |
| # Verify we have reasonable batch sizes (not too small, not too large) |
| for batch_size in batch_sizes: |
| assert batch_size > 0, "Batch should not be empty" |
| assert batch_size <= total_rows, f"Batch size {batch_size} should not exceed total rows {total_rows}" |
| |
| # Test 4: Streaming behavior with multiple files |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| |
| processed_rows = 0 |
| batch_count = 0 |
| file_data_counts = dict.fromkeys(range(num_files), 0) |
| |
| for batch in scan.to_record_batches(scan_tasks): |
| batch_count += 1 |
| processed_rows += batch.num_rows |
| |
| # Count rows per file in this batch |
| for value in batch.column("value").to_pylist(): |
| if value.startswith("file_"): |
| file_idx = int(value.split("_")[1]) |
| file_data_counts[file_idx] += 1 |
| |
| # PyArrow may optimize batching, so we just verify we get reasonable batching |
| assert batch_count >= 1, f"Expected at least 1 batch, got {batch_count}" |
| assert batch_count <= num_files, f"Expected at most {num_files} batches (1 per file), got {batch_count}" |
| assert processed_rows == total_rows, f"Processed {processed_rows} rows, expected {total_rows}" |
| |
| # Verify each file contributed data |
| for file_idx in range(num_files): |
| assert file_data_counts[file_idx] == rows_per_file, ( |
| f"File {file_idx} contributed {file_data_counts[file_idx]} rows, expected {rows_per_file}" |
| ) |
| |
| # Test 5: Column projection with multiple files |
| projected_schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| ) |
| |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=projected_schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| |
| batches = list(scan.to_record_batches(scan_tasks)) |
| assert len(batches) >= 1, f"Expected at least 1 batch for projected schema, got {len(batches)}" |
| |
| for batch in batches: |
| assert batch.num_columns == 1, f"Expected 1 column after projection, got {batch.num_columns}" |
| assert "id" in batch.schema.names, "Missing 'id' column after projection" |
| assert "value" not in batch.schema.names, "Should not have 'value' column after projection" |
| |
| # Test 6: Filtering with multiple files |
| from pyiceberg.expressions import GreaterThan |
| |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=GreaterThan("id", total_rows // 2), # Filter to second half of data |
| case_sensitive=True, |
| ) |
| |
| batches = list(scan.to_record_batches(scan_tasks)) |
| total_filtered_rows = sum(batch.num_rows for batch in batches) |
| expected_filtered = total_rows // 2 |
| assert total_filtered_rows == expected_filtered, ( |
| f"Expected {expected_filtered} rows after filtering, got {total_filtered_rows}" |
| ) |
| |
| # Verify all returned IDs are in the filtered range |
| for batch in batches: |
| ids = batch.column("id").to_pylist() |
| assert all(id_val > total_rows // 2 for id_val in ids), f"Found ID <= {total_rows // 2}: {ids}" |
| |
| # Test 7: Verify batch count matches expected pattern |
| # The number of batches should be >= number of files (one batch per file minimum) |
| # and could be more if ORC creates multiple fragments per file |
| # This validates the end-to-end batching behavior as requested in the PR comment |
| # We expect multiple batches based on file size and configured batch size |
| |
| # Verify we get reasonable batching behavior |
| assert len(batches) >= 1, f"Expected at least 1 batch, got {len(batches)}" |
| assert len(batches) <= total_rows, f"Expected at most {total_rows} batches (one per row), got {len(batches)}" |
| |
| |
| def test_orc_batching_exact_counts_single_file(tmp_path: Path) -> None: |
| """ |
| Test exact batch counts for single ORC files of different sizes. |
| This test explicitly verifies the number of batches PyArrow creates for different file sizes. |
| Note: Uses default ORC writing (1 stripe per file), so expects 1 batch per file. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_batching_exact_counts_single_file |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "value", StringType(), required=False), |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Test different file sizes to understand PyArrow's batching behavior |
| # Note: All files will have 1 stripe (default ORC writing), so 1 batch each |
| test_cases = [ |
| (500, "Small file (1 stripe)"), |
| (1000, "Medium file (1 stripe)"), |
| (2000, "Large file (1 stripe)"), |
| (5000, "Very large file (1 stripe)"), |
| ] |
| |
| for num_rows, _description in test_cases: |
| # Create data |
| data = pa.table( |
| {"id": pa.array(range(1, num_rows + 1), type=pa.int32()), "value": [f"value_{i}" for i in range(1, num_rows + 1)]} |
| ) |
| |
| # Create ORC file |
| orc_path = tmp_path / f"test_{num_rows}_rows.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=num_rows, |
| column_sizes={1: num_rows * 4, 2: num_rows * 8}, |
| value_counts={1: num_rows, 2: num_rows}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"}, |
| upper_bounds={1: num_rows.to_bytes(4, "little"), 2: f"value_{num_rows}".encode()}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| |
| scan_task = FileScanTask(data_file=data_file) |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| batches = list(scan.to_record_batches([scan_task])) |
| |
| # Verify exact batch count and sizes |
| total_batch_rows = sum(batch.num_rows for batch in batches) |
| assert total_batch_rows == num_rows, f"Total rows mismatch: expected {num_rows}, got {total_batch_rows}" |
| |
| # Verify data integrity |
| all_ids = [] |
| for batch in batches: |
| all_ids.extend(batch.column("id").to_pylist()) |
| assert sorted(all_ids) == list(range(1, num_rows + 1)), f"Data integrity check failed for {num_rows} rows" |
| |
| |
| def test_orc_batching_exact_counts_multiple_files(tmp_path: Path) -> None: |
| """ |
| Test exact batch counts for multiple ORC files of different sizes and counts. |
| This test explicitly verifies the number of batches PyArrow creates for different file configurations. |
| Note: Uses default ORC writing (1 stripe per file), so expects 1 batch per file. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_batching_exact_counts_multiple_files |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "value", StringType(), required=False), |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Test different file configurations to understand PyArrow's batching behavior |
| # Note: All files will have 1 stripe each (default ORC writing), so 1 batch per file |
| test_cases = [ |
| (2, 500, "2 files, 500 rows each (1 stripe each)"), |
| (3, 1000, "3 files, 1000 rows each (1 stripe each)"), |
| (4, 750, "4 files, 750 rows each (1 stripe each)"), |
| (2, 2000, "2 files, 2000 rows each (1 stripe each)"), |
| ] |
| |
| for num_files, rows_per_file, description in test_cases: |
| total_rows = num_files * rows_per_file |
| scan_tasks = [] |
| |
| for file_idx in range(num_files): |
| # Create data for this file |
| start_id = file_idx * rows_per_file + 1 |
| end_id = (file_idx + 1) * rows_per_file |
| data = pa.table( |
| { |
| "id": pa.array(range(start_id, end_id + 1), type=pa.int32()), |
| "value": [f"file_{file_idx}_value_{i}" for i in range(start_id, end_id + 1)], |
| } |
| ) |
| |
| # Create ORC file |
| orc_path = tmp_path / f"multi_test_{file_idx}_{rows_per_file}_rows.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=rows_per_file, |
| column_sizes={1: rows_per_file * 4, 2: rows_per_file * 8}, |
| value_counts={1: rows_per_file, 2: rows_per_file}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: start_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{start_id}".encode()}, |
| upper_bounds={1: end_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{end_id}".encode()}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| scan_tasks.append(FileScanTask(data_file=data_file)) |
| |
| # Test batching behavior across multiple files |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| |
| batches = list(scan.to_record_batches(scan_tasks)) |
| |
| # Verify exact batch count and sizes |
| total_batch_rows = sum(batch.num_rows for batch in batches) |
| assert total_batch_rows == total_rows, f"Total rows mismatch: expected {total_rows}, got {total_batch_rows}" |
| |
| # Verify data spans all files |
| all_ids = [] |
| file_data_counts = dict.fromkeys(range(num_files), 0) |
| |
| for batch in batches: |
| batch_ids = batch.column("id").to_pylist() |
| all_ids.extend(batch_ids) |
| |
| # Count rows per file in this batch |
| for value in batch.column("value").to_pylist(): |
| if value.startswith("file_"): |
| file_idx = int(value.split("_")[1]) |
| file_data_counts[file_idx] += 1 |
| |
| # Verify we have data from all files |
| assert sorted(all_ids) == list(range(1, total_rows + 1)), f"Data integrity check failed for {description}" |
| |
| # Verify each file contributed data |
| for file_idx in range(num_files): |
| assert file_data_counts[file_idx] == rows_per_file, ( |
| f"File {file_idx} contributed {file_data_counts[file_idx]} rows, expected {rows_per_file}" |
| ) |
| |
| |
| def test_orc_field_id_extraction() -> None: |
| """ |
| Test ORC field ID extraction from PyArrow field metadata. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_field_id_extraction |
| """ |
| import pyarrow as pa |
| |
| from pyiceberg.io.pyarrow import ORC_FIELD_ID_KEY, PYARROW_PARQUET_FIELD_ID_KEY, _get_field_id |
| |
| # Test 1: Parquet field ID extraction |
| field_parquet = pa.field("test_parquet", pa.string(), metadata={PYARROW_PARQUET_FIELD_ID_KEY: b"123"}) |
| field_id = _get_field_id(field_parquet) |
| assert field_id == 123, f"Expected Parquet field ID 123, got {field_id}" |
| |
| # Test 2: ORC field ID extraction |
| field_orc = pa.field("test_orc", pa.string(), metadata={ORC_FIELD_ID_KEY: b"456"}) |
| field_id = _get_field_id(field_orc) |
| assert field_id == 456, f"Expected ORC field ID 456, got {field_id}" |
| |
| # Test 3: No field ID |
| field_no_id = pa.field("test_no_id", pa.string()) |
| field_id = _get_field_id(field_no_id) |
| assert field_id is None, f"Expected None for field without ID, got {field_id}" |
| |
| # Test 4: Both field IDs present (should prefer Parquet) |
| field_both = pa.field("test_both", pa.string(), metadata={PYARROW_PARQUET_FIELD_ID_KEY: b"123", ORC_FIELD_ID_KEY: b"456"}) |
| field_id = _get_field_id(field_both) |
| assert field_id == 123, f"Expected Parquet field ID 123 (preferred), got {field_id}" |
| |
| # Test 5: Empty metadata |
| field_empty_metadata = pa.field("test_empty", pa.string(), metadata={}) |
| field_id = _get_field_id(field_empty_metadata) |
| assert field_id is None, f"Expected None for field with empty metadata, got {field_id}" |
| |
| # Test 6: Invalid field ID format |
| field_invalid = pa.field("test_invalid", pa.string(), metadata={ORC_FIELD_ID_KEY: b"not_a_number"}) |
| try: |
| field_id = _get_field_id(field_invalid) |
| raise AssertionError("Expected ValueError for invalid field ID format") |
| except ValueError: |
| pass # Expected behavior |
| |
| # Test 7: Different data types |
| field_int = pa.field("test_int", pa.int32(), metadata={ORC_FIELD_ID_KEY: b"789"}) |
| field_id = _get_field_id(field_int) |
| assert field_id == 789, f"Expected ORC field ID 789 for int field, got {field_id}" |
| |
| field_bool = pa.field("test_bool", pa.bool_(), metadata={ORC_FIELD_ID_KEY: b"101"}) |
| field_id = _get_field_id(field_bool) |
| assert field_id == 101, f"Expected ORC field ID 101 for bool field, got {field_id}" |
| |
| |
| def test_orc_schema_with_field_ids(tmp_path: Path) -> None: |
| """ |
| Test ORC reading with actual field IDs embedded in the schema. |
| This test creates an ORC file with field IDs and reads it without name mapping. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_schema_with_field_ids |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ORC_FIELD_ID_KEY, ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "name", StringType(), required=False), |
| ) |
| |
| # Create PyArrow schema with ORC field IDs |
| arrow_schema = pa.schema( |
| [ |
| pa.field("id", pa.int32(), metadata={ORC_FIELD_ID_KEY: b"1"}), |
| pa.field("name", pa.string(), metadata={ORC_FIELD_ID_KEY: b"2"}), |
| ] |
| ) |
| |
| # Create data with the schema that has field IDs |
| data = pa.table({"id": pa.array([1, 2, 3], type=pa.int32()), "name": ["alice", "bob", "charlie"]}, schema=arrow_schema) |
| |
| # Create ORC file |
| orc_path = tmp_path / "field_id_test.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create table metadata WITHOUT name mapping (should work with field IDs) |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| # No name mapping - should work with field IDs |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=3, |
| column_sizes={1: 12, 2: 30}, |
| value_counts={1: 3, 2: 3}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"alice"}, |
| upper_bounds={1: b"\x03\x00\x00\x00", 2: b"charlie"}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| |
| # Read back using ArrowScan - should work without name mapping |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| scan_task = FileScanTask(data_file=data_file) |
| table_read = scan.to_table([scan_task]) |
| |
| # Verify the data was read correctly |
| assert table_read.num_rows == 3 |
| assert table_read.num_columns == 2 |
| assert table_read.column_names == ["id", "name"] |
| |
| # Verify data matches |
| assert table_read.column("id").to_pylist() == [1, 2, 3] |
| assert table_read.column("name").to_pylist() == ["alice", "bob", "charlie"] |
| |
| |
| def test_orc_schema_conversion_with_field_ids() -> None: |
| """ |
| Test that schema_to_pyarrow correctly adds ORC field IDs when file_format is specified. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_schema_conversion_with_field_ids |
| """ |
| from pyiceberg.io.pyarrow import ORC_FIELD_ID_KEY, PYARROW_PARQUET_FIELD_ID_KEY, schema_to_pyarrow |
| from pyiceberg.manifest import FileFormat |
| from pyiceberg.schema import Schema |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "name", StringType(), required=False), |
| ) |
| |
| # Test 1: Default behavior (should add Parquet field IDs) |
| arrow_schema_default = schema_to_pyarrow(schema, include_field_ids=True) |
| |
| id_field = arrow_schema_default.field(0) # id field |
| name_field = arrow_schema_default.field(1) # name field |
| |
| assert id_field.metadata[PYARROW_PARQUET_FIELD_ID_KEY] == b"1" |
| assert name_field.metadata[PYARROW_PARQUET_FIELD_ID_KEY] == b"2" |
| assert ORC_FIELD_ID_KEY not in id_field.metadata |
| assert ORC_FIELD_ID_KEY not in name_field.metadata |
| |
| # Test 2: Explicitly specify ORC format |
| arrow_schema_orc = schema_to_pyarrow(schema, include_field_ids=True, file_format=FileFormat.ORC) |
| |
| id_field_orc = arrow_schema_orc.field(0) # id field |
| name_field_orc = arrow_schema_orc.field(1) # name field |
| |
| assert id_field_orc.metadata[ORC_FIELD_ID_KEY] == b"1" |
| assert name_field_orc.metadata[ORC_FIELD_ID_KEY] == b"2" |
| assert PYARROW_PARQUET_FIELD_ID_KEY not in id_field_orc.metadata |
| assert PYARROW_PARQUET_FIELD_ID_KEY not in name_field_orc.metadata |
| |
| # Test 3: No field IDs |
| arrow_schema_no_ids = schema_to_pyarrow(schema, include_field_ids=False, file_format=FileFormat.ORC) |
| |
| id_field_no_ids = arrow_schema_no_ids.field(0) |
| name_field_no_ids = arrow_schema_no_ids.field(1) |
| |
| assert ORC_FIELD_ID_KEY not in id_field_no_ids.metadata |
| assert ORC_FIELD_ID_KEY not in name_field_no_ids.metadata |
| assert PYARROW_PARQUET_FIELD_ID_KEY not in id_field_no_ids.metadata |
| assert PYARROW_PARQUET_FIELD_ID_KEY not in name_field_no_ids.metadata |
| |
| |
| def test_orc_schema_conversion_with_required_attribute() -> None: |
| """ |
| Test that schema_to_pyarrow correctly adds ORC iceberg.required attribute. |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_schema_conversion_with_required_attribute |
| """ |
| from pyiceberg.io.pyarrow import ORC_FIELD_REQUIRED_KEY, schema_to_pyarrow |
| from pyiceberg.manifest import FileFormat |
| from pyiceberg.schema import Schema |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "name", StringType(), required=False), |
| ) |
| |
| # Test 1: Specify Parquet format |
| arrow_schema_default = schema_to_pyarrow(schema, file_format=FileFormat.PARQUET) |
| |
| id_field = arrow_schema_default.field(0) |
| name_field = arrow_schema_default.field(1) |
| |
| assert ORC_FIELD_REQUIRED_KEY not in id_field.metadata |
| assert ORC_FIELD_REQUIRED_KEY not in name_field.metadata |
| |
| # Test 2: Specify ORC format |
| arrow_schema_orc = schema_to_pyarrow(schema, file_format=FileFormat.ORC) |
| |
| id_field_orc = arrow_schema_orc.field(0) |
| name_field_orc = arrow_schema_orc.field(1) |
| |
| assert id_field_orc.metadata[ORC_FIELD_REQUIRED_KEY] == b"true" |
| assert name_field_orc.metadata[ORC_FIELD_REQUIRED_KEY] == b"false" |
| |
| |
| def test_orc_batching_behavior_documentation(tmp_path: Path) -> None: |
| """ |
| Document and verify PyArrow's exact batching behavior for ORC files. |
| This test serves as comprehensive documentation of how PyArrow batches ORC files. |
| |
| ORC BATCHING BEHAVIOR SUMMARY: |
| ============================= |
| |
| 1. STRIPE-BASED BATCHING: |
| - PyArrow creates exactly 1 batch per ORC stripe |
| - This is similar to how Parquet creates 1 batch per row group |
| - Number of batches = Number of stripes in the ORC file |
| |
| 2. DEFAULT BEHAVIOR: |
| - Default ORC writing creates 1 stripe per file (64MB default stripe size) |
| - Therefore, most ORC files have 1 batch per file by default |
| - This is why many tests show "1 batch per file" behavior |
| |
| 3. CONFIGURABLE BATCHING: |
| - ORC CAN have multiple batches per file when configured with multiple stripes |
| - Use stripe_size parameter when writing ORC files to control batching |
| - stripe_size < 200KB: PyArrow ignores the parameter, uses default 1024 rows per stripe |
| - stripe_size >= 200KB: PyArrow respects the parameter and creates stripes accordingly |
| |
| 4. PYARROW CONFIGURATION: |
| - PyIceberg sets buffer_size=8MB for both Parquet and ORC |
| - Parquet: Gets the 8MB buffer_size parameter (ds.ParquetFileFormat supports it) |
| - ORC: Ignores the 8MB buffer_size parameter (ds.OrcFileFormat doesn't support it) |
| - This means ORC uses PyArrow's default batching behavior (based on stripes) |
| |
| 5. KEY DIFFERENCES FROM PARQUET: |
| - Parquet: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on row groups) |
| - ORC: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on stripes) |
| - Both formats support multiple batches per file when configured properly |
| - The difference is in default configuration, not fundamental behavior |
| |
| 6. TESTING IMPLICATIONS: |
| - Tests using default ORC writing will show 1 batch per file |
| - Tests using custom stripe_size >= 200KB will show multiple batches per file |
| - Always verify the actual number of stripes in ORC files when testing batching |
| |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_batching_behavior_documentation |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "value", StringType(), required=False), |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Test cases that document the exact behavior (using default ORC writing = 1 stripe per file) |
| test_cases = [ |
| # (file_count, rows_per_file, expected_batches, description) |
| (1, 100, 1, "Single small file (1 stripe)"), |
| (1, 1000, 1, "Single medium file (1 stripe)"), |
| (1, 5000, 1, "Single large file (1 stripe)"), |
| (2, 500, 2, "Two small files (1 stripe each)"), |
| (3, 1000, 3, "Three medium files (1 stripe each)"), |
| (4, 750, 4, "Four small files (1 stripe each)"), |
| (2, 2000, 2, "Two large files (1 stripe each)"), |
| ] |
| |
| for file_count, rows_per_file, expected_batches, description in test_cases: |
| total_rows = file_count * rows_per_file |
| scan_tasks = [] |
| |
| for file_idx in range(file_count): |
| # Create data for this file |
| start_id = file_idx * rows_per_file + 1 |
| end_id = (file_idx + 1) * rows_per_file |
| data = pa.table( |
| { |
| "id": pa.array(range(start_id, end_id + 1), type=pa.int32()), |
| "value": [f"file_{file_idx}_value_{i}" for i in range(start_id, end_id + 1)], |
| } |
| ) |
| |
| # Create ORC file |
| orc_path = tmp_path / f"doc_test_{file_idx}_{rows_per_file}_rows.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=rows_per_file, |
| column_sizes={1: rows_per_file * 4, 2: rows_per_file * 8}, |
| value_counts={1: rows_per_file, 2: rows_per_file}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: start_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{start_id}".encode()}, |
| upper_bounds={1: end_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{end_id}".encode()}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| scan_tasks.append(FileScanTask(data_file=data_file)) |
| |
| # Test batching behavior |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| |
| batches = list(scan.to_record_batches(scan_tasks)) |
| |
| # Verify exact batch count |
| assert len(batches) == expected_batches, f"Expected {expected_batches} batches, got {len(batches)} for {description}" |
| |
| # Verify total rows |
| total_batch_rows = sum(batch.num_rows for batch in batches) |
| assert total_batch_rows == total_rows, f"Total rows mismatch: expected {total_rows}, got {total_batch_rows}" |
| |
| # Verify data integrity |
| all_ids = [] |
| for batch in batches: |
| all_ids.extend(batch.column("id").to_pylist()) |
| assert sorted(all_ids) == list(range(1, total_rows + 1)), f"Data integrity check failed for {description}" |
| |
| |
| def test_parquet_vs_orc_batching_behavior_comparison(tmp_path: Path) -> None: |
| """ |
| Compare Parquet vs ORC batching behavior to document the key differences. |
| |
| Key differences: |
| - PARQUET: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on row groups) |
| - ORC: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on stripes) |
| |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_parquet_vs_orc_batching_behavior_comparison |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| import pyarrow.parquet as pq |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, NestedField, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "value", StringType(), required=False), |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Test Parquet with different row group sizes |
| parquet_test_cases = [ |
| (1000, "Small row groups"), |
| (2000, "Medium row groups"), |
| (5000, "Large row groups"), |
| ] |
| |
| for row_group_size, _description in parquet_test_cases: |
| # Create data |
| data = pa.table({"id": pa.array(range(1, 10001), type=pa.int32()), "value": [f"value_{i}" for i in range(1, 10001)]}) |
| |
| # Create Parquet file with specific row group size |
| parquet_path = tmp_path / f"parquet_test_{row_group_size}.parquet" |
| pq.write_table(data, str(parquet_path), row_group_size=row_group_size) |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(parquet_path), |
| file_format=FileFormat.PARQUET, |
| partition=Record(), |
| file_size_in_bytes=parquet_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=10000, |
| column_sizes={1: 40000, 2: 80000}, |
| value_counts={1: 10000, 2: 10000}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"}, |
| upper_bounds={1: b"\x10'\x00\x00", 2: b"value_10000"}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| scan_task = FileScanTask(data_file=data_file) |
| |
| # Test batching behavior |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| |
| batches = list(scan.to_record_batches([scan_task])) |
| expected_batches = 10000 // row_group_size # Number of row groups |
| |
| # Verify exact batch count based on row groups |
| assert len(batches) == expected_batches, ( |
| f"Expected {expected_batches} batches for row_group_size={row_group_size}, got {len(batches)}" |
| ) |
| |
| # Verify total rows |
| total_rows = sum(batch.num_rows for batch in batches) |
| assert total_rows == 10000, f"Expected 10000 total rows, got {total_rows}" |
| |
| orc_test_cases = [ |
| (1000, "Small file"), |
| (5000, "Medium file"), |
| (10000, "Large file"), |
| ] |
| |
| for file_size, _description in orc_test_cases: |
| # Create data |
| data = pa.table( |
| {"id": pa.array(range(1, file_size + 1), type=pa.int32()), "value": [f"value_{i}" for i in range(1, file_size + 1)]} |
| ) |
| |
| # Create ORC file |
| orc_path = tmp_path / f"orc_test_{file_size}.orc" |
| orc.write_table(data, str(orc_path)) |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=file_size, |
| column_sizes={1: file_size * 4, 2: file_size * 8}, |
| value_counts={1: file_size, 2: file_size}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"}, |
| upper_bounds={1: file_size.to_bytes(4, "little"), 2: f"value_{file_size}".encode()}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| scan_task = FileScanTask(data_file=data_file) |
| |
| # Test batching behavior |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| |
| batches = list(scan.to_record_batches([scan_task])) |
| |
| # Verify ORC creates 1 batch per file (with default stripe configuration) |
| # Note: This is because default ORC writing creates 1 stripe per file |
| assert len(batches) == 1, ( |
| f"Expected 1 batch for ORC file with {file_size} rows (default stripe config), got {len(batches)}" |
| ) |
| |
| # Verify total rows |
| total_rows = sum(batch.num_rows for batch in batches) |
| assert total_rows == file_size, f"Expected {file_size} total rows, got {total_rows}" |
| |
| |
| def test_orc_stripe_size_batch_size_compression_interaction(tmp_path: Path) -> None: |
| """ |
| Test that demonstrates how stripe size, batch size, and compression interact |
| to affect ORC batching behavior. |
| |
| This test shows: |
| 1. How stripe_size affects the number of stripes (and therefore batches) |
| 2. How batch_size affects the number of stripes when stripe_size is small |
| 3. How compression affects both stripe count and file size |
| 4. The relationship between uncompressed target size and actual file size |
| |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_stripe_size_batch_size_compression_interaction |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import IntegerType, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "value", StringType(), required=False), |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Create test data |
| data = pa.table({"id": pa.array(range(1, 10001), type=pa.int32()), "value": [f"value_{i}" for i in range(1, 10001)]}) |
| |
| # Test different combinations |
| test_cases = [ |
| # (stripe_size, batch_size, compression, description) |
| (200000, None, "uncompressed", "200KB stripe, no batch limit, uncompressed"), |
| (200000, 1000, "uncompressed", "200KB stripe, 1000 batch, uncompressed"), |
| (200000, None, "snappy", "200KB stripe, no batch limit, snappy"), |
| (100000, None, "uncompressed", "100KB stripe, no batch limit, uncompressed"), |
| (500000, None, "uncompressed", "500KB stripe, no batch limit, uncompressed"), |
| (None, 1000, "uncompressed", "No stripe limit, 1000 batch, uncompressed"), |
| (None, 2000, "uncompressed", "No stripe limit, 2000 batch, uncompressed"), |
| ] |
| |
| for stripe_size, batch_size, compression, description in test_cases: |
| # Create ORC file with specific parameters |
| orc_path = tmp_path / f"orc_test_{hash(description)}.orc" |
| |
| write_kwargs: dict[str, Any] = {"compression": compression} |
| if stripe_size is not None: |
| write_kwargs["stripe_size"] = stripe_size |
| if batch_size is not None: |
| write_kwargs["batch_size"] = batch_size |
| |
| orc.write_table(data, str(orc_path), **write_kwargs) |
| |
| # Analyze the ORC file |
| file_size = orc_path.stat().st_size |
| orc_file = orc.ORCFile(str(orc_path)) |
| actual_stripes = orc_file.nstripes |
| |
| # Assert basic file properties |
| assert file_size > 0, f"ORC file should have non-zero size for {description}" |
| assert actual_stripes > 0, f"ORC file should have at least one stripe for {description}" |
| |
| # Assert stripe count expectations based on stripe_size and compression |
| if stripe_size is not None: |
| # With stripe_size specified, we expect multiple stripes for small sizes |
| # But compression can make the data small enough to fit in one stripe |
| if stripe_size <= 200000 and compression == "uncompressed": # 200KB or smaller, uncompressed |
| assert actual_stripes > 1, ( |
| f"Expected multiple stripes for small stripe_size={stripe_size} in {description}, got {actual_stripes}" |
| ) |
| else: # Larger stripe sizes or compressed data might result in single stripe |
| assert actual_stripes >= 1, ( |
| f"Expected at least 1 stripe for stripe_size={stripe_size} in {description}, got {actual_stripes}" |
| ) |
| else: |
| # Without stripe_size, we expect at least 1 stripe |
| assert actual_stripes >= 1, f"Expected at least 1 stripe for {description}, got {actual_stripes}" |
| |
| # Test PyArrow batching |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=file_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=10000, |
| column_sizes={1: 40000, 2: 80000}, |
| value_counts={1: 10000, 2: 10000}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"}, |
| upper_bounds={1: b"\x10'\x00\x00", 2: b"value_10000"}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| |
| # Test batching behavior |
| scan_task = FileScanTask(data_file=data_file) |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| batches = list(scan.to_record_batches([scan_task])) |
| |
| # Assert batching behavior |
| assert len(batches) > 0, f"Should have at least one batch for {description}" |
| assert len(batches) == actual_stripes, ( |
| f"Number of batches should match number of stripes for {description}: " |
| f"{len(batches)} batches vs {actual_stripes} stripes" |
| ) |
| |
| # Assert data integrity |
| total_rows = sum(batch.num_rows for batch in batches) |
| assert total_rows == 10000, f"Total rows should be 10000 for {description}, got {total_rows}" |
| |
| # Assert compression effect |
| if compression == "snappy": |
| # Snappy compression should result in smaller file size than uncompressed |
| uncompressed_size = len(data) * 15 # Rough estimate of uncompressed size |
| assert file_size < uncompressed_size, ( |
| f"Snappy compression should reduce file size for {description}: {file_size} vs {uncompressed_size}" |
| ) |
| elif compression == "uncompressed": |
| # Uncompressed should be larger than snappy |
| assert file_size > 0, f"Uncompressed file should have size > 0 for {description}" |
| |
| |
| def test_orc_near_perfect_stripe_size_mapping(tmp_path: Path) -> None: |
| """ |
| Test that demonstrates near-perfect 1:1 mapping between stripe size and actual file size. |
| |
| This test shows how to achieve ratios of 0.9+ (actual/target) by using: |
| 1. Large stripe sizes (2-5MB) |
| 2. Large datasets (50K+ rows) |
| 3. Data that is hard to compress (large random strings) |
| 4. uncompressed compression setting |
| |
| This is the closest we can get to having stripe size directly map to number of batches |
| without significant ORC encoding overhead. |
| |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_near_perfect_stripe_size_mapping |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.types import NestedField, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", StringType(), required=True), # Large string field |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=1, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Create large dataset with hard-to-compress data |
| data = pa.table( |
| { |
| "id": pa.array( |
| [ |
| ( |
| f"very_long_string_value_{i:06d}_with_lots_of_padding_to_make_it_harder_to_compress_" |
| f"{i * 7919 % 100000:05d}_more_padding_{i * 7919 % 100000:05d}" |
| ) |
| for i in range(1, 50001) |
| ] |
| ) # 50K rows |
| } |
| ) |
| |
| # Test with large stripe sizes that should give us multiple stripes |
| test_cases = [ |
| (2000000, "2MB stripe size"), |
| (3000000, "3MB stripe size"), |
| (4000000, "4MB stripe size"), |
| (5000000, "5MB stripe size"), |
| ] |
| |
| for stripe_size, _description in test_cases: |
| # Create ORC file with specific stripe size |
| orc_path = tmp_path / f"orc_perfect_test_{stripe_size}.orc" |
| orc.write_table(data, str(orc_path), stripe_size=stripe_size, compression="uncompressed") |
| |
| # Analyze the ORC file |
| file_size = orc_path.stat().st_size |
| orc_file = orc.ORCFile(str(orc_path)) |
| actual_stripes = orc_file.nstripes |
| |
| # Assert basic file properties |
| assert file_size > 0, f"ORC file should have non-zero size for stripe_size={stripe_size}" |
| assert actual_stripes > 0, f"ORC file should have at least one stripe for stripe_size={stripe_size}" |
| |
| # Assert that larger stripe sizes result in fewer stripes |
| # With 50K rows of large strings, we expect multiple stripes for smaller stripe sizes |
| if stripe_size <= 3000000: # 3MB or smaller |
| assert actual_stripes > 1, f"Expected multiple stripes for stripe_size={stripe_size}, got {actual_stripes}" |
| else: # Larger stripe sizes might result in single stripe |
| assert actual_stripes >= 1, f"Expected at least 1 stripe for stripe_size={stripe_size}, got {actual_stripes}" |
| |
| # Test PyArrow batching |
| from pyiceberg.manifest import DataFile |
| from pyiceberg.typedef import Record |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=file_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=50000, |
| column_sizes={1: 5450000}, |
| value_counts={1: 50000}, |
| null_value_counts={1: 0}, |
| nan_value_counts={1: 0}, |
| lower_bounds={1: b"very_long_string_value_000001_"}, |
| upper_bounds={1: b"very_long_string_value_050000_"}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| |
| # Test batching behavior |
| scan_task = FileScanTask(data_file=data_file) |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| batches = list(scan.to_record_batches([scan_task])) |
| |
| # Assert batching behavior |
| assert len(batches) > 0, f"Should have at least one batch for stripe_size={stripe_size}" |
| assert len(batches) == actual_stripes, ( |
| f"Number of batches should match number of stripes for stripe_size={stripe_size}: " |
| f"{len(batches)} batches vs {actual_stripes} stripes" |
| ) |
| |
| # Assert data integrity |
| total_rows = sum(batch.num_rows for batch in batches) |
| assert total_rows == 50000, f"Total rows should be 50000 for stripe_size={stripe_size}, got {total_rows}" |
| |
| # Assert compression ratio is reasonable (uncompressed should be close to raw data size) |
| raw_data_size = data.nbytes |
| compression_ratio = raw_data_size / file_size if file_size > 0 else 0 |
| assert compression_ratio > 0.5, ( |
| f"Compression ratio should be reasonable for stripe_size={stripe_size}: {compression_ratio:.2f}" |
| ) |
| assert compression_ratio < 2.0, ( |
| f"Compression ratio should not be too high for stripe_size={stripe_size}: {compression_ratio:.2f}" |
| ) |
| |
| |
| def test_orc_stripe_based_batching(tmp_path: Path) -> None: |
| """ |
| Test ORC stripe-based batching to demonstrate that ORC can have multiple batches per file. |
| This corrects the previous understanding that ORC always has 1 batch per file. |
| |
| This test uses hardcoded expected values based on observed behavior with 10,000 rows: |
| - 200KB stripe size: 5 stripes of [2048, 2048, 2048, 2048, 1808] rows |
| - 400KB stripe size: 2 stripes of [7168, 2832] rows |
| - 600KB stripe size: 1 stripe of [10000] rows |
| |
| To run just this test: |
| pytest tests/io/test_pyarrow.py -k test_orc_stripe_based_batching |
| """ |
| import pyarrow as pa |
| import pyarrow.orc as orc |
| |
| from pyiceberg.expressions import AlwaysTrue |
| from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO |
| from pyiceberg.manifest import DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import FileScanTask |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.typedef import Record |
| from pyiceberg.types import IntegerType, NestedField, StringType |
| |
| # Define schema |
| schema = Schema( |
| NestedField(1, "id", IntegerType(), required=True), |
| NestedField(2, "value", StringType(), required=False), |
| ) |
| |
| # Create table metadata |
| table_metadata = TableMetadataV2( |
| location=f"file://{tmp_path}/test_location", |
| last_column_id=2, |
| format_version=2, |
| schemas=[schema], |
| partition_specs=[PartitionSpec()], |
| properties={ |
| "schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]', |
| }, |
| ) |
| io = PyArrowFileIO() |
| |
| # Test ORC with different stripe configurations (stripe_size in bytes) |
| # Note: PyArrow ORC ignores stripe_size < 200KB, so we use larger values |
| # Expected values are hardcoded based on observed behavior with 10,000 rows |
| test_cases = [ |
| (200000, "Small stripes (200KB)", 5, [2048, 2048, 2048, 2048, 1808]), |
| (400000, "Medium stripes (400KB)", 2, [7168, 2832]), |
| (600000, "Large stripes (600KB)", 1, [10000]), |
| ] |
| |
| for stripe_size, _description, expected_stripes, expected_stripe_sizes in test_cases: |
| # Create data |
| data = pa.table({"id": pa.array(range(1, 10001), type=pa.int32()), "value": [f"value_{i}" for i in range(1, 10001)]}) |
| |
| # Create ORC file with specific stripe size (in bytes) |
| orc_path = tmp_path / f"orc_stripe_test_{stripe_size}.orc" |
| orc.write_table(data, str(orc_path), stripe_size=stripe_size) |
| |
| # Check ORC metadata |
| orc_file = orc.ORCFile(str(orc_path)) |
| actual_stripes = orc_file.nstripes |
| actual_stripe_sizes = [orc_file.read_stripe(i).num_rows for i in range(actual_stripes)] |
| |
| # Create DataFile |
| from pyiceberg.manifest import DataFile |
| |
| data_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=str(orc_path), |
| file_format=FileFormat.ORC, |
| partition=Record(), |
| file_size_in_bytes=orc_path.stat().st_size, |
| sort_order_id=None, |
| spec_id=0, |
| equality_ids=None, |
| key_metadata=None, |
| record_count=10000, |
| column_sizes={1: 40000, 2: 80000}, |
| value_counts={1: 10000, 2: 10000}, |
| null_value_counts={1: 0, 2: 0}, |
| nan_value_counts={1: 0, 2: 0}, |
| lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"}, |
| upper_bounds={1: b"\x10'\x00\x00", 2: b"value_10000"}, |
| split_offsets=None, |
| ) |
| data_file.spec_id = 0 |
| |
| # Test batching behavior |
| scan_task = FileScanTask(data_file=data_file) |
| scan = ArrowScan( |
| table_metadata=table_metadata, |
| io=io, |
| projected_schema=schema, |
| row_filter=AlwaysTrue(), |
| case_sensitive=True, |
| ) |
| batches = list(scan.to_record_batches([scan_task])) |
| |
| # CRITICAL: Verify we get multiple batches for a single file (when stripe size is small enough) |
| if expected_stripes > 1: |
| assert len(batches) > 1, f"Expected multiple batches for single file, got {len(batches)} batches" |
| assert actual_stripes > 1, f"Expected multiple stripes for single file, got {actual_stripes} stripes" |
| |
| # Verify exact batch count matches expected |
| assert len(batches) == expected_stripes, f"Expected {expected_stripes} batches, got {len(batches)}" |
| |
| # Verify batch sizes match expected stripe sizes |
| batch_sizes = [batch.num_rows for batch in batches] |
| assert batch_sizes == expected_stripe_sizes, ( |
| f"Batch sizes {batch_sizes} don't match expected stripe sizes {expected_stripe_sizes}" |
| ) |
| |
| # Verify actual ORC metadata matches expected |
| assert actual_stripes == expected_stripes, f"Expected {expected_stripes} stripes, got {actual_stripes}" |
| assert actual_stripe_sizes == expected_stripe_sizes, ( |
| f"Expected stripe sizes {expected_stripe_sizes}, got {actual_stripe_sizes}" |
| ) |
| |
| # Verify total rows |
| total_rows = sum(batch.num_rows for batch in batches) |
| assert total_rows == 10000, f"Expected 10000 total rows, got {total_rows}" |
| |
| |
| def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCatalog) -> None: |
| """Test column projection on partitioned table after schema evolution (https://github.com/apache/iceberg-python/issues/2672).""" |
| initial_schema = Schema( |
| NestedField(1, "partition_date", DateType(), required=False), |
| NestedField(2, "id", IntegerType(), required=False), |
| NestedField(3, "name", StringType(), required=False), |
| NestedField(4, "value", IntegerType(), required=False), |
| ) |
| |
| partition_spec = PartitionSpec( |
| PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="partition_date"), |
| ) |
| |
| catalog.create_namespace("default") |
| table = catalog.create_table( |
| "default.test_schema_evolution_projection", |
| schema=initial_schema, |
| partition_spec=partition_spec, |
| ) |
| |
| data_v1 = pa.Table.from_pylist( |
| [ |
| {"partition_date": date(2024, 1, 1), "id": 1, "name": "Alice", "value": 100}, |
| {"partition_date": date(2024, 1, 1), "id": 2, "name": "Bob", "value": 200}, |
| ], |
| schema=pa.schema( |
| [ |
| ("partition_date", pa.date32()), |
| ("id", pa.int32()), |
| ("name", pa.string()), |
| ("value", pa.int32()), |
| ] |
| ), |
| ) |
| |
| table.append(data_v1) |
| |
| with table.update_schema() as update: |
| update.add_column("new_column", StringType()) |
| |
| table = catalog.load_table("default.test_schema_evolution_projection") |
| |
| data_v2 = pa.Table.from_pylist( |
| [ |
| {"partition_date": date(2024, 1, 2), "id": 3, "name": "Charlie", "value": 300, "new_column": "new1"}, |
| {"partition_date": date(2024, 1, 2), "id": 4, "name": "David", "value": 400, "new_column": "new2"}, |
| ], |
| schema=pa.schema( |
| [ |
| ("partition_date", pa.date32()), |
| ("id", pa.int32()), |
| ("name", pa.string()), |
| ("value", pa.int32()), |
| ("new_column", pa.string()), |
| ] |
| ), |
| ) |
| |
| table.append(data_v2) |
| |
| result = table.scan(selected_fields=("id", "name", "value", "new_column")).to_arrow() |
| |
| assert set(result.schema.names) == {"id", "name", "value", "new_column"} |
| assert result.num_rows == 4 |
| result_sorted = result.sort_by("name") |
| assert result_sorted["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David"] |
| assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"] |