blob: 2170741bdd04b2b7efa62824028252dcb85c4a84 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=protected-access,unused-argument,redefined-outer-name
import logging
import os
import tempfile
import uuid
import warnings
from datetime import date, datetime, timezone
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pyarrow
import pyarrow as pa
import pyarrow.orc as orc
import pyarrow.parquet as pq
import pytest
from packaging import version
from pyarrow.fs import AwsDefaultS3RetryStrategy, FileType, LocalFileSystem, S3FileSystem
from pyiceberg.exceptions import ResolveError
from pyiceberg.expressions import (
AlwaysFalse,
AlwaysTrue,
And,
BooleanExpression,
BoundEqualTo,
BoundGreaterThan,
BoundGreaterThanOrEqual,
BoundIn,
BoundIsNaN,
BoundIsNull,
BoundLessThan,
BoundLessThanOrEqual,
BoundNotEqualTo,
BoundNotIn,
BoundNotNaN,
BoundNotNull,
BoundNotStartsWith,
BoundReference,
BoundStartsWith,
GreaterThan,
Not,
Or,
)
from pyiceberg.expressions.literals import literal
from pyiceberg.io import S3_RETRY_STRATEGY_IMPL, InputStream, OutputStream, load_file_io
from pyiceberg.io.pyarrow import (
ICEBERG_SCHEMA,
PYARROW_PARQUET_FIELD_ID_KEY,
ArrowScan,
PyArrowFile,
PyArrowFileIO,
StatsAggregator,
_check_pyarrow_schema_compatible,
_ConvertToArrowSchema,
_determine_partitions,
_primitive_to_physical,
_read_deletes,
_task_to_record_batches,
_to_requested_schema,
bin_pack_arrow_table,
compute_statistics_plan,
data_file_statistics_from_parquet_metadata,
expression_to_pyarrow,
parquet_path_to_id_mapping,
schema_to_pyarrow,
write_file,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import HourTransform, IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record, TableVersion
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
GeographyType,
GeometryType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
PrimitiveType,
StringType,
StructType,
TimestampNanoType,
TimestampType,
TimestamptzType,
TimeType,
)
from tests.catalog.test_base import InMemoryCatalog
from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES
skip_if_pyarrow_too_old = pytest.mark.skipif(
version.parse(pyarrow.__version__) < version.parse("20.0.0"),
reason="Requires pyarrow version >= 20.0.0",
)
def test_pyarrow_infer_local_fs_from_path() -> None:
"""Test path with `file` scheme and no scheme both use LocalFileSystem"""
assert isinstance(PyArrowFileIO().new_output("file://tmp/warehouse")._filesystem, LocalFileSystem)
assert isinstance(PyArrowFileIO().new_output("/tmp/warehouse")._filesystem, LocalFileSystem)
def test_pyarrow_local_fs_can_create_path_without_parent_dir() -> None:
"""Test LocalFileSystem can create path without first creating the parent directories"""
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = f"{tmpdirname}/foo/bar/baz.txt"
output_file = PyArrowFileIO().new_output(file_path)
parent_path = os.path.dirname(file_path)
assert output_file._filesystem.get_file_info(parent_path).type == FileType.NotFound
try:
with output_file.create() as f:
f.write(b"foo")
except Exception:
pytest.fail("Failed to write to file without parent directory")
def test_pyarrow_input_file() -> None:
"""Test reading a file using PyArrowFile"""
with tempfile.TemporaryDirectory() as tmpdirname:
file_location = os.path.join(tmpdirname, "foo.txt")
with open(file_location, "wb") as f:
f.write(b"foo")
# Confirm that the file initially exists
assert os.path.exists(file_location)
# Instantiate the input file
absolute_file_location = os.path.abspath(file_location)
input_file = PyArrowFileIO().new_input(location=f"{absolute_file_location}")
# Test opening and reading the file
r = input_file.open(seekable=False)
assert isinstance(r, InputStream) # Test that the file object abides by the InputStream protocol
data = r.read()
assert data == b"foo"
assert len(input_file) == 3
with pytest.raises(OSError) as exc_info:
r.seek(0, 0)
assert "only valid on seekable files" in str(exc_info.value)
def test_pyarrow_input_file_seekable() -> None:
"""Test reading a file using PyArrowFile"""
with tempfile.TemporaryDirectory() as tmpdirname:
file_location = os.path.join(tmpdirname, "foo.txt")
with open(file_location, "wb") as f:
f.write(b"foo")
# Confirm that the file initially exists
assert os.path.exists(file_location)
# Instantiate the input file
absolute_file_location = os.path.abspath(file_location)
input_file = PyArrowFileIO().new_input(location=f"{absolute_file_location}")
# Test opening and reading the file
r = input_file.open(seekable=True)
assert isinstance(r, InputStream) # Test that the file object abides by the InputStream protocol
data = r.read()
assert data == b"foo"
assert len(input_file) == 3
r.seek(0, 0)
data = r.read()
assert data == b"foo"
assert len(input_file) == 3
def test_pyarrow_output_file() -> None:
"""Test writing a file using PyArrowFile"""
with tempfile.TemporaryDirectory() as tmpdirname:
file_location = os.path.join(tmpdirname, "foo.txt")
# Instantiate the output file
absolute_file_location = os.path.abspath(file_location)
output_file = PyArrowFileIO().new_output(location=f"{absolute_file_location}")
# Create the output file and write to it
f = output_file.create()
assert isinstance(f, OutputStream) # Test that the file object abides by the OutputStream protocol
f.write(b"foo")
# Confirm that bytes were written
with open(file_location, "rb") as f:
assert f.read() == b"foo"
assert len(output_file) == 3
def test_pyarrow_invalid_scheme() -> None:
"""Test that a ValueError is raised if a location is provided with an invalid scheme"""
with pytest.raises(ValueError) as exc_info:
PyArrowFileIO().new_input("foo://bar/baz.txt")
assert "Unrecognized filesystem type in URI" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
PyArrowFileIO().new_output("foo://bar/baz.txt")
assert "Unrecognized filesystem type in URI" in str(exc_info.value)
def test_pyarrow_violating_input_stream_protocol() -> None:
"""Test that a TypeError is raised if an input file is provided that violates the InputStream protocol"""
# Missing seek, tell, closed, and close
input_file_mock = MagicMock(spec=["read"])
# Create a mocked filesystem that returns input_file_mock
filesystem_mock = MagicMock()
filesystem_mock.open_input_file.return_value = input_file_mock
input_file = PyArrowFile("foo.txt", path="foo.txt", fs=filesystem_mock)
f = input_file.open()
assert not isinstance(f, InputStream)
def test_pyarrow_violating_output_stream_protocol() -> None:
"""Test that a TypeError is raised if an output stream is provided that violates the OutputStream protocol"""
# Missing closed, and close
output_file_mock = MagicMock(spec=["write", "exists"])
output_file_mock.exists.return_value = False
file_info_mock = MagicMock()
file_info_mock.type = FileType.NotFound
# Create a mocked filesystem that returns output_file_mock
filesystem_mock = MagicMock()
filesystem_mock.open_output_stream.return_value = output_file_mock
filesystem_mock.get_file_info.return_value = file_info_mock
output_file = PyArrowFile("foo.txt", path="foo.txt", fs=filesystem_mock)
f = output_file.create()
assert not isinstance(f, OutputStream)
def test_raise_on_opening_a_local_file_not_found() -> None:
"""Test that a PyArrowFile raises appropriately when a local file is not found"""
with tempfile.TemporaryDirectory() as tmpdirname:
file_location = os.path.join(tmpdirname, "foo.txt")
f = PyArrowFileIO().new_input(file_location)
with pytest.raises(FileNotFoundError) as exc_info:
f.open()
assert "[Errno 2] Failed to open local file" in str(exc_info.value)
def test_raise_on_opening_an_s3_file_no_permission() -> None:
"""Test that opening a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'"""
s3fs_mock = MagicMock()
s3fs_mock.open_input_file.side_effect = OSError("AWS Error [code 15]")
f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock)
with pytest.raises(PermissionError) as exc_info:
f.open()
assert "Cannot open file, access denied:" in str(exc_info.value)
def test_raise_on_opening_an_s3_file_not_found() -> None:
"""Test that a PyArrowFile raises a FileNotFoundError when the pyarrow error includes 'Path does not exist'"""
s3fs_mock = MagicMock()
s3fs_mock.open_input_file.side_effect = OSError("Path does not exist")
f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock)
with pytest.raises(FileNotFoundError) as exc_info:
f.open()
assert "Cannot open file, does not exist:" in str(exc_info.value)
@patch("pyiceberg.io.pyarrow.PyArrowFile.exists", return_value=False)
def test_raise_on_creating_an_s3_file_no_permission(_: Any) -> None:
"""Test that creating a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'"""
s3fs_mock = MagicMock()
s3fs_mock.open_output_stream.side_effect = OSError("AWS Error [code 15]")
f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock)
with pytest.raises(PermissionError) as exc_info:
f.create()
assert "Cannot create file, access denied:" in str(exc_info.value)
def test_deleting_s3_file_no_permission() -> None:
"""Test that a PyArrowFile raises a PermissionError when the pyarrow OSError includes 'AWS Error [code 15]'"""
s3fs_mock = MagicMock()
s3fs_mock.delete_file.side_effect = OSError("AWS Error [code 15]")
with patch.object(PyArrowFileIO, "_initialize_fs") as submocked:
submocked.return_value = s3fs_mock
with pytest.raises(PermissionError) as exc_info:
PyArrowFileIO().delete("s3://foo/bar.txt")
assert "Cannot delete file, access denied:" in str(exc_info.value)
def test_deleting_s3_file_not_found() -> None:
"""Test that a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'"""
s3fs_mock = MagicMock()
s3fs_mock.delete_file.side_effect = OSError("Path does not exist")
with patch.object(PyArrowFileIO, "_initialize_fs") as submocked:
submocked.return_value = s3fs_mock
with pytest.raises(FileNotFoundError) as exc_info:
PyArrowFileIO().delete("s3://foo/bar.txt")
assert "Cannot delete file, does not exist:" in str(exc_info.value)
def test_deleting_hdfs_file_not_found() -> None:
"""Test that a PyArrowFile raises a PermissionError when the pyarrow error includes 'No such file or directory'"""
hdfs_mock = MagicMock()
hdfs_mock.delete_file.side_effect = OSError("Path does not exist")
with patch.object(PyArrowFileIO, "_initialize_fs") as submocked:
submocked.return_value = hdfs_mock
with pytest.raises(FileNotFoundError) as exc_info:
PyArrowFileIO().delete("hdfs://foo/bar.txt")
assert "Cannot delete file, does not exist:" in str(exc_info.value)
def test_pyarrow_s3_session_properties() -> None:
session_properties: Properties = {
"s3.endpoint": "http://localhost:9000",
"s3.access-key-id": "admin",
"s3.secret-access-key": "password",
"s3.region": "us-east-1",
"s3.session-token": "s3.session-token",
**UNIFIED_AWS_SESSION_PROPERTIES,
}
with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
s3_fileio = PyArrowFileIO(properties=session_properties)
filename = str(uuid.uuid4())
# Mock `resolve_s3_region` to prevent from the location used resolving to a different s3 region
mock_s3_region_resolver.side_effect = OSError("S3 bucket is not found")
s3_fileio.new_input(location=f"s3://warehouse/{filename}")
mock_s3fs.assert_called_with(
endpoint_override="http://localhost:9000",
access_key="admin",
secret_key="password",
region="us-east-1",
session_token="s3.session-token",
)
def test_pyarrow_s3_session_properties_with_anonymous() -> None:
session_properties: Properties = {
"s3.anonymous": "true",
"s3.endpoint": "http://localhost:9000",
"s3.access-key-id": "admin",
"s3.secret-access-key": "password",
"s3.region": "us-east-1",
"s3.session-token": "s3.session-token",
**UNIFIED_AWS_SESSION_PROPERTIES,
}
with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
s3_fileio = PyArrowFileIO(properties=session_properties)
filename = str(uuid.uuid4())
# Mock `resolve_s3_region` to prevent from the location used resolving to a different s3 region
mock_s3_region_resolver.side_effect = OSError("S3 bucket is not found")
s3_fileio.new_input(location=f"s3://warehouse/{filename}")
mock_s3fs.assert_called_with(
anonymous=True,
endpoint_override="http://localhost:9000",
access_key="admin",
secret_key="password",
region="us-east-1",
session_token="s3.session-token",
)
def test_pyarrow_unified_session_properties() -> None:
session_properties: Properties = {
"s3.endpoint": "http://localhost:9000",
**UNIFIED_AWS_SESSION_PROPERTIES,
}
with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
s3_fileio = PyArrowFileIO(properties=session_properties)
filename = str(uuid.uuid4())
mock_s3_region_resolver.return_value = "client.region"
s3_fileio.new_input(location=f"s3://warehouse/{filename}")
mock_s3fs.assert_called_with(
endpoint_override="http://localhost:9000",
access_key="client.access-key-id",
secret_key="client.secret-access-key",
region="client.region",
session_token="client.session-token",
)
def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested)
expected = """foo: large_string
-- field metadata --
PARQUET:field_id: '1'
bar: int32 not null
-- field metadata --
PARQUET:field_id: '2'
baz: bool
-- field metadata --
PARQUET:field_id: '3'
qux: large_list<element: large_string not null> not null
child 0, element: large_string not null
-- field metadata --
PARQUET:field_id: '5'
-- field metadata --
PARQUET:field_id: '4'
quux: map<large_string, map<large_string, int32>> not null
child 0, entries: struct<key: large_string not null, value: map<large_string, int32> not null> not null
child 0, key: large_string not null
-- field metadata --
PARQUET:field_id: '7'
child 1, value: map<large_string, int32> not null
child 0, entries: struct<key: large_string not null, value: int32 not null> not null
child 0, key: large_string not null
-- field metadata --
PARQUET:field_id: '9'
child 1, value: int32 not null
-- field metadata --
PARQUET:field_id: '10'
-- field metadata --
PARQUET:field_id: '8'
-- field metadata --
PARQUET:field_id: '6'
location: large_list<element: struct<latitude: float, longitude: float> not null> not null
child 0, element: struct<latitude: float, longitude: float> not null
child 0, latitude: float
-- field metadata --
PARQUET:field_id: '13'
child 1, longitude: float
-- field metadata --
PARQUET:field_id: '14'
-- field metadata --
PARQUET:field_id: '12'
-- field metadata --
PARQUET:field_id: '11'
person: struct<name: large_string, age: int32 not null>
child 0, name: large_string
-- field metadata --
PARQUET:field_id: '16'
child 1, age: int32 not null
-- field metadata --
PARQUET:field_id: '17'
-- field metadata --
PARQUET:field_id: '15'"""
assert repr(actual) == expected
def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
expected = """foo: large_string
bar: int32 not null
baz: bool
qux: large_list<element: large_string not null> not null
child 0, element: large_string not null
quux: map<large_string, map<large_string, int32>> not null
child 0, entries: struct<key: large_string not null, value: map<large_string, int32> not null> not null
child 0, key: large_string not null
child 1, value: map<large_string, int32> not null
child 0, entries: struct<key: large_string not null, value: int32 not null> not null
child 0, key: large_string not null
child 1, value: int32 not null
location: large_list<element: struct<latitude: float, longitude: float> not null> not null
child 0, element: struct<latitude: float, longitude: float> not null
child 0, latitude: float
child 1, longitude: float
person: struct<name: large_string, age: int32 not null>
child 0, name: large_string
child 1, age: int32 not null"""
assert repr(actual) == expected
def test_fixed_type_to_pyarrow() -> None:
length = 22
iceberg_type = FixedType(length)
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.binary(length)
def test_decimal_type_to_pyarrow() -> None:
precision = 25
scale = 19
iceberg_type = DecimalType(precision, scale)
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.decimal128(precision, scale)
def test_boolean_type_to_pyarrow() -> None:
iceberg_type = BooleanType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.bool_()
def test_integer_type_to_pyarrow() -> None:
iceberg_type = IntegerType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.int32()
def test_long_type_to_pyarrow() -> None:
iceberg_type = LongType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.int64()
def test_float_type_to_pyarrow() -> None:
iceberg_type = FloatType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.float32()
def test_double_type_to_pyarrow() -> None:
iceberg_type = DoubleType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.float64()
def test_date_type_to_pyarrow() -> None:
iceberg_type = DateType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.date32()
def test_time_type_to_pyarrow() -> None:
iceberg_type = TimeType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.time64("us")
def test_timestamp_type_to_pyarrow() -> None:
iceberg_type = TimestampType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.timestamp(unit="us")
def test_timestamptz_type_to_pyarrow() -> None:
iceberg_type = TimestamptzType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.timestamp(unit="us", tz="UTC")
def test_string_type_to_pyarrow() -> None:
iceberg_type = StringType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_string()
def test_binary_type_to_pyarrow() -> None:
iceberg_type = BinaryType()
assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_binary()
def test_geometry_type_to_pyarrow_without_geoarrow() -> None:
"""Test geometry type falls back to large_binary when geoarrow is not available."""
import sys
iceberg_type = GeometryType()
# Remove geoarrow from sys.modules if present and block re-import
saved_modules = {}
for mod_name in list(sys.modules.keys()):
if mod_name.startswith("geoarrow"):
saved_modules[mod_name] = sys.modules.pop(mod_name)
import builtins
original_import = builtins.__import__
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name.startswith("geoarrow"):
raise ImportError(f"No module named '{name}'")
return original_import(name, *args, **kwargs)
try:
builtins.__import__ = mock_import
result = visit(iceberg_type, _ConvertToArrowSchema())
assert result == pa.large_binary()
finally:
builtins.__import__ = original_import
sys.modules.update(saved_modules)
def test_geography_type_to_pyarrow_without_geoarrow() -> None:
"""Test geography type falls back to large_binary when geoarrow is not available."""
import sys
iceberg_type = GeographyType()
# Remove geoarrow from sys.modules if present and block re-import
saved_modules = {}
for mod_name in list(sys.modules.keys()):
if mod_name.startswith("geoarrow"):
saved_modules[mod_name] = sys.modules.pop(mod_name)
import builtins
original_import = builtins.__import__
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name.startswith("geoarrow"):
raise ImportError(f"No module named '{name}'")
return original_import(name, *args, **kwargs)
try:
builtins.__import__ = mock_import
result = visit(iceberg_type, _ConvertToArrowSchema())
assert result == pa.large_binary()
finally:
builtins.__import__ = original_import
sys.modules.update(saved_modules)
def test_geometry_type_to_pyarrow_with_geoarrow() -> None:
"""Test geometry type uses geoarrow WKB extension type when available."""
pytest.importorskip("geoarrow.pyarrow")
import geoarrow.pyarrow as ga
# Test default CRS
iceberg_type = GeometryType()
result = visit(iceberg_type, _ConvertToArrowSchema())
expected = ga.wkb().with_crs("OGC:CRS84")
assert result == expected
# Test custom CRS
iceberg_type_custom = GeometryType("EPSG:4326")
result_custom = visit(iceberg_type_custom, _ConvertToArrowSchema())
expected_custom = ga.wkb().with_crs("EPSG:4326")
assert result_custom == expected_custom
def test_geography_type_to_pyarrow_with_geoarrow() -> None:
"""Test geography type uses geoarrow WKB extension type with edge type when available."""
pytest.importorskip("geoarrow.pyarrow")
import geoarrow.pyarrow as ga
# Test default (spherical algorithm)
iceberg_type = GeographyType()
result = visit(iceberg_type, _ConvertToArrowSchema())
expected = ga.wkb().with_crs("OGC:CRS84").with_edge_type(ga.EdgeType.SPHERICAL)
assert result == expected
# Test custom CRS with spherical
iceberg_type_custom = GeographyType("EPSG:4326", "spherical")
result_custom = visit(iceberg_type_custom, _ConvertToArrowSchema())
expected_custom = ga.wkb().with_crs("EPSG:4326").with_edge_type(ga.EdgeType.SPHERICAL)
assert result_custom == expected_custom
# Test planar algorithm (no edge type set, uses default)
iceberg_type_planar = GeographyType("OGC:CRS84", "planar")
result_planar = visit(iceberg_type_planar, _ConvertToArrowSchema())
expected_planar = ga.wkb().with_crs("OGC:CRS84")
assert result_planar == expected_planar
def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None:
expected = pa.struct(
[
pa.field("foo", pa.large_string(), nullable=True, metadata={"field_id": "1"}),
pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}),
pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}),
]
)
assert visit(table_schema_simple.as_struct(), _ConvertToArrowSchema()) == expected
def test_map_type_to_pyarrow() -> None:
iceberg_map = MapType(
key_id=1,
key_type=IntegerType(),
value_id=2,
value_type=StringType(),
value_required=True,
)
assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_(
pa.field("key", pa.int32(), nullable=False, metadata={"field_id": "1"}),
pa.field("value", pa.large_string(), nullable=False, metadata={"field_id": "2"}),
)
def test_list_type_to_pyarrow() -> None:
iceberg_map = ListType(
element_id=1,
element_type=IntegerType(),
element_required=True,
)
assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.large_list(
pa.field("element", pa.int32(), nullable=False, metadata={"field_id": "1"})
)
@pytest.fixture
def bound_reference(table_schema_simple: Schema) -> BoundReference:
return BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1))
@pytest.fixture
def bound_double_reference() -> BoundReference:
schema = Schema(
NestedField(field_id=1, name="foo", field_type=DoubleType(), required=False),
schema_id=1,
identifier_field_ids=[],
)
return BoundReference(schema.find_field(1), schema.accessor_for_field(1))
def test_expr_is_null_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundIsNull(bound_reference)))
== "<pyarrow.compute.Expression is_null(foo, {nan_is_null=false})>"
)
def test_expr_not_null_to_pyarrow(bound_reference: BoundReference) -> None:
assert repr(expression_to_pyarrow(BoundNotNull(bound_reference))) == "<pyarrow.compute.Expression is_valid(foo)>"
def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference) -> None:
assert repr(expression_to_pyarrow(BoundIsNaN(bound_double_reference))) == "<pyarrow.compute.Expression is_nan(foo)>"
def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference) -> None:
assert repr(expression_to_pyarrow(BoundNotNaN(bound_double_reference))) == "<pyarrow.compute.Expression invert(is_nan(foo))>"
def test_expr_equal_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundEqualTo(bound_reference, literal("hello"))))
== '<pyarrow.compute.Expression (foo == "hello")>'
)
def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundNotEqualTo(bound_reference, literal("hello"))))
== '<pyarrow.compute.Expression (foo != "hello")>'
)
def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundGreaterThanOrEqual(bound_reference, literal("hello"))))
== '<pyarrow.compute.Expression (foo >= "hello")>'
)
def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundGreaterThan(bound_reference, literal("hello"))))
== '<pyarrow.compute.Expression (foo > "hello")>'
)
def test_expr_less_than_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundLessThan(bound_reference, literal("hello"))))
== '<pyarrow.compute.Expression (foo < "hello")>'
)
def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundLessThanOrEqual(bound_reference, literal("hello"))))
== '<pyarrow.compute.Expression (foo <= "hello")>'
)
def test_expr_in_to_pyarrow(bound_reference: BoundReference) -> None:
assert repr(expression_to_pyarrow(BoundIn(bound_reference, {literal("hello"), literal("world")}))) in (
"""<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[
"hello",
"world"
], null_matching_behavior=MATCH})>""",
"""<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[
"world",
"hello"
], null_matching_behavior=MATCH})>""",
)
def test_expr_not_in_to_pyarrow(bound_reference: BoundReference) -> None:
assert repr(expression_to_pyarrow(BoundNotIn(bound_reference, {literal("hello"), literal("world")}))) in (
"""<pyarrow.compute.Expression invert(is_in(foo, {value_set=large_string:[
"hello",
"world"
], null_matching_behavior=MATCH}))>""",
"""<pyarrow.compute.Expression invert(is_in(foo, {value_set=large_string:[
"world",
"hello"
], null_matching_behavior=MATCH}))>""",
)
def test_expr_starts_with_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundStartsWith(bound_reference, literal("he"))))
== '<pyarrow.compute.Expression starts_with(foo, {pattern="he", ignore_case=false})>'
)
def test_expr_not_starts_with_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(BoundNotStartsWith(bound_reference, literal("he"))))
== '<pyarrow.compute.Expression invert(starts_with(foo, {pattern="he", ignore_case=false}))>'
)
def test_and_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(And(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference))))
== '<pyarrow.compute.Expression ((foo == "hello") and is_null(foo, {nan_is_null=false}))>'
)
def test_or_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(Or(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference))))
== '<pyarrow.compute.Expression ((foo == "hello") or is_null(foo, {nan_is_null=false}))>'
)
def test_not_to_pyarrow(bound_reference: BoundReference) -> None:
assert (
repr(expression_to_pyarrow(Not(BoundEqualTo(bound_reference, literal("hello")))))
== '<pyarrow.compute.Expression invert((foo == "hello"))>'
)
def test_always_true_to_pyarrow(bound_reference: BoundReference) -> None:
assert repr(expression_to_pyarrow(AlwaysTrue())) == "<pyarrow.compute.Expression true>"
def test_always_false_to_pyarrow(bound_reference: BoundReference) -> None:
assert repr(expression_to_pyarrow(AlwaysFalse())) == "<pyarrow.compute.Expression false>"
@pytest.fixture
def schema_int() -> Schema:
return Schema(NestedField(1, "id", IntegerType(), required=False))
@pytest.fixture
def schema_int_str() -> Schema:
return Schema(NestedField(1, "id", IntegerType(), required=False), NestedField(2, "data", StringType(), required=False))
@pytest.fixture
def schema_str() -> Schema:
return Schema(NestedField(2, "data", StringType(), required=False))
@pytest.fixture
def schema_long() -> Schema:
return Schema(NestedField(3, "id", LongType(), required=False))
@pytest.fixture
def schema_struct() -> Schema:
return Schema(
NestedField(
4,
"location",
StructType(
NestedField(41, "lat", DoubleType()),
NestedField(42, "long", DoubleType()),
),
)
)
@pytest.fixture
def schema_list() -> Schema:
return Schema(
NestedField(5, "ids", ListType(51, IntegerType(), element_required=False), required=False),
)
@pytest.fixture
def schema_list_of_structs() -> Schema:
return Schema(
NestedField(
5,
"locations",
ListType(
51,
StructType(NestedField(511, "lat", DoubleType()), NestedField(512, "long", DoubleType())),
element_required=False,
),
required=False,
),
)
@pytest.fixture
def schema_map_of_structs() -> Schema:
return Schema(
NestedField(
5,
"locations",
MapType(
key_id=51,
value_id=52,
key_type=StringType(),
value_type=StructType(
NestedField(511, "lat", DoubleType(), required=True), NestedField(512, "long", DoubleType(), required=True)
),
element_required=False,
),
required=False,
),
)
@pytest.fixture
def schema_map() -> Schema:
return Schema(
NestedField(
5,
"properties",
MapType(
key_id=51,
key_type=StringType(),
value_id=52,
value_type=StringType(),
value_required=True,
),
required=False,
),
)
def _write_table_to_file(filepath: str, schema: pa.Schema, table: pa.Table) -> str:
with pq.ParquetWriter(filepath, schema) as writer:
writer.write_table(table)
return filepath
def _write_table_to_data_file(filepath: str, schema: pa.Schema, table: pa.Table) -> DataFile:
filepath = _write_table_to_file(filepath, schema, table)
return DataFile.from_args(
content=DataFileContent.DATA,
file_path=filepath,
file_format=FileFormat.PARQUET,
partition={},
record_count=len(table),
file_size_in_bytes=22, # This is not relevant for now
)
@pytest.fixture
def file_int(schema_int: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_int, metadata={ICEBERG_SCHEMA: bytes(schema_int.model_dump_json(), UTF8)})
return _write_table_to_file(
f"file:{tmpdir}/a.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema)
)
@pytest.fixture
def file_int_str(schema_int_str: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_int_str, metadata={ICEBERG_SCHEMA: bytes(schema_int_str.model_dump_json(), UTF8)})
return _write_table_to_file(
f"file:{tmpdir}/a.parquet",
pyarrow_schema,
pa.Table.from_arrays([pa.array([0, 1, 2]), pa.array(["0", "1", "2"])], schema=pyarrow_schema),
)
@pytest.fixture
def file_string(schema_str: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_str, metadata={ICEBERG_SCHEMA: bytes(schema_str.model_dump_json(), UTF8)})
return _write_table_to_file(
f"file:{tmpdir}/b.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array(["0", "1", "2"])], schema=pyarrow_schema)
)
@pytest.fixture
def file_long(schema_long: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_long, metadata={ICEBERG_SCHEMA: bytes(schema_long.model_dump_json(), UTF8)})
return _write_table_to_file(
f"file:{tmpdir}/c.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema)
)
@pytest.fixture
def file_struct(schema_struct: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_struct, metadata={ICEBERG_SCHEMA: bytes(schema_struct.model_dump_json(), UTF8)})
return _write_table_to_file(
f"file:{tmpdir}/d.parquet",
pyarrow_schema,
pa.Table.from_pylist(
[
{"location": {"lat": 52.371807, "long": 4.896029}},
{"location": {"lat": 52.387386, "long": 4.646219}},
{"location": {"lat": 52.078663, "long": 4.288788}},
],
schema=pyarrow_schema,
),
)
@pytest.fixture
def file_list(schema_list: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_list, metadata={ICEBERG_SCHEMA: bytes(schema_list.model_dump_json(), UTF8)})
return _write_table_to_file(
f"file:{tmpdir}/e.parquet",
pyarrow_schema,
pa.Table.from_pylist(
[
{"ids": list(range(1, 10))},
{"ids": list(range(2, 20))},
{"ids": list(range(3, 30))},
],
schema=pyarrow_schema,
),
)
@pytest.fixture
def file_list_of_structs(schema_list_of_structs: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(
schema_list_of_structs, metadata={ICEBERG_SCHEMA: bytes(schema_list_of_structs.model_dump_json(), UTF8)}
)
return _write_table_to_file(
f"file:{tmpdir}/e.parquet",
pyarrow_schema,
pa.Table.from_pylist(
[
{"locations": [{"lat": 52.371807, "long": 4.896029}, {"lat": 52.387386, "long": 4.646219}]},
{"locations": []},
{"locations": [{"lat": 52.078663, "long": 4.288788}, {"lat": 52.387386, "long": 4.646219}]},
],
schema=pyarrow_schema,
),
)
@pytest.fixture
def file_map_of_structs(schema_map_of_structs: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(
schema_map_of_structs, metadata={ICEBERG_SCHEMA: bytes(schema_map_of_structs.model_dump_json(), UTF8)}
)
return _write_table_to_file(
f"file:{tmpdir}/e.parquet",
pyarrow_schema,
pa.Table.from_pylist(
[
{"locations": {"1": {"lat": 52.371807, "long": 4.896029}, "2": {"lat": 52.387386, "long": 4.646219}}},
{"locations": {}},
{"locations": {"3": {"lat": 52.078663, "long": 4.288788}, "4": {"lat": 52.387386, "long": 4.646219}}},
],
schema=pyarrow_schema,
),
)
@pytest.fixture
def file_map(schema_map: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_map, metadata={ICEBERG_SCHEMA: bytes(schema_map.model_dump_json(), UTF8)})
return _write_table_to_file(
f"file:{tmpdir}/e.parquet",
pyarrow_schema,
pa.Table.from_pylist(
[
{"properties": [("a", "b")]},
{"properties": [("c", "d")]},
{"properties": [("e", "f"), ("g", "h")]},
],
schema=pyarrow_schema,
),
)
def project(
schema: Schema, files: list[str], expr: BooleanExpression | None = None, table_schema: Schema | None = None
) -> pa.Table:
def _set_spec_id(datafile: DataFile) -> DataFile:
datafile.spec_id = 0
return datafile
return ArrowScan(
table_metadata=TableMetadataV2(
location="file://a/b/",
last_column_id=1,
format_version=2,
schemas=[table_schema or schema],
partition_specs=[PartitionSpec()],
),
io=PyArrowFileIO(),
projected_schema=schema,
row_filter=expr or AlwaysTrue(),
case_sensitive=True,
).to_table(
tasks=[
FileScanTask(
_set_spec_id(
DataFile.from_args(
content=DataFileContent.DATA,
file_path=file,
file_format=FileFormat.PARQUET,
partition={},
record_count=3,
file_size_in_bytes=3,
)
)
)
for file in files
]
)
def test_projection_add_column(file_int: str) -> None:
schema = Schema(
# All new IDs
NestedField(10, "id", IntegerType(), required=False),
NestedField(20, "list", ListType(21, IntegerType(), element_required=False), required=False),
NestedField(
30,
"map",
MapType(key_id=31, key_type=IntegerType(), value_id=32, value_type=StringType(), value_required=False),
required=False,
),
NestedField(
40,
"location",
StructType(
NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "lon", DoubleType(), required=False)
),
required=False,
),
)
result_table = project(schema, [file_int])
for col in result_table.columns:
assert len(col) == 3
for actual, expected in zip(result_table.columns[0], [None, None, None], strict=True):
assert actual.as_py() == expected
for actual, expected in zip(result_table.columns[1], [None, None, None], strict=True):
assert actual.as_py() == expected
for actual, expected in zip(result_table.columns[2], [None, None, None], strict=True):
assert actual.as_py() == expected
for actual, expected in zip(result_table.columns[3], [None, None, None], strict=True):
assert actual.as_py() == expected
assert (
repr(result_table.schema)
== """id: int32
list: large_list<element: int32>
child 0, element: int32
map: map<int32, large_string>
child 0, entries: struct<key: int32 not null, value: large_string> not null
child 0, key: int32 not null
child 1, value: large_string
location: struct<lat: double, lon: double>
child 0, lat: double
child 1, lon: double"""
)
def test_read_list(schema_list: Schema, file_list: str) -> None:
result_table = project(schema_list, [file_list])
assert len(result_table.columns[0]) == 3
for actual, expected in zip(
result_table.columns[0], [list(range(1, 10)), list(range(2, 20)), list(range(3, 30))], strict=True
):
assert actual.as_py() == expected
assert (
repr(result_table.schema)
== """ids: large_list<element: int32>
child 0, element: int32"""
)
def test_read_map(schema_map: Schema, file_map: str) -> None:
result_table = project(schema_map, [file_map])
assert len(result_table.columns[0]) == 3
for actual, expected in zip(result_table.columns[0], [[("a", "b")], [("c", "d")], [("e", "f"), ("g", "h")]], strict=True):
assert actual.as_py() == expected
assert (
repr(result_table.schema)
== """properties: map<string, string>
child 0, entries: struct<key: string not null, value: string not null> not null
child 0, key: string not null
child 1, value: string not null"""
)
def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None:
schema = Schema(
# A new ID
NestedField(
2,
"id",
MapType(key_id=3, key_type=IntegerType(), value_id=4, value_type=StringType(), value_required=False),
required=False,
)
)
result_table = project(schema, [file_int])
# Everything should be None
for r in result_table.columns[0]:
assert r.as_py() is None
assert (
repr(result_table.schema)
== """id: map<int32, large_string>
child 0, entries: struct<key: int32 not null, value: large_string> not null
child 0, key: int32 not null
child 1, value: large_string"""
)
def test_projection_add_column_struct_required(file_int: str) -> None:
schema = Schema(
# A new ID
NestedField(
2,
"other_id",
IntegerType(),
required=True,
)
)
with pytest.raises(ResolveError) as exc_info:
_ = project(schema, [file_int])
assert "Field is required, and could not be found in the file: 2: other_id: required int" in str(exc_info.value)
def test_projection_rename_column(schema_int: Schema, file_int: str) -> None:
schema = Schema(
# Reuses the id 1
NestedField(1, "other_name", IntegerType(), required=True)
)
result_table = project(schema, [file_int])
assert len(result_table.columns[0]) == 3
for actual, expected in zip(result_table.columns[0], [0, 1, 2], strict=True):
assert actual.as_py() == expected
assert repr(result_table.schema) == "other_name: int32 not null"
def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int, file_int])
for actual, expected in zip(result_table.columns[0], [0, 1, 2, 0, 1, 2], strict=True):
assert actual.as_py() == expected
assert len(result_table.columns[0]) == 6
assert repr(result_table.schema) == "id: int32"
def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a partitioned table, verifying partition value
# projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)
partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "partition_id"),
)
catalog.create_namespace("default")
table = catalog.create_table(
"default.test_projection_partition",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)
file_data = pa.array(["foo", "bar", "baz"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)
unpartitioned_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
# projected value
partition=Record(1),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)
with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)
schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int32())])
assert table.scan().to_arrow() == pa.table(
{
"other_field": ["foo", "bar", "baz"],
"partition_id": [1, 1, 1],
},
schema=schema,
)
# Test that row filter works with partition value projection
assert table.scan(row_filter="partition_id = 1").to_arrow() == pa.table(
{
"other_field": ["foo", "bar", "baz"],
"partition_id": [1, 1, 1],
},
schema=schema,
)
# Test that row filter does not return any rows for a non-existing partition value
assert len(table.scan(row_filter="partition_id = -1").to_arrow()) == 0
def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value
# projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
schema = Schema(
NestedField(1, "field_1", StringType(), required=False),
NestedField(2, "field_2", IntegerType(), required=False),
NestedField(3, "field_3", IntegerType(), required=False),
)
partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "field_2"),
PartitionField(3, 1001, IdentityTransform(), "field_3"),
)
catalog.create_namespace("default")
table = catalog.create_table(
"default.test_projection_partitions",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)
file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["field_1"]), file_loc)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)
unpartitioned_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
# projected value
partition=Record(2, 3),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)
with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)
assert (
str(table.scan().to_arrow())
== """pyarrow.Table
field_1: string
field_2: int32
field_3: int32
----
field_1: [["foo"]]
field_2: [[2]]
field_3: [[3]]"""
)
@pytest.fixture
def catalog() -> InMemoryCatalog:
return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"})
def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
assert repr(result_table.schema) == """id: int32"""
def test_projection_filter_renamed_column(file_int: str) -> None:
schema = Schema(
# Reuses the id 1
NestedField(1, "other_id", IntegerType(), required=True)
)
result_table = project(schema, [file_int], GreaterThan("other_id", 1))
assert len(result_table.columns[0]) == 1
assert repr(result_table.schema) == "other_id: int32 not null"
def test_projection_filter_add_column(schema_int: Schema, file_int: str, file_string: str) -> None:
"""We have one file that has the column, and the other one doesn't"""
result_table = project(schema_int, [file_int, file_string])
for actual, expected in zip(result_table.columns[0], [0, 1, 2, None, None, None], strict=True):
assert actual.as_py() == expected
assert len(result_table.columns[0]) == 6
assert repr(result_table.schema) == "id: int32"
def test_projection_filter_add_column_promote(file_int: str) -> None:
schema_long = Schema(NestedField(1, "id", LongType(), required=True))
result_table = project(schema_long, [file_int])
for actual, expected in zip(result_table.columns[0], [0, 1, 2], strict=True):
assert actual.as_py() == expected
assert len(result_table.columns[0]) == 3
assert repr(result_table.schema) == "id: int64 not null"
def test_projection_filter_add_column_demote(file_long: str) -> None:
schema_int = Schema(NestedField(3, "id", IntegerType()))
with pytest.raises(ResolveError) as exc_info:
_ = project(schema_int, [file_long])
assert "Cannot promote long to int" in str(exc_info.value)
def test_projection_nested_struct_subset(file_struct: str) -> None:
schema = Schema(
NestedField(
4,
"location",
StructType(
NestedField(41, "lat", DoubleType(), required=True),
# long is missing!
),
required=True,
)
)
result_table = project(schema, [file_struct])
for actual, expected in zip(result_table.columns[0], [52.371807, 52.387386, 52.078663], strict=True):
assert actual.as_py() == {"lat": expected}
assert len(result_table.columns[0]) == 3
assert (
repr(result_table.schema)
== """location: struct<lat: double not null> not null
child 0, lat: double not null"""
)
def test_projection_nested_new_field(file_struct: str) -> None:
schema = Schema(
NestedField(
4,
"location",
StructType(
NestedField(43, "null", DoubleType(), required=False), # Whoa, this column doesn't exist in the file
),
required=True,
)
)
result_table = project(schema, [file_struct])
for actual, expected in zip(result_table.columns[0], [None, None, None], strict=True):
assert actual.as_py() == {"null": expected}
assert len(result_table.columns[0]) == 3
assert (
repr(result_table.schema)
== """location: struct<null: double> not null
child 0, null: double"""
)
def test_projection_nested_struct(schema_struct: Schema, file_struct: str) -> None:
schema = Schema(
NestedField(
4,
"location",
StructType(
NestedField(41, "lat", DoubleType(), required=False),
NestedField(43, "null", DoubleType(), required=False),
NestedField(42, "long", DoubleType(), required=False),
),
required=True,
)
)
result_table = project(schema, [file_struct])
for actual, expected in zip(
result_table.columns[0],
[
{"lat": 52.371807, "long": 4.896029, "null": None},
{"lat": 52.387386, "long": 4.646219, "null": None},
{"lat": 52.078663, "long": 4.288788, "null": None},
],
strict=True,
):
assert actual.as_py() == expected
assert len(result_table.columns[0]) == 3
assert (
repr(result_table.schema)
== """location: struct<lat: double, null: double, long: double> not null
child 0, lat: double
child 1, null: double
child 2, long: double"""
)
def test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of_structs: str) -> None:
schema = Schema(
NestedField(
5,
"locations",
ListType(
51,
StructType(
NestedField(511, "latitude", DoubleType(), required=True),
NestedField(512, "longitude", DoubleType(), required=True),
NestedField(513, "altitude", DoubleType(), required=False),
),
element_required=False,
),
required=False,
),
)
result_table = project(schema, [file_list_of_structs])
assert len(result_table.columns) == 1
assert len(result_table.columns[0]) == 3
results = [row.as_py() for row in result_table.columns[0]]
assert results == [
[
{"latitude": 52.371807, "longitude": 4.896029, "altitude": None},
{"latitude": 52.387386, "longitude": 4.646219, "altitude": None},
],
[],
[
{"latitude": 52.078663, "longitude": 4.288788, "altitude": None},
{"latitude": 52.387386, "longitude": 4.646219, "altitude": None},
],
]
assert (
repr(result_table.schema)
== """locations: large_list<element: struct<latitude: double not null, longitude: double not null, altitude: double>>
child 0, element: struct<latitude: double not null, longitude: double not null, altitude: double>
child 0, latitude: double not null
child 1, longitude: double not null
child 2, altitude: double"""
)
def test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_structs: str) -> None:
schema = Schema(
NestedField(
5,
"locations",
MapType(
key_id=51,
value_id=52,
key_type=StringType(),
value_type=StructType(
NestedField(511, "latitude", DoubleType(), required=True),
NestedField(512, "longitude", DoubleType(), required=True),
NestedField(513, "altitude", DoubleType()),
),
element_required=False,
),
required=False,
),
)
result_table = project(schema, [file_map_of_structs])
assert len(result_table.columns) == 1
assert len(result_table.columns[0]) == 3
for actual, expected in zip(
result_table.columns[0],
[
[
("1", {"latitude": 52.371807, "longitude": 4.896029, "altitude": None}),
("2", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}),
],
[],
[
("3", {"latitude": 52.078663, "longitude": 4.288788, "altitude": None}),
("4", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}),
],
],
strict=True,
):
assert actual.as_py() == expected
expected_schema_repr = (
"locations: map<string, struct<latitude: double not null, "
"longitude: double not null, altitude: double>>\n"
" child 0, entries: struct<key: string not null, value: struct<latitude: double not null, "
"longitude: double not null, al (... 25 chars omitted) not null\n"
" child 0, key: string not null\n"
" child 1, value: struct<latitude: double not null, longitude: double not null, "
"altitude: double> not null\n"
" child 0, latitude: double not null\n"
" child 1, longitude: double not null\n"
" child 2, altitude: double"
)
assert repr(result_table.schema) == expected_schema_repr
def test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
schema = Schema(
NestedField(
5, # 😱 this is 4 in the file, this will be fixed when projecting the file schema
"location",
StructType(
NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "long", DoubleType(), required=False)
),
required=False,
)
)
result_table = project(schema, [file_struct])
for actual, expected in zip(result_table.columns[0], [None, None, None], strict=True):
assert actual.as_py() == expected
assert len(result_table.columns[0]) == 3
assert (
repr(result_table.schema)
== """location: struct<lat: double, long: double>
child 0, lat: double
child 1, long: double"""
)
def test_projection_filter_on_unprojected_field(schema_int_str: Schema, file_int_str: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType(), required=True))
result_table = project(schema, [file_int_str], GreaterThan("data", "1"), schema_int_str)
for actual, expected in zip(result_table.columns[0], [2], strict=True):
assert actual.as_py() == expected
assert len(result_table.columns[0]) == 1
assert repr(result_table.schema) == "id: int32 not null"
def test_projection_filter_on_unknown_field(schema_int_str: Schema, file_int_str: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType()))
with pytest.raises(ValueError) as exc_info:
_ = project(schema, [file_int_str], GreaterThan("unknown_field", "1"), schema_int_str)
assert "Could not find field with name unknown_field, case_sensitive=True" in str(exc_info.value)
@pytest.fixture(params=["parquet", "orc"])
def deletes_file(tmp_path: str, request: pytest.FixtureRequest) -> str:
if request.param == "parquet":
example_task = request.getfixturevalue("example_task")
import pyarrow.parquet as pq
write_func = pq.write_table
file_ext = "parquet"
else: # orc
example_task = request.getfixturevalue("example_task_orc")
import pyarrow.orc as orc
write_func = orc.write_table
file_ext = "orc"
path = example_task.file.file_path
table = pa.table({"file_path": [path, path, path], "pos": [1, 3, 5]})
deletes_file_path = f"{tmp_path}/deletes.{file_ext}"
write_func(table, deletes_file_path)
return deletes_file_path
def test_read_deletes(deletes_file: str, request: pytest.FixtureRequest) -> None:
# Determine file format from the file extension
file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC
# Get the appropriate example_task fixture based on file format
if file_format == FileFormat.PARQUET:
request.getfixturevalue("example_task")
else:
request.getfixturevalue("example_task_orc")
deletes = _read_deletes(PyArrowFileIO(), DataFile.from_args(file_path=deletes_file, file_format=file_format))
# Get the expected file path from the actual deletes keys since they might differ between formats
expected_file_path = list(deletes.keys())[0]
assert set(deletes.keys()) == {expected_file_path}
assert list(deletes.values())[0] == pa.chunked_array([[1, 3, 5]])
def test_delete(deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema) -> None:
# Determine file format from the file extension
file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC
# Get the appropriate example_task fixture based on file format
if file_format == FileFormat.PARQUET:
example_task = request.getfixturevalue("example_task")
else:
example_task = request.getfixturevalue("example_task_orc")
metadata_location = "file://a/b/c.json"
example_task_with_delete = FileScanTask(
data_file=example_task.file,
delete_files={
DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=file_format)
},
)
with_deletes = ArrowScan(
table_metadata=TableMetadataV2(
location=metadata_location,
last_column_id=1,
format_version=2,
current_schema_id=1,
schemas=[table_schema_simple],
partition_specs=[PartitionSpec()],
),
io=load_file_io(),
projected_schema=table_schema_simple,
row_filter=AlwaysTrue(),
).to_table(tasks=[example_task_with_delete])
# ORC uses 'string' while Parquet uses 'large_string' for string columns
expected_foo_type = "string" if file_format == FileFormat.ORC else "large_string"
expected_str = f"""pyarrow.Table
foo: {expected_foo_type}
bar: int32 not null
baz: bool
----
foo: [["a","c"]]
bar: [[1,3]]
baz: [[true,null]]"""
assert str(with_deletes) == expected_str
def test_delete_duplicates(deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema) -> None:
# Determine file format from the file extension
file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC
# Get the appropriate example_task fixture based on file format
if file_format == FileFormat.PARQUET:
example_task = request.getfixturevalue("example_task")
else:
example_task = request.getfixturevalue("example_task_orc")
metadata_location = "file://a/b/c.json"
example_task_with_delete = FileScanTask(
data_file=example_task.file,
delete_files={
DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=file_format),
DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=file_format),
},
)
with_deletes = ArrowScan(
table_metadata=TableMetadataV2(
location=metadata_location,
last_column_id=1,
format_version=2,
current_schema_id=1,
schemas=[table_schema_simple],
partition_specs=[PartitionSpec()],
),
io=load_file_io(),
projected_schema=table_schema_simple,
row_filter=AlwaysTrue(),
).to_table(tasks=[example_task_with_delete])
# ORC uses 'string' while Parquet uses 'large_string' for string columns
expected_foo_type = "string" if file_format == FileFormat.ORC else "large_string"
expected_str = f"""pyarrow.Table
foo: {expected_foo_type}
bar: int32 not null
baz: bool
----
foo: [["a","c"]]
bar: [[1,3]]
baz: [[true,null]]"""
assert str(with_deletes) == expected_str
def test_pyarrow_wrap_fsspec(example_task: FileScanTask, table_schema_simple: Schema) -> None:
metadata_location = "file://a/b/c.json"
projection = ArrowScan(
table_metadata=TableMetadataV2(
location=metadata_location,
last_column_id=1,
format_version=2,
current_schema_id=1,
schemas=[table_schema_simple],
partition_specs=[PartitionSpec()],
),
io=load_file_io(properties={"py-io-impl": "pyiceberg.io.fsspec.FsspecFileIO"}, location=metadata_location),
case_sensitive=True,
projected_schema=table_schema_simple,
row_filter=AlwaysTrue(),
).to_table(tasks=[example_task])
assert (
str(projection)
== """pyarrow.Table
foo: large_string
bar: int32 not null
baz: bool
----
foo: [["a","b","c"]]
bar: [[1,2,3]]
baz: [[true,false,null]]"""
)
@pytest.mark.gcs
def test_new_input_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test creating a new input file from a fsspec file-io"""
filename = str(uuid4())
input_file = pyarrow_fileio_gcs.new_input(f"gs://warehouse/{filename}")
assert isinstance(input_file, PyArrowFile)
assert input_file.location == f"gs://warehouse/{filename}"
@pytest.mark.gcs
def test_new_output_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test creating a new output file from an fsspec file-io"""
filename = str(uuid4())
output_file = pyarrow_fileio_gcs.new_output(f"gs://warehouse/{filename}")
assert isinstance(output_file, PyArrowFile)
assert output_file.location == f"gs://warehouse/{filename}"
@pytest.mark.gcs
@pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993")
def test_write_and_read_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test writing and reading a file using PyArrowFile"""
location = f"gs://warehouse/{uuid4()}.txt"
output_file = pyarrow_fileio_gcs.new_output(location=location)
with output_file.create() as f:
assert f.write(b"foo") == 3
assert output_file.exists()
input_file = pyarrow_fileio_gcs.new_input(location=location)
with input_file.open() as f:
assert f.read() == b"foo"
pyarrow_fileio_gcs.delete(input_file)
@pytest.mark.gcs
def test_getting_length_of_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test getting the length of PyArrowFile"""
filename = str(uuid4())
output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}")
with output_file.create() as f:
f.write(b"foobar")
assert len(output_file) == 6
input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}")
assert len(input_file) == 6
pyarrow_fileio_gcs.delete(output_file)
@pytest.mark.gcs
@pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993")
def test_file_tell_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
location = f"gs://warehouse/{uuid4()}"
output_file = pyarrow_fileio_gcs.new_output(location=location)
with output_file.create() as write_file:
write_file.write(b"foobar")
input_file = pyarrow_fileio_gcs.new_input(location=location)
with input_file.open() as f:
f.seek(0)
assert f.tell() == 0
f.seek(1)
assert f.tell() == 1
f.seek(3)
assert f.tell() == 3
f.seek(0)
assert f.tell() == 0
@pytest.mark.gcs
@pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993")
def test_read_specified_bytes_for_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
location = f"gs://warehouse/{uuid4()}"
output_file = pyarrow_fileio_gcs.new_output(location=location)
with output_file.create() as write_file:
write_file.write(b"foo")
input_file = pyarrow_fileio_gcs.new_input(location=location)
with input_file.open() as f:
f.seek(0)
assert b"f" == f.read(1)
f.seek(0)
assert b"fo" == f.read(2)
f.seek(1)
assert b"o" == f.read(1)
f.seek(1)
assert b"oo" == f.read(2)
f.seek(0)
assert b"foo" == f.read(999) # test reading amount larger than entire content length
pyarrow_fileio_gcs.delete(input_file)
@pytest.mark.gcs
@pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993")
def test_raise_on_opening_file_not_found_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test that PyArrowFile raises appropriately when the gcs file is not found"""
filename = str(uuid4())
input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}")
with pytest.raises(FileNotFoundError) as exc_info:
input_file.open().read()
assert filename in str(exc_info.value)
@pytest.mark.gcs
def test_checking_if_a_file_exists_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test checking if a file exists"""
non_existent_file = pyarrow_fileio_gcs.new_input(location="gs://warehouse/does-not-exist.txt")
assert not non_existent_file.exists()
location = f"gs://warehouse/{uuid4()}"
output_file = pyarrow_fileio_gcs.new_output(location=location)
assert not output_file.exists()
with output_file.create() as f:
f.write(b"foo")
existing_input_file = pyarrow_fileio_gcs.new_input(location=location)
assert existing_input_file.exists()
existing_output_file = pyarrow_fileio_gcs.new_output(location=location)
assert existing_output_file.exists()
pyarrow_fileio_gcs.delete(existing_output_file)
@pytest.mark.gcs
@pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993")
def test_closing_a_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test closing an output file and input file"""
filename = str(uuid4())
output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}")
with output_file.create() as write_file:
write_file.write(b"foo")
assert not write_file.closed # type: ignore
assert write_file.closed # type: ignore
input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}")
with input_file.open() as f:
assert not f.closed # type: ignore
assert f.closed # type: ignore
pyarrow_fileio_gcs.delete(f"gs://warehouse/{filename}")
@pytest.mark.gcs
def test_converting_an_outputfile_to_an_inputfile_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test converting an output file to an input file"""
filename = str(uuid4())
output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}")
input_file = output_file.to_input_file()
assert input_file.location == output_file.location
@pytest.mark.gcs
@pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993")
def test_writing_avro_file_gcs(generated_manifest_entry_file: str, pyarrow_fileio_gcs: PyArrowFileIO) -> None:
"""Test that bytes match when reading a local avro file, writing it using pyarrow file-io, and then reading it again"""
filename = str(uuid4())
with PyArrowFileIO().new_input(location=generated_manifest_entry_file).open() as f:
b1 = f.read()
with pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}").create() as out_f:
out_f.write(b1)
with pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}").open() as in_f:
b2 = in_f.read()
assert b1 == b2 # Check that bytes of read from local avro file match bytes written to s3
pyarrow_fileio_gcs.delete(f"gs://warehouse/{filename}")
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_new_input_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test creating a new input file from pyarrow file-io"""
filename = str(uuid4())
input_file = pyarrow_fileio_adls.new_input(f"{adls_scheme}://warehouse/{filename}")
assert isinstance(input_file, PyArrowFile)
assert input_file.location == f"{adls_scheme}://warehouse/{filename}"
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_new_output_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test creating a new output file from pyarrow file-io"""
filename = str(uuid4())
output_file = pyarrow_fileio_adls.new_output(f"{adls_scheme}://warehouse/{filename}")
assert isinstance(output_file, PyArrowFile)
assert output_file.location == f"{adls_scheme}://warehouse/{filename}"
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_write_and_read_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test writing and reading a file using PyArrowFile"""
location = f"{adls_scheme}://warehouse/{uuid4()}.txt"
output_file = pyarrow_fileio_adls.new_output(location=location)
with output_file.create() as f:
assert f.write(b"foo") == 3
assert output_file.exists()
input_file = pyarrow_fileio_adls.new_input(location=location)
with input_file.open() as f:
assert f.read() == b"foo"
pyarrow_fileio_adls.delete(input_file)
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_getting_length_of_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test getting the length of PyArrowFile"""
filename = str(uuid4())
output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}")
with output_file.create() as f:
f.write(b"foobar")
assert len(output_file) == 6
input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}")
assert len(input_file) == 6
pyarrow_fileio_adls.delete(output_file)
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_file_tell_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
location = f"{adls_scheme}://warehouse/{uuid4()}"
output_file = pyarrow_fileio_adls.new_output(location=location)
with output_file.create() as write_file:
write_file.write(b"foobar")
input_file = pyarrow_fileio_adls.new_input(location=location)
with input_file.open() as f:
f.seek(0)
assert f.tell() == 0
f.seek(1)
assert f.tell() == 1
f.seek(3)
assert f.tell() == 3
f.seek(0)
assert f.tell() == 0
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_read_specified_bytes_for_file_adls(pyarrow_fileio_adls: PyArrowFileIO) -> None:
location = f"abfss://warehouse/{uuid4()}"
output_file = pyarrow_fileio_adls.new_output(location=location)
with output_file.create() as write_file:
write_file.write(b"foo")
input_file = pyarrow_fileio_adls.new_input(location=location)
with input_file.open() as f:
f.seek(0)
assert b"f" == f.read(1)
f.seek(0)
assert b"fo" == f.read(2)
f.seek(1)
assert b"o" == f.read(1)
f.seek(1)
assert b"oo" == f.read(2)
f.seek(0)
assert b"foo" == f.read(999) # test reading amount larger than entire content length
pyarrow_fileio_adls.delete(input_file)
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_raise_on_opening_file_not_found_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test that PyArrowFile raises appropriately when the adls file is not found"""
filename = str(uuid4())
input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}")
with pytest.raises(FileNotFoundError) as exc_info:
input_file.open().read()
assert filename in str(exc_info.value)
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_checking_if_a_file_exists_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test checking if a file exists"""
non_existent_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/does-not-exist.txt")
assert not non_existent_file.exists()
location = f"{adls_scheme}://warehouse/{uuid4()}"
output_file = pyarrow_fileio_adls.new_output(location=location)
assert not output_file.exists()
with output_file.create() as f:
f.write(b"foo")
existing_input_file = pyarrow_fileio_adls.new_input(location=location)
assert existing_input_file.exists()
existing_output_file = pyarrow_fileio_adls.new_output(location=location)
assert existing_output_file.exists()
pyarrow_fileio_adls.delete(existing_output_file)
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_closing_a_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test closing an output file and input file"""
filename = str(uuid4())
output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}")
with output_file.create() as write_file:
write_file.write(b"foo")
assert not write_file.closed # type: ignore
assert write_file.closed # type: ignore
input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}")
with input_file.open() as f:
assert not f.closed # type: ignore
assert f.closed # type: ignore
pyarrow_fileio_adls.delete(f"{adls_scheme}://warehouse/{filename}")
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_converting_an_outputfile_to_an_inputfile_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test converting an output file to an input file"""
filename = str(uuid4())
output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}")
input_file = output_file.to_input_file()
assert input_file.location == output_file.location
@pytest.mark.adls
@skip_if_pyarrow_too_old
def test_writing_avro_file_adls(generated_manifest_entry_file: str, pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None:
"""Test that bytes match when reading a local avro file, writing it using pyarrow file-io, and then reading it again"""
filename = str(uuid4())
with PyArrowFileIO().new_input(location=generated_manifest_entry_file).open() as f:
b1 = f.read()
with pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}").create() as out_f:
out_f.write(b1)
with pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}").open() as in_f:
b2 = in_f.read()
assert b1 == b2 # Check that bytes of read from local avro file match bytes written to s3
pyarrow_fileio_adls.delete(f"{adls_scheme}://warehouse/{filename}")
def test_parse_location() -> None:
def check_results(location: str, expected_schema: str, expected_netloc: str, expected_uri: str) -> None:
schema, netloc, uri = PyArrowFileIO.parse_location(location)
assert schema == expected_schema
assert netloc == expected_netloc
assert uri == expected_uri
check_results("hdfs://127.0.0.1:9000/root/foo.txt", "hdfs", "127.0.0.1:9000", "/root/foo.txt")
check_results("hdfs://127.0.0.1/root/foo.txt", "hdfs", "127.0.0.1", "/root/foo.txt")
check_results("hdfs://clusterA/root/foo.txt", "hdfs", "clusterA", "/root/foo.txt")
check_results("/root/foo.txt", "file", "", "/root/foo.txt")
check_results("/root/tmp/foo.txt", "file", "", "/root/tmp/foo.txt")
def test_make_compatible_name() -> None:
assert make_compatible_name("label/abc") == "label_x2Fabc"
assert make_compatible_name("label?abc") == "label_x3Fabc"
@pytest.mark.parametrize(
"vals, primitive_type, expected_result",
[
([None, 2, 1], IntegerType(), 1),
([1, None, 2], IntegerType(), 1),
([None, None, None], IntegerType(), None),
([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 1, 2)),
([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 1, 2)),
([None, None, None], DateType(), None),
],
)
def test_stats_aggregator_update_min(vals: list[Any], primitive_type: PrimitiveType, expected_result: Any) -> None:
stats = StatsAggregator(primitive_type, _primitive_to_physical(primitive_type))
for val in vals:
stats.update_min(val)
assert stats.current_min == expected_result
@pytest.mark.parametrize(
"vals, primitive_type, expected_result",
[
([None, 2, 1], IntegerType(), 2),
([1, None, 2], IntegerType(), 2),
([None, None, None], IntegerType(), None),
([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 2, 4)),
([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 2, 4)),
([None, None, None], DateType(), None),
],
)
def test_stats_aggregator_update_max(vals: list[Any], primitive_type: PrimitiveType, expected_result: Any) -> None:
stats = StatsAggregator(primitive_type, _primitive_to_physical(primitive_type))
for val in vals:
stats.update_max(val)
assert stats.current_max == expected_result
@pytest.mark.parametrize(
"iceberg_type, physical_type_string",
[
# Exact match
(IntegerType(), "INT32"),
# Allowed INT32 -> INT64 promotion
(LongType(), "INT32"),
# Allowed FLOAT -> DOUBLE promotion
(DoubleType(), "FLOAT"),
# Allowed FIXED_LEN_BYTE_ARRAY -> INT32
(DecimalType(precision=2, scale=2), "FIXED_LEN_BYTE_ARRAY"),
# Allowed FIXED_LEN_BYTE_ARRAY -> INT64
(DecimalType(precision=12, scale=2), "FIXED_LEN_BYTE_ARRAY"),
],
)
def test_stats_aggregator_conditionally_allowed_types_pass(iceberg_type: PrimitiveType, physical_type_string: str) -> None:
stats = StatsAggregator(iceberg_type, physical_type_string)
assert stats.primitive_type == iceberg_type
assert stats.current_min is None
assert stats.current_max is None
@pytest.mark.parametrize(
"iceberg_type, physical_type_string",
[
# Fail case: INT64 cannot be cast to INT32
(IntegerType(), "INT64"),
],
)
def test_stats_aggregator_physical_type_does_not_match_expected_raise_error(
iceberg_type: PrimitiveType, physical_type_string: str
) -> None:
with pytest.raises(ValueError, match="Unexpected physical type"):
StatsAggregator(iceberg_type, physical_type_string)
def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
# default packs to 1 bin since the table is small
bin_packed = bin_pack_arrow_table(
arrow_table_with_null, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
)
assert len(list(bin_packed)) == 1
# as long as table is smaller than default target size, it should pack to 1 bin
bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10)
assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT)
assert len(list(bin_packed)) == 1
# unless we override the target size to be smaller
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes)
assert len(list(bin_packed)) == 10
# and will produce half the number of files if we double the target size
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2)
assert len(list(bin_packed)) == 5
def test_bin_pack_arrow_table_target_size_smaller_than_row(arrow_table_with_null: pa.Table) -> None:
bin_packed = list(bin_pack_arrow_table(arrow_table_with_null, target_file_size=1))
assert len(bin_packed) == arrow_table_with_null.num_rows
assert sum(batch.num_rows for bin_ in bin_packed for batch in bin_) == arrow_table_with_null.num_rows
def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
other_schema = pa.schema(
(
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.decimal128(18, 6), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
)
)
expected = r"""Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴─────────────────────────────────┘
"""
with pytest.raises(ValueError, match=expected):
_check_pyarrow_schema_compatible(table_schema_simple, other_schema)
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
other_schema = pa.schema(
(
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
)
)
expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""
with pytest.raises(ValueError, match=expected):
_check_pyarrow_schema_compatible(table_schema_simple, other_schema)
def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> None:
other_schema = pa.schema(
(
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=False),
)
)
try:
_check_pyarrow_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`")
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema(
(
pa.field("foo", pa.string(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
)
)
expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ Missing │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""
with pytest.raises(ValueError, match=expected):
_check_pyarrow_schema_compatible(table_schema_simple, other_schema)
def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Schema) -> None:
schema = table_schema_nested.as_arrow()
schema = schema.remove(6).insert(
6,
pa.field(
"person",
pa.struct(
[
pa.field("age", pa.int32(), nullable=False),
]
),
nullable=True,
),
)
try:
_check_pyarrow_schema_compatible(table_schema_nested, schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`")
def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Schema) -> None:
other_schema = table_schema_nested.as_arrow()
other_schema = other_schema.remove(6).insert(
6,
pa.field(
"person",
pa.struct(
[
pa.field("name", pa.string(), nullable=True),
]
),
nullable=True,
),
)
expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ✅ │ 2: bar: required int │ 2: bar: required int │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
│ ✅ │ 4: qux: required list<string> │ 4: qux: required list<string> │
│ ✅ │ 5: element: required string │ 5: element: required string │
│ ✅ │ 6: quux: required map<string, │ 6: quux: required map<string, │
│ │ map<string, int>> │ map<string, int>> │
│ ✅ │ 7: key: required string │ 7: key: required string │
│ ✅ │ 8: value: required map<string, │ 8: value: required map<string, │
│ │ int> │ int> │
│ ✅ │ 9: key: required string │ 9: key: required string │
│ ✅ │ 10: value: required int │ 10: value: required int │
│ ✅ │ 11: location: required │ 11: location: required │
│ │ list<struct<13: latitude: optional │ list<struct<13: latitude: optional │
│ │ float, 14: longitude: optional │ float, 14: longitude: optional │
│ │ float>> │ float>> │
│ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │
│ │ latitude: optional float, 14: │ latitude: optional float, 14: │
│ │ longitude: optional float> │ longitude: optional float> │
│ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │
│ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │
│ ✅ │ 15: person: optional struct<16: │ 15: person: optional struct<16: │
│ │ name: optional string, 17: age: │ name: optional string> │
│ │ required int> │ │
│ ✅ │ 16: name: optional string │ 16: name: optional string │
│ ❌ │ 17: age: required int │ Missing │
└────┴────────────────────────────────────┴────────────────────────────────────┘
"""
with pytest.raises(ValueError, match=expected):
_check_pyarrow_schema_compatible(table_schema_nested, other_schema)
def test_schema_compatible_nested(table_schema_nested: Schema) -> None:
try:
_check_pyarrow_schema_compatible(table_schema_nested, table_schema_nested.as_arrow())
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`")
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema(
(
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("new_field", pa.date32(), nullable=True),
)
)
with pytest.raises(
ValueError, match=r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
):
_check_pyarrow_schema_compatible(table_schema_simple, other_schema)
def test_schema_compatible(table_schema_simple: Schema) -> None:
try:
_check_pyarrow_schema_compatible(table_schema_simple, table_schema_simple.as_arrow())
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`")
def test_schema_projection(table_schema_simple: Schema) -> None:
# remove optional `baz` field from `table_schema_simple`
other_schema = pa.schema(
(
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
)
)
try:
_check_pyarrow_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`")
def test_schema_downcast(table_schema_simple: Schema) -> None:
# large_string type is compatible with string type
other_schema = pa.schema(
(
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
)
)
try:
_check_pyarrow_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`")
def test_partition_for_demo() -> None:
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
NestedField(field_id=1, name="year", field_type=StringType(), required=False),
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
schema_id=1,
)
test_data = {
"year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
"n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100],
"animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"],
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
)
result = list(_determine_partitions(partition_spec, test_schema, arrow_table))
assert {table_partition.partition_key.partition for table_partition in result} == {
Record(2, 2020),
Record(100, 2021),
Record(4, 2021),
Record(4, 2022),
Record(2, 2022),
Record(5, 2019),
}
assert (
pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows
)
def test_partition_for_nested_field() -> None:
schema = Schema(
NestedField(id=1, name="foo", field_type=StringType(), required=True),
NestedField(
id=2,
name="bar",
field_type=StructType(
NestedField(id=3, name="baz", field_type=TimestampType(), required=False),
NestedField(id=4, name="qux", field_type=IntegerType(), required=False),
),
required=True,
),
)
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts"))
t1 = datetime(2025, 7, 11, 9, 30, 0)
t2 = datetime(2025, 7, 11, 10, 30, 0)
test_data = [
{"foo": "a", "bar": {"baz": t1, "qux": 1}},
{"foo": "b", "bar": {"baz": t2, "qux": 2}},
]
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
partitions = list(_determine_partitions(spec, schema, arrow_table))
partition_values = {p.partition_key.partition[0] for p in partitions}
assert partition_values == {486729, 486730}
def test_partition_for_deep_nested_field() -> None:
schema = Schema(
NestedField(
id=1,
name="foo",
field_type=StructType(
NestedField(
id=2,
name="bar",
field_type=StructType(NestedField(id=3, name="baz", field_type=StringType(), required=False)),
required=True,
)
),
required=True,
)
)
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="qux"))
test_data = [
{"foo": {"bar": {"baz": "data-1"}}},
{"foo": {"bar": {"baz": "data-2"}}},
{"foo": {"bar": {"baz": "data-1"}}},
]
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
partitions = list(_determine_partitions(spec, schema, arrow_table))
assert len(partitions) == 2 # 2 unique partitions
partition_values = {p.partition_key.partition[0] for p in partitions}
assert partition_values == {"data-1", "data-2"}
def test_inspect_partition_for_nested_field(catalog: InMemoryCatalog) -> None:
schema = Schema(
NestedField(id=1, name="foo", field_type=StringType(), required=True),
NestedField(
id=2,
name="bar",
field_type=StructType(
NestedField(id=3, name="baz", field_type=StringType(), required=False),
NestedField(id=4, name="qux", field_type=IntegerType(), required=False),
),
required=True,
),
)
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="part"))
catalog.create_namespace("default")
table = catalog.create_table("default.test_partition_in_struct", schema=schema, partition_spec=spec)
test_data = [
{"foo": "a", "bar": {"baz": "data-a", "qux": 1}},
{"foo": "b", "bar": {"baz": "data-b", "qux": 2}},
]
arrow_table = pa.Table.from_pylist(test_data, schema=table.schema().as_arrow())
table.append(arrow_table)
partitions_table = table.inspect.partitions()
partitions = partitions_table["partition"].to_pylist()
assert len(partitions) == 2
assert {part["part"] for part in partitions} == {"data-a", "data-b"}
def test_inspect_partitions_respects_partition_evolution(catalog: InMemoryCatalog) -> None:
schema = Schema(
NestedField(id=1, name="dt", field_type=DateType(), required=False),
NestedField(id=2, name="category", field_type=StringType(), required=False),
)
spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="dt"))
catalog.create_namespace("default")
table = catalog.create_table(
"default.test_inspect_partitions_respects_partition_evolution", schema=schema, partition_spec=spec
)
old_spec_id = table.spec().spec_id
old_data = [{"dt": date(2025, 1, 1), "category": "old"}]
table.append(pa.Table.from_pylist(old_data, schema=table.schema().as_arrow()))
table.update_spec().add_identity("category").commit()
new_spec_id = table.spec().spec_id
assert new_spec_id != old_spec_id
partitions_table = table.inspect.partitions()
partitions = partitions_table["partition"].to_pylist()
assert all("category" not in partition for partition in partitions)
new_data = [{"dt": date(2025, 1, 2), "category": "new"}]
table.append(pa.Table.from_pylist(new_data, schema=table.schema().as_arrow()))
partitions_table = table.inspect.partitions()
assert set(partitions_table["spec_id"].to_pylist()) == {old_spec_id, new_spec_id}
def test_identity_partition_on_multi_columns() -> None:
test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
NestedField(field_id=1, name="born_year", field_type=StringType(), required=False),
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
schema_id=1,
)
# 5 partitions, 6 unique row values, 12 rows
test_rows = [
(2021, 4, "Dog"),
(2022, 4, "Horse"),
(2022, 4, "Another Horse"),
(2021, 100, "Centipede"),
(None, 4, "Kirin"),
(2021, None, "Fish"),
] * 2
expected = {Record(test_rows[i][1], test_rows[i][0]) for i in range(len(test_rows))}
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
)
import random
# there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
for _ in range(1000):
random.shuffle(test_rows)
test_data = {
"born_year": [row[0] for row in test_rows],
"n_legs": [row[1] for row in test_rows],
"animal": [row[2] for row in test_rows],
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
result = list(_determine_partitions(partition_spec, test_schema, arrow_table))
assert {table_partition.partition_key.partition for table_partition in result} == expected
concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result])
assert concatenated_arrow_table.num_rows == arrow_table.num_rows
assert concatenated_arrow_table.sort_by(
[
("born_year", "ascending"),
("n_legs", "ascending"),
("animal", "ascending"),
]
) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")])
def test_initial_value() -> None:
# Have some fake data, otherwise it will generate a table without records
data = pa.record_batch([pa.nulls(10, pa.int64())], names=["some_field"])
result = _to_requested_schema(
Schema(NestedField(1, "we-love-22", LongType(), required=True, initial_default=22)), Schema(), data
)
assert result.column_names == ["we-love-22"]
for val in result[0]:
assert val.as_py() == 22
def test__to_requested_schema_timestamp_to_timestamptz_projection() -> None:
# file is written with timestamp without timezone
file_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
batch = pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp("us"),
)
],
names=["ts_field"],
)
# table is written with timestamp with timezone
table_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False))
actual_result = _to_requested_schema(table_schema, file_schema, batch, downcast_ns_timestamp_to_us=True)
expected = pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp("us", tz=timezone.utc),
)
],
names=["ts_field"],
)
# expect actual_result to have timezone
assert expected.equals(actual_result)
def test__to_requested_schema_timestamptz_to_timestamp_projection() -> None:
# file is written with timestamp with timezone
file_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False))
batch = pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0, tzinfo=timezone.utc),
datetime(2025, 8, 14, 13, 0, 0, tzinfo=timezone.utc),
],
type=pa.timestamp("us", tz="UTC"),
)
],
names=["ts_field"],
)
# table schema expects timestamp without timezone
table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
# allow_timestamp_tz_mismatch=True enables reading timestamptz as timestamp
actual_result = _to_requested_schema(
table_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, allow_timestamp_tz_mismatch=True
)
expected = pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp("us"),
)
],
names=["ts_field"],
)
# expect actual_result to have no timezone
assert expected.equals(actual_result)
def test__to_requested_schema_timestamptz_to_timestamp_write_rejects() -> None:
"""Test that the write path (default) rejects timestamptz to timestamp casting.
This ensures we enforce the Iceberg spec distinction between timestamp and timestamptz on writes,
while the read path can be more permissive (like Spark) via allow_timestamp_tz_mismatch=True.
"""
# file is written with timestamp with timezone
file_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False))
batch = pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0, tzinfo=timezone.utc),
datetime(2025, 8, 14, 13, 0, 0, tzinfo=timezone.utc),
],
type=pa.timestamp("us", tz="UTC"),
)
],
names=["ts_field"],
)
# table schema expects timestamp without timezone
table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
# allow_timestamp_tz_mismatch=False (default, used in write path) should raise
with pytest.raises(ValueError, match="Unsupported schema projection"):
_to_requested_schema(
table_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, allow_timestamp_tz_mismatch=False
)
def test_write_file_rejects_timestamptz_to_timestamp(tmp_path: Path) -> None:
"""Test that write_file rejects writing timestamptz data to a timestamp column."""
from pyiceberg.table import WriteTask
# Table expects timestamp (no tz), but data has timestamptz
table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
task_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False))
arrow_data = pa.table({"ts_field": [datetime(2025, 8, 14, 12, 0, 0, tzinfo=timezone.utc)]})
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}",
last_column_id=1,
format_version=2,
schemas=[table_schema],
partition_specs=[PartitionSpec()],
)
task = WriteTask(
write_uuid=uuid.uuid4(),
task_id=0,
record_batches=arrow_data.to_batches(),
schema=task_schema,
)
with pytest.raises(ValueError, match="Unsupported schema projection"):
list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task])))
def test__to_requested_schema_timestamps(
arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
arrow_table_with_all_timestamp_precisions: pa.Table,
arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
table_schema_with_all_microseconds_timestamp_precision: Schema,
) -> None:
requested_schema = table_schema_with_all_microseconds_timestamp_precision
file_schema = requested_schema
batch = arrow_table_with_all_timestamp_precisions.to_batches()[0]
result = _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False)
expected = arrow_table_with_all_timestamp_precisions.cast(
arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False
).to_batches()[0]
assert result == expected
def test__to_requested_schema_timestamps_without_downcast_raises_exception(
arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
arrow_table_with_all_timestamp_precisions: pa.Table,
arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
table_schema_with_all_microseconds_timestamp_precision: Schema,
) -> None:
requested_schema = table_schema_with_all_microseconds_timestamp_precision
file_schema = requested_schema
batch = arrow_table_with_all_timestamp_precisions.to_batches()[0]
with pytest.raises(ValueError) as exc_info:
_to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False)
assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value)
@pytest.mark.parametrize(
"arrow_type,iceberg_type,expected_arrow_type",
[
(pa.uint8(), IntegerType(), pa.int32()),
(pa.int8(), IntegerType(), pa.int32()),
(pa.int16(), IntegerType(), pa.int32()),
(pa.uint16(), IntegerType(), pa.int32()),
(pa.uint32(), LongType(), pa.int64()),
(pa.int32(), LongType(), pa.int64()),
],
)
def test__to_requested_schema_integer_promotion(
arrow_type: pa.DataType,
iceberg_type: PrimitiveType,
expected_arrow_type: pa.DataType,
) -> None:
"""Test that smaller integer types are cast to target Iceberg type during write."""
requested_schema = Schema(NestedField(1, "col", iceberg_type, required=False))
file_schema = requested_schema
arrow_schema = pa.schema([pa.field("col", arrow_type)])
data = pa.array([1, 2, 3, None], type=arrow_type)
batch = pa.RecordBatch.from_arrays([data], schema=arrow_schema)
result = _to_requested_schema(
requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False
)
assert result.schema[0].type == expected_arrow_type
assert result.column(0).to_pylist() == [1, 2, 3, None]
def test_pyarrow_file_io_fs_by_scheme_cache() -> None:
# It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument
# becomes available for `resolve_s3_region`
# Refer to: https://github.com/apache/arrow/issues/43713
pyarrow_file_io = PyArrowFileIO()
us_east_1_region = "us-east-1"
ap_southeast_2_region = "ap-southeast-2"
with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
# Call with new argument resolves region automatically
mock_s3_region_resolver.return_value = us_east_1_region
filesystem_us = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket")
assert filesystem_us.region == us_east_1_region
assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 1 # type: ignore
assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 1 # type: ignore
# Call with different argument also resolves region automatically
mock_s3_region_resolver.return_value = ap_southeast_2_region
filesystem_ap_southeast_2 = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket")
assert filesystem_ap_southeast_2.region == ap_southeast_2_region
assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 2 # type: ignore
assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 2 # type: ignore
# Call with same argument hits cache
filesystem_us_cached = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket")
assert filesystem_us_cached.region == us_east_1_region
assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 1 # type: ignore
# Call with same argument hits cache
filesystem_ap_southeast_2_cached = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket")
assert filesystem_ap_southeast_2_cached.region == ap_southeast_2_region
assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 2 # type: ignore
def test_pyarrow_io_new_input_multi_region(caplog: Any) -> None:
# It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument
# becomes available for `resolve_s3_region`
# Refer to: https://github.com/apache/arrow/issues/43713
user_provided_region = "ap-southeast-1"
bucket_regions = [
("us-east-2-bucket", "us-east-2"),
("ap-southeast-2-bucket", "ap-southeast-2"),
]
def _s3_region_map(bucket: str) -> str:
for bucket_region in bucket_regions:
if bucket_region[0] == bucket:
return bucket_region[1]
raise OSError("Unknown bucket")
# For a pyarrow io instance with configured default s3 region
pyarrow_file_io = PyArrowFileIO({"s3.region": user_provided_region, "s3.resolve-region": "true"})
with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
mock_s3_region_resolver.side_effect = _s3_region_map
# The region is set to provided region if bucket region cannot be resolved
with caplog.at_level(logging.WARNING):
assert pyarrow_file_io.new_input("s3://non-exist-bucket/path/to/file")._filesystem.region == user_provided_region
assert "Unable to resolve region for bucket non-exist-bucket" in caplog.text
for bucket_region in bucket_regions:
# For s3 scheme, region is overwritten by resolved bucket region if different from user provided region
with caplog.at_level(logging.WARNING):
assert pyarrow_file_io.new_input(f"s3://{bucket_region[0]}/path/to/file")._filesystem.region == bucket_region[1]
assert (
f"PyArrow FileIO overriding S3 bucket region for bucket {bucket_region[0]}: "
f"provided region {user_provided_region}, actual region {bucket_region[1]}" in caplog.text
)
# For oss scheme, user provided region is used instead
assert pyarrow_file_io.new_input(f"oss://{bucket_region[0]}/path/to/file")._filesystem.region == user_provided_region
def test_pyarrow_io_multi_fs() -> None:
pyarrow_file_io = PyArrowFileIO({"s3.region": "ap-southeast-1"})
with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
mock_s3_region_resolver.return_value = None
# The PyArrowFileIO instance resolves s3 file input to S3FileSystem
assert isinstance(pyarrow_file_io.new_input("s3://bucket/path/to/file")._filesystem, S3FileSystem)
# Same PyArrowFileIO instance resolves local file input to LocalFileSystem
assert isinstance(pyarrow_file_io.new_input("file:///path/to/file")._filesystem, LocalFileSystem)
class SomeRetryStrategy(AwsDefaultS3RetryStrategy):
def __init__(self) -> None:
super().__init__()
warnings.warn("Initialized SomeRetryStrategy 👍", stacklevel=2)
def test_retry_strategy() -> None:
io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "tests.io.test_pyarrow.SomeRetryStrategy"})
with pytest.warns(UserWarning, match="Initialized SomeRetryStrategy.*"):
io.new_input("s3://bucket/path/to/file")
def test_retry_strategy_not_found() -> None:
io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"})
with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"):
io.new_input("s3://bucket/path/to/file")
@pytest.mark.parametrize("format_version", [1, 2, 3])
def test_task_to_record_batches_nanos(format_version: TableVersion, tmpdir: str) -> None:
arrow_table = pa.table(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp("ns"),
)
],
pa.schema((pa.field("ts_field", pa.timestamp("ns"), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),)),
)
data_file = _write_table_to_data_file(f"{tmpdir}/test_task_to_record_batches_nanos.parquet", arrow_table.schema, arrow_table)
if format_version <= 2:
table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
else:
table_schema = Schema(NestedField(1, "ts_field", TimestampNanoType(), required=False))
actual_result = list(
_task_to_record_batches(
PyArrowFileIO(),
FileScanTask(data_file),
bound_row_filter=AlwaysTrue(),
projected_schema=table_schema,
table_schema=table_schema,
projected_field_ids={1},
positional_deletes=None,
case_sensitive=True,
format_version=format_version,
)
)[0]
def _expected_batch(unit: str) -> pa.RecordBatch:
return pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp(unit),
)
],
names=["ts_field"],
)
assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result)
def test_parse_location_defaults() -> None:
"""Test that parse_location uses defaults."""
from pyiceberg.io.pyarrow import PyArrowFileIO
# if no default scheme or netloc is provided, use file scheme and empty netloc
scheme, netloc, path = PyArrowFileIO.parse_location("/foo/bar")
assert scheme == "file"
assert netloc == ""
assert path == "/foo/bar"
scheme, netloc, path = PyArrowFileIO.parse_location(
"/foo/bar", properties={"DEFAULT_SCHEME": "scheme", "DEFAULT_NETLOC": "netloc:8000"}
)
assert scheme == "scheme"
assert netloc == "netloc:8000"
assert path == "/foo/bar"
scheme, netloc, path = PyArrowFileIO.parse_location(
"/foo/bar", properties={"DEFAULT_SCHEME": "hdfs", "DEFAULT_NETLOC": "netloc:8000"}
)
assert scheme == "hdfs"
assert netloc == "netloc:8000"
assert path == "/foo/bar"
def test_write_and_read_orc(tmp_path: Path) -> None:
"""Test basic ORC write and read functionality."""
# Create a simple Arrow table
data = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]})
orc_path = tmp_path / "test.orc"
orc.write_table(data, str(orc_path))
# Read it back
orc_file = orc.ORCFile(str(orc_path))
table_read = orc_file.read()
assert table_read.equals(data)
def test_orc_file_format_integration(tmp_path: Path) -> None:
"""Test ORC file format integration with PyArrow dataset API."""
# This test mimics a minimal integration with PyIceberg's FileFormat enum and pyarrow.orc
import pyarrow.dataset as ds
data = pa.table({"a": [10, 20], "b": ["foo", "bar"]})
orc_path = tmp_path / "iceberg.orc"
orc.write_table(data, str(orc_path))
# Use PyArrow dataset API to read as ORC
dataset = ds.dataset(str(orc_path), format=ds.OrcFileFormat())
table_read = dataset.to_table()
assert table_read.equals(data)
def test_iceberg_read_orc(tmp_path: Path) -> None:
"""
Integration test: Read ORC files via Iceberg API.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_iceberg_read_orc
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, StringType
# Define schema and data
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=False),
)
data = pa.table({"id": pa.array([1, 2, 3], type=pa.int32()), "name": ["a", "b", "c"]})
# Create ORC file directly using PyArrow
orc_path = tmp_path / "test_data.orc"
orc.write_table(data, str(orc_path))
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"write.format.default": "parquet", # This doesn't matter for reading
# Add name mapping for ORC files without field IDs
"schema.name-mapping.default": ('[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["name"]}]'),
},
)
io = PyArrowFileIO()
# Create a DataFile pointing to the ORC file
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=3,
column_sizes={1: 12, 2: 12}, # Approximate sizes
value_counts={1: 3, 2: 3},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"a"}, # Approximate bounds
upper_bounds={1: b"\x03\x00\x00\x00", 2: b"c"},
split_offsets=None,
)
# Ensure spec_id is properly set
data_file.spec_id = 0
# Read back using ArrowScan
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
scan_task = FileScanTask(data_file=data_file)
table_read = scan.to_table([scan_task])
# Compare data ignoring schema metadata (like not null constraints)
assert table_read.num_rows == data.num_rows
assert table_read.num_columns == data.num_columns
assert table_read.column_names == data.column_names
# Compare actual column data values
for col_name in data.column_names:
assert table_read.column(col_name).to_pylist() == data.column(col_name).to_pylist()
def test_orc_row_filtering_predicate_pushdown(tmp_path: Path) -> None:
"""
Test ORC row filtering and predicate pushdown functionality.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_row_filtering_predicate_pushdown
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import And, EqualTo, GreaterThan, In, LessThan, Or
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import BooleanType, IntegerType, StringType
# Define schema and data with more complex data for filtering
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=False),
NestedField(3, "age", IntegerType(), required=True),
NestedField(4, "active", BooleanType(), required=True),
)
# Create data with various values for filtering
data = pa.table(
{
"id": pa.array([1, 2, 3, 4, 5], type=pa.int32()),
"name": ["alice", "bob", "charlie", "david", "eve"],
"age": pa.array([25, 30, 35, 40, 45], type=pa.int32()),
"active": [True, False, True, True, False],
}
)
# Create ORC file
orc_path = tmp_path / "filter_test.orc"
orc.write_table(data, str(orc_path))
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=4,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": (
'[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["name"]}, '
'{"field-id": 3, "names": ["age"]}, {"field-id": 4, "names": ["active"]}]'
),
},
)
io = PyArrowFileIO()
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=5,
column_sizes={1: 20, 2: 50, 3: 20, 4: 10},
value_counts={1: 5, 2: 5, 3: 5, 4: 5},
null_value_counts={1: 0, 2: 0, 3: 0, 4: 0},
nan_value_counts={1: 0, 2: 0, 3: 0, 4: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"alice", 3: b"\x19\x00\x00\x00", 4: b"\x00"},
upper_bounds={1: b"\x05\x00\x00\x00", 2: b"eve", 3: b"\x2d\x00\x00\x00", 4: b"\x01"},
split_offsets=None,
)
data_file.spec_id = 0
scan_task = FileScanTask(data_file=data_file)
# Test 1: Simple equality filter
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=EqualTo("id", 3),
case_sensitive=True,
)
result = scan.to_table([scan_task])
assert result.num_rows == 1
assert result.column("id").to_pylist() == [3]
assert result.column("name").to_pylist() == ["charlie"]
# Test 2: Range filter
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=And(GreaterThan("age", 30), LessThan("age", 45)),
case_sensitive=True,
)
result = scan.to_table([scan_task])
assert result.num_rows == 2
assert set(result.column("id").to_pylist()) == {3, 4}
# Test 3: String filter
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=EqualTo("name", "bob"),
case_sensitive=True,
)
result = scan.to_table([scan_task])
assert result.num_rows == 1
assert result.column("name").to_pylist() == ["bob"]
# Test 4: Boolean filter
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=EqualTo("active", True),
case_sensitive=True,
)
result = scan.to_table([scan_task])
assert result.num_rows == 3
assert set(result.column("id").to_pylist()) == {1, 3, 4}
# Test 5: IN filter
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=In("id", [1, 3, 5]),
case_sensitive=True,
)
result = scan.to_table([scan_task])
assert result.num_rows == 3
assert set(result.column("id").to_pylist()) == {1, 3, 5}
# Test 6: Complex AND/OR filter
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=Or(And(EqualTo("active", True), GreaterThan("age", 30)), EqualTo("name", "bob")),
case_sensitive=True,
)
result = scan.to_table([scan_task])
assert result.num_rows == 3
assert set(result.column("id").to_pylist()) == {2, 3, 4} # bob, charlie, david
def test_orc_record_batching_streaming(tmp_path: Path) -> None:
"""
Test ORC record batching and streaming functionality with multiple files and fragments.
This test validates that we get the expected number of batches based on file scan tasks
and ORC fragments, providing end-to-end validation of the batching behavior.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_record_batching_streaming
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "value", StringType(), required=False),
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]',
},
)
io = PyArrowFileIO()
# Test with larger files to better demonstrate batching behavior
# PyArrow default batch size is typically 1024 rows, so we'll create files larger than that
num_files = 2
rows_per_file = 2000 # Larger than default batch size to ensure multiple batches per file
total_rows = num_files * rows_per_file
scan_tasks = []
for file_idx in range(num_files):
# Create data for this file
start_id = file_idx * rows_per_file + 1
end_id = (file_idx + 1) * rows_per_file
data = pa.table(
{
"id": pa.array(range(start_id, end_id + 1), type=pa.int32()),
"value": [f"file_{file_idx}_value_{i}" for i in range(start_id, end_id + 1)],
}
)
# Create ORC file
orc_path = tmp_path / f"batch_test_{file_idx}.orc"
orc.write_table(data, str(orc_path))
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=rows_per_file,
column_sizes={1: 8000, 2: 16000},
value_counts={1: rows_per_file, 2: rows_per_file},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: start_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{start_id}".encode()},
upper_bounds={1: end_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{end_id}".encode()},
split_offsets=None,
)
data_file.spec_id = 0
scan_tasks.append(FileScanTask(data_file=data_file))
# Test 1: Multiple file batching - verify we get batches from all files
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches(scan_tasks))
# Verify we get the expected number of batches
# Based on our testing, PyArrow creates 1 batch per file
expected_batches = num_files # 1 batch per file
assert len(batches) == expected_batches, f"Expected {expected_batches} batches (1 per file), got {len(batches)}"
# Verify batch sizes are reasonable (not too large)
max_batch_size = max(batch.num_rows for batch in batches)
assert max_batch_size <= 2000, f"Batch size {max_batch_size} seems too large for ORC files"
assert max_batch_size > 0, "Batch should not be empty"
# We shouldn't get more batches than total rows (one batch per row maximum)
assert len(batches) <= total_rows, f"Expected at most {total_rows} batches (one per row), got {len(batches)}"
# Verify all batches are RecordBatch objects
for batch in batches:
assert isinstance(batch, pa.RecordBatch), f"Expected RecordBatch, got {type(batch)}"
assert batch.num_columns == 2, f"Expected 2 columns, got {batch.num_columns}"
assert "id" in batch.schema.names, "Missing 'id' column"
assert "value" in batch.schema.names, "Missing 'value' column"
# Test 2: Verify data integrity across all batches from all files
total_rows = sum(batch.num_rows for batch in batches)
assert total_rows == total_rows, f"Expected {total_rows} rows total, got {total_rows}"
# Collect all data from batches and verify it spans all files
all_ids = []
all_values = []
for batch in batches:
all_ids.extend(batch.column("id").to_pylist())
all_values.extend(batch.column("value").to_pylist())
# Verify we have data from all files
expected_ids = list(range(1, total_rows + 1))
assert sorted(all_ids) == expected_ids, f"ID data doesn't match expected range 1-{total_rows}"
# Verify values contain data from all files
file_values = set()
for value in all_values:
if value.startswith("file_"):
file_idx = int(value.split("_")[1])
file_values.add(file_idx)
assert file_values == set(range(num_files)), f"Expected values from all {num_files} files, got from files: {file_values}"
# Test 3: Verify batch distribution across files
# Each file should contribute at least one batch
batch_sizes = [batch.num_rows for batch in batches]
total_batch_rows = sum(batch_sizes)
assert total_batch_rows == total_rows, f"Total batch rows {total_batch_rows} != expected {total_rows}"
# Verify we have reasonable batch sizes (not too small, not too large)
for batch_size in batch_sizes:
assert batch_size > 0, "Batch should not be empty"
assert batch_size <= total_rows, f"Batch size {batch_size} should not exceed total rows {total_rows}"
# Test 4: Streaming behavior with multiple files
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
processed_rows = 0
batch_count = 0
file_data_counts = dict.fromkeys(range(num_files), 0)
for batch in scan.to_record_batches(scan_tasks):
batch_count += 1
processed_rows += batch.num_rows
# Count rows per file in this batch
for value in batch.column("value").to_pylist():
if value.startswith("file_"):
file_idx = int(value.split("_")[1])
file_data_counts[file_idx] += 1
# PyArrow may optimize batching, so we just verify we get reasonable batching
assert batch_count >= 1, f"Expected at least 1 batch, got {batch_count}"
assert batch_count <= num_files, f"Expected at most {num_files} batches (1 per file), got {batch_count}"
assert processed_rows == total_rows, f"Processed {processed_rows} rows, expected {total_rows}"
# Verify each file contributed data
for file_idx in range(num_files):
assert file_data_counts[file_idx] == rows_per_file, (
f"File {file_idx} contributed {file_data_counts[file_idx]} rows, expected {rows_per_file}"
)
# Test 5: Column projection with multiple files
projected_schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
)
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=projected_schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches(scan_tasks))
assert len(batches) >= 1, f"Expected at least 1 batch for projected schema, got {len(batches)}"
for batch in batches:
assert batch.num_columns == 1, f"Expected 1 column after projection, got {batch.num_columns}"
assert "id" in batch.schema.names, "Missing 'id' column after projection"
assert "value" not in batch.schema.names, "Should not have 'value' column after projection"
# Test 6: Filtering with multiple files
from pyiceberg.expressions import GreaterThan
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=GreaterThan("id", total_rows // 2), # Filter to second half of data
case_sensitive=True,
)
batches = list(scan.to_record_batches(scan_tasks))
total_filtered_rows = sum(batch.num_rows for batch in batches)
expected_filtered = total_rows // 2
assert total_filtered_rows == expected_filtered, (
f"Expected {expected_filtered} rows after filtering, got {total_filtered_rows}"
)
# Verify all returned IDs are in the filtered range
for batch in batches:
ids = batch.column("id").to_pylist()
assert all(id_val > total_rows // 2 for id_val in ids), f"Found ID <= {total_rows // 2}: {ids}"
# Test 7: Verify batch count matches expected pattern
# The number of batches should be >= number of files (one batch per file minimum)
# and could be more if ORC creates multiple fragments per file
# This validates the end-to-end batching behavior as requested in the PR comment
# We expect multiple batches based on file size and configured batch size
# Verify we get reasonable batching behavior
assert len(batches) >= 1, f"Expected at least 1 batch, got {len(batches)}"
assert len(batches) <= total_rows, f"Expected at most {total_rows} batches (one per row), got {len(batches)}"
def test_orc_batching_exact_counts_single_file(tmp_path: Path) -> None:
"""
Test exact batch counts for single ORC files of different sizes.
This test explicitly verifies the number of batches PyArrow creates for different file sizes.
Note: Uses default ORC writing (1 stripe per file), so expects 1 batch per file.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_batching_exact_counts_single_file
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "value", StringType(), required=False),
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]',
},
)
io = PyArrowFileIO()
# Test different file sizes to understand PyArrow's batching behavior
# Note: All files will have 1 stripe (default ORC writing), so 1 batch each
test_cases = [
(500, "Small file (1 stripe)"),
(1000, "Medium file (1 stripe)"),
(2000, "Large file (1 stripe)"),
(5000, "Very large file (1 stripe)"),
]
for num_rows, _description in test_cases:
# Create data
data = pa.table(
{"id": pa.array(range(1, num_rows + 1), type=pa.int32()), "value": [f"value_{i}" for i in range(1, num_rows + 1)]}
)
# Create ORC file
orc_path = tmp_path / f"test_{num_rows}_rows.orc"
orc.write_table(data, str(orc_path))
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=num_rows,
column_sizes={1: num_rows * 4, 2: num_rows * 8},
value_counts={1: num_rows, 2: num_rows},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"},
upper_bounds={1: num_rows.to_bytes(4, "little"), 2: f"value_{num_rows}".encode()},
split_offsets=None,
)
data_file.spec_id = 0
scan_task = FileScanTask(data_file=data_file)
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches([scan_task]))
# Verify exact batch count and sizes
total_batch_rows = sum(batch.num_rows for batch in batches)
assert total_batch_rows == num_rows, f"Total rows mismatch: expected {num_rows}, got {total_batch_rows}"
# Verify data integrity
all_ids = []
for batch in batches:
all_ids.extend(batch.column("id").to_pylist())
assert sorted(all_ids) == list(range(1, num_rows + 1)), f"Data integrity check failed for {num_rows} rows"
def test_orc_batching_exact_counts_multiple_files(tmp_path: Path) -> None:
"""
Test exact batch counts for multiple ORC files of different sizes and counts.
This test explicitly verifies the number of batches PyArrow creates for different file configurations.
Note: Uses default ORC writing (1 stripe per file), so expects 1 batch per file.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_batching_exact_counts_multiple_files
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "value", StringType(), required=False),
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]',
},
)
io = PyArrowFileIO()
# Test different file configurations to understand PyArrow's batching behavior
# Note: All files will have 1 stripe each (default ORC writing), so 1 batch per file
test_cases = [
(2, 500, "2 files, 500 rows each (1 stripe each)"),
(3, 1000, "3 files, 1000 rows each (1 stripe each)"),
(4, 750, "4 files, 750 rows each (1 stripe each)"),
(2, 2000, "2 files, 2000 rows each (1 stripe each)"),
]
for num_files, rows_per_file, description in test_cases:
total_rows = num_files * rows_per_file
scan_tasks = []
for file_idx in range(num_files):
# Create data for this file
start_id = file_idx * rows_per_file + 1
end_id = (file_idx + 1) * rows_per_file
data = pa.table(
{
"id": pa.array(range(start_id, end_id + 1), type=pa.int32()),
"value": [f"file_{file_idx}_value_{i}" for i in range(start_id, end_id + 1)],
}
)
# Create ORC file
orc_path = tmp_path / f"multi_test_{file_idx}_{rows_per_file}_rows.orc"
orc.write_table(data, str(orc_path))
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=rows_per_file,
column_sizes={1: rows_per_file * 4, 2: rows_per_file * 8},
value_counts={1: rows_per_file, 2: rows_per_file},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: start_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{start_id}".encode()},
upper_bounds={1: end_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{end_id}".encode()},
split_offsets=None,
)
data_file.spec_id = 0
scan_tasks.append(FileScanTask(data_file=data_file))
# Test batching behavior across multiple files
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches(scan_tasks))
# Verify exact batch count and sizes
total_batch_rows = sum(batch.num_rows for batch in batches)
assert total_batch_rows == total_rows, f"Total rows mismatch: expected {total_rows}, got {total_batch_rows}"
# Verify data spans all files
all_ids = []
file_data_counts = dict.fromkeys(range(num_files), 0)
for batch in batches:
batch_ids = batch.column("id").to_pylist()
all_ids.extend(batch_ids)
# Count rows per file in this batch
for value in batch.column("value").to_pylist():
if value.startswith("file_"):
file_idx = int(value.split("_")[1])
file_data_counts[file_idx] += 1
# Verify we have data from all files
assert sorted(all_ids) == list(range(1, total_rows + 1)), f"Data integrity check failed for {description}"
# Verify each file contributed data
for file_idx in range(num_files):
assert file_data_counts[file_idx] == rows_per_file, (
f"File {file_idx} contributed {file_data_counts[file_idx]} rows, expected {rows_per_file}"
)
def test_orc_field_id_extraction() -> None:
"""
Test ORC field ID extraction from PyArrow field metadata.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_field_id_extraction
"""
import pyarrow as pa
from pyiceberg.io.pyarrow import ORC_FIELD_ID_KEY, PYARROW_PARQUET_FIELD_ID_KEY, _get_field_id
# Test 1: Parquet field ID extraction
field_parquet = pa.field("test_parquet", pa.string(), metadata={PYARROW_PARQUET_FIELD_ID_KEY: b"123"})
field_id = _get_field_id(field_parquet)
assert field_id == 123, f"Expected Parquet field ID 123, got {field_id}"
# Test 2: ORC field ID extraction
field_orc = pa.field("test_orc", pa.string(), metadata={ORC_FIELD_ID_KEY: b"456"})
field_id = _get_field_id(field_orc)
assert field_id == 456, f"Expected ORC field ID 456, got {field_id}"
# Test 3: No field ID
field_no_id = pa.field("test_no_id", pa.string())
field_id = _get_field_id(field_no_id)
assert field_id is None, f"Expected None for field without ID, got {field_id}"
# Test 4: Both field IDs present (should prefer Parquet)
field_both = pa.field("test_both", pa.string(), metadata={PYARROW_PARQUET_FIELD_ID_KEY: b"123", ORC_FIELD_ID_KEY: b"456"})
field_id = _get_field_id(field_both)
assert field_id == 123, f"Expected Parquet field ID 123 (preferred), got {field_id}"
# Test 5: Empty metadata
field_empty_metadata = pa.field("test_empty", pa.string(), metadata={})
field_id = _get_field_id(field_empty_metadata)
assert field_id is None, f"Expected None for field with empty metadata, got {field_id}"
# Test 6: Invalid field ID format
field_invalid = pa.field("test_invalid", pa.string(), metadata={ORC_FIELD_ID_KEY: b"not_a_number"})
try:
field_id = _get_field_id(field_invalid)
raise AssertionError("Expected ValueError for invalid field ID format")
except ValueError:
pass # Expected behavior
# Test 7: Different data types
field_int = pa.field("test_int", pa.int32(), metadata={ORC_FIELD_ID_KEY: b"789"})
field_id = _get_field_id(field_int)
assert field_id == 789, f"Expected ORC field ID 789 for int field, got {field_id}"
field_bool = pa.field("test_bool", pa.bool_(), metadata={ORC_FIELD_ID_KEY: b"101"})
field_id = _get_field_id(field_bool)
assert field_id == 101, f"Expected ORC field ID 101 for bool field, got {field_id}"
def test_orc_schema_with_field_ids(tmp_path: Path) -> None:
"""
Test ORC reading with actual field IDs embedded in the schema.
This test creates an ORC file with field IDs and reads it without name mapping.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_schema_with_field_ids
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ORC_FIELD_ID_KEY, ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=False),
)
# Create PyArrow schema with ORC field IDs
arrow_schema = pa.schema(
[
pa.field("id", pa.int32(), metadata={ORC_FIELD_ID_KEY: b"1"}),
pa.field("name", pa.string(), metadata={ORC_FIELD_ID_KEY: b"2"}),
]
)
# Create data with the schema that has field IDs
data = pa.table({"id": pa.array([1, 2, 3], type=pa.int32()), "name": ["alice", "bob", "charlie"]}, schema=arrow_schema)
# Create ORC file
orc_path = tmp_path / "field_id_test.orc"
orc.write_table(data, str(orc_path))
# Create table metadata WITHOUT name mapping (should work with field IDs)
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
# No name mapping - should work with field IDs
},
)
io = PyArrowFileIO()
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=3,
column_sizes={1: 12, 2: 30},
value_counts={1: 3, 2: 3},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"alice"},
upper_bounds={1: b"\x03\x00\x00\x00", 2: b"charlie"},
split_offsets=None,
)
data_file.spec_id = 0
# Read back using ArrowScan - should work without name mapping
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
scan_task = FileScanTask(data_file=data_file)
table_read = scan.to_table([scan_task])
# Verify the data was read correctly
assert table_read.num_rows == 3
assert table_read.num_columns == 2
assert table_read.column_names == ["id", "name"]
# Verify data matches
assert table_read.column("id").to_pylist() == [1, 2, 3]
assert table_read.column("name").to_pylist() == ["alice", "bob", "charlie"]
def test_orc_schema_conversion_with_field_ids() -> None:
"""
Test that schema_to_pyarrow correctly adds ORC field IDs when file_format is specified.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_schema_conversion_with_field_ids
"""
from pyiceberg.io.pyarrow import ORC_FIELD_ID_KEY, PYARROW_PARQUET_FIELD_ID_KEY, schema_to_pyarrow
from pyiceberg.manifest import FileFormat
from pyiceberg.schema import Schema
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=False),
)
# Test 1: Default behavior (should add Parquet field IDs)
arrow_schema_default = schema_to_pyarrow(schema, include_field_ids=True)
id_field = arrow_schema_default.field(0) # id field
name_field = arrow_schema_default.field(1) # name field
assert id_field.metadata[PYARROW_PARQUET_FIELD_ID_KEY] == b"1"
assert name_field.metadata[PYARROW_PARQUET_FIELD_ID_KEY] == b"2"
assert ORC_FIELD_ID_KEY not in id_field.metadata
assert ORC_FIELD_ID_KEY not in name_field.metadata
# Test 2: Explicitly specify ORC format
arrow_schema_orc = schema_to_pyarrow(schema, include_field_ids=True, file_format=FileFormat.ORC)
id_field_orc = arrow_schema_orc.field(0) # id field
name_field_orc = arrow_schema_orc.field(1) # name field
assert id_field_orc.metadata[ORC_FIELD_ID_KEY] == b"1"
assert name_field_orc.metadata[ORC_FIELD_ID_KEY] == b"2"
assert PYARROW_PARQUET_FIELD_ID_KEY not in id_field_orc.metadata
assert PYARROW_PARQUET_FIELD_ID_KEY not in name_field_orc.metadata
# Test 3: No field IDs
arrow_schema_no_ids = schema_to_pyarrow(schema, include_field_ids=False, file_format=FileFormat.ORC)
id_field_no_ids = arrow_schema_no_ids.field(0)
name_field_no_ids = arrow_schema_no_ids.field(1)
assert ORC_FIELD_ID_KEY not in id_field_no_ids.metadata
assert ORC_FIELD_ID_KEY not in name_field_no_ids.metadata
assert PYARROW_PARQUET_FIELD_ID_KEY not in id_field_no_ids.metadata
assert PYARROW_PARQUET_FIELD_ID_KEY not in name_field_no_ids.metadata
def test_orc_schema_conversion_with_required_attribute() -> None:
"""
Test that schema_to_pyarrow correctly adds ORC iceberg.required attribute.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_schema_conversion_with_required_attribute
"""
from pyiceberg.io.pyarrow import ORC_FIELD_REQUIRED_KEY, schema_to_pyarrow
from pyiceberg.manifest import FileFormat
from pyiceberg.schema import Schema
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=False),
)
# Test 1: Specify Parquet format
arrow_schema_default = schema_to_pyarrow(schema, file_format=FileFormat.PARQUET)
id_field = arrow_schema_default.field(0)
name_field = arrow_schema_default.field(1)
assert ORC_FIELD_REQUIRED_KEY not in id_field.metadata
assert ORC_FIELD_REQUIRED_KEY not in name_field.metadata
# Test 2: Specify ORC format
arrow_schema_orc = schema_to_pyarrow(schema, file_format=FileFormat.ORC)
id_field_orc = arrow_schema_orc.field(0)
name_field_orc = arrow_schema_orc.field(1)
assert id_field_orc.metadata[ORC_FIELD_REQUIRED_KEY] == b"true"
assert name_field_orc.metadata[ORC_FIELD_REQUIRED_KEY] == b"false"
def test_orc_batching_behavior_documentation(tmp_path: Path) -> None:
"""
Document and verify PyArrow's exact batching behavior for ORC files.
This test serves as comprehensive documentation of how PyArrow batches ORC files.
ORC BATCHING BEHAVIOR SUMMARY:
=============================
1. STRIPE-BASED BATCHING:
- PyArrow creates exactly 1 batch per ORC stripe
- This is similar to how Parquet creates 1 batch per row group
- Number of batches = Number of stripes in the ORC file
2. DEFAULT BEHAVIOR:
- Default ORC writing creates 1 stripe per file (64MB default stripe size)
- Therefore, most ORC files have 1 batch per file by default
- This is why many tests show "1 batch per file" behavior
3. CONFIGURABLE BATCHING:
- ORC CAN have multiple batches per file when configured with multiple stripes
- Use stripe_size parameter when writing ORC files to control batching
- stripe_size < 200KB: PyArrow ignores the parameter, uses default 1024 rows per stripe
- stripe_size >= 200KB: PyArrow respects the parameter and creates stripes accordingly
4. PYARROW CONFIGURATION:
- PyIceberg sets buffer_size=8MB for both Parquet and ORC
- Parquet: Gets the 8MB buffer_size parameter (ds.ParquetFileFormat supports it)
- ORC: Ignores the 8MB buffer_size parameter (ds.OrcFileFormat doesn't support it)
- This means ORC uses PyArrow's default batching behavior (based on stripes)
5. KEY DIFFERENCES FROM PARQUET:
- Parquet: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on row groups)
- ORC: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on stripes)
- Both formats support multiple batches per file when configured properly
- The difference is in default configuration, not fundamental behavior
6. TESTING IMPLICATIONS:
- Tests using default ORC writing will show 1 batch per file
- Tests using custom stripe_size >= 200KB will show multiple batches per file
- Always verify the actual number of stripes in ORC files when testing batching
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_batching_behavior_documentation
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "value", StringType(), required=False),
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]',
},
)
io = PyArrowFileIO()
# Test cases that document the exact behavior (using default ORC writing = 1 stripe per file)
test_cases = [
# (file_count, rows_per_file, expected_batches, description)
(1, 100, 1, "Single small file (1 stripe)"),
(1, 1000, 1, "Single medium file (1 stripe)"),
(1, 5000, 1, "Single large file (1 stripe)"),
(2, 500, 2, "Two small files (1 stripe each)"),
(3, 1000, 3, "Three medium files (1 stripe each)"),
(4, 750, 4, "Four small files (1 stripe each)"),
(2, 2000, 2, "Two large files (1 stripe each)"),
]
for file_count, rows_per_file, expected_batches, description in test_cases:
total_rows = file_count * rows_per_file
scan_tasks = []
for file_idx in range(file_count):
# Create data for this file
start_id = file_idx * rows_per_file + 1
end_id = (file_idx + 1) * rows_per_file
data = pa.table(
{
"id": pa.array(range(start_id, end_id + 1), type=pa.int32()),
"value": [f"file_{file_idx}_value_{i}" for i in range(start_id, end_id + 1)],
}
)
# Create ORC file
orc_path = tmp_path / f"doc_test_{file_idx}_{rows_per_file}_rows.orc"
orc.write_table(data, str(orc_path))
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=rows_per_file,
column_sizes={1: rows_per_file * 4, 2: rows_per_file * 8},
value_counts={1: rows_per_file, 2: rows_per_file},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: start_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{start_id}".encode()},
upper_bounds={1: end_id.to_bytes(4, "little"), 2: f"file_{file_idx}_value_{end_id}".encode()},
split_offsets=None,
)
data_file.spec_id = 0
scan_tasks.append(FileScanTask(data_file=data_file))
# Test batching behavior
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches(scan_tasks))
# Verify exact batch count
assert len(batches) == expected_batches, f"Expected {expected_batches} batches, got {len(batches)} for {description}"
# Verify total rows
total_batch_rows = sum(batch.num_rows for batch in batches)
assert total_batch_rows == total_rows, f"Total rows mismatch: expected {total_rows}, got {total_batch_rows}"
# Verify data integrity
all_ids = []
for batch in batches:
all_ids.extend(batch.column("id").to_pylist())
assert sorted(all_ids) == list(range(1, total_rows + 1)), f"Data integrity check failed for {description}"
def test_parquet_vs_orc_batching_behavior_comparison(tmp_path: Path) -> None:
"""
Compare Parquet vs ORC batching behavior to document the key differences.
Key differences:
- PARQUET: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on row groups)
- ORC: 1 FileScanTask → 1 File → 1 PyArrow Fragment → N Batches (based on stripes)
To run just this test:
pytest tests/io/test_pyarrow.py -k test_parquet_vs_orc_batching_behavior_comparison
"""
import pyarrow as pa
import pyarrow.orc as orc
import pyarrow.parquet as pq
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, NestedField, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "value", StringType(), required=False),
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]',
},
)
io = PyArrowFileIO()
# Test Parquet with different row group sizes
parquet_test_cases = [
(1000, "Small row groups"),
(2000, "Medium row groups"),
(5000, "Large row groups"),
]
for row_group_size, _description in parquet_test_cases:
# Create data
data = pa.table({"id": pa.array(range(1, 10001), type=pa.int32()), "value": [f"value_{i}" for i in range(1, 10001)]})
# Create Parquet file with specific row group size
parquet_path = tmp_path / f"parquet_test_{row_group_size}.parquet"
pq.write_table(data, str(parquet_path), row_group_size=row_group_size)
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(parquet_path),
file_format=FileFormat.PARQUET,
partition=Record(),
file_size_in_bytes=parquet_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=10000,
column_sizes={1: 40000, 2: 80000},
value_counts={1: 10000, 2: 10000},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"},
upper_bounds={1: b"\x10'\x00\x00", 2: b"value_10000"},
split_offsets=None,
)
data_file.spec_id = 0
scan_task = FileScanTask(data_file=data_file)
# Test batching behavior
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches([scan_task]))
expected_batches = 10000 // row_group_size # Number of row groups
# Verify exact batch count based on row groups
assert len(batches) == expected_batches, (
f"Expected {expected_batches} batches for row_group_size={row_group_size}, got {len(batches)}"
)
# Verify total rows
total_rows = sum(batch.num_rows for batch in batches)
assert total_rows == 10000, f"Expected 10000 total rows, got {total_rows}"
orc_test_cases = [
(1000, "Small file"),
(5000, "Medium file"),
(10000, "Large file"),
]
for file_size, _description in orc_test_cases:
# Create data
data = pa.table(
{"id": pa.array(range(1, file_size + 1), type=pa.int32()), "value": [f"value_{i}" for i in range(1, file_size + 1)]}
)
# Create ORC file
orc_path = tmp_path / f"orc_test_{file_size}.orc"
orc.write_table(data, str(orc_path))
# Create DataFile
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=file_size,
column_sizes={1: file_size * 4, 2: file_size * 8},
value_counts={1: file_size, 2: file_size},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"},
upper_bounds={1: file_size.to_bytes(4, "little"), 2: f"value_{file_size}".encode()},
split_offsets=None,
)
data_file.spec_id = 0
scan_task = FileScanTask(data_file=data_file)
# Test batching behavior
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches([scan_task]))
# Verify ORC creates 1 batch per file (with default stripe configuration)
# Note: This is because default ORC writing creates 1 stripe per file
assert len(batches) == 1, (
f"Expected 1 batch for ORC file with {file_size} rows (default stripe config), got {len(batches)}"
)
# Verify total rows
total_rows = sum(batch.num_rows for batch in batches)
assert total_rows == file_size, f"Expected {file_size} total rows, got {total_rows}"
def test_orc_stripe_size_batch_size_compression_interaction(tmp_path: Path) -> None:
"""
Test that demonstrates how stripe size, batch size, and compression interact
to affect ORC batching behavior.
This test shows:
1. How stripe_size affects the number of stripes (and therefore batches)
2. How batch_size affects the number of stripes when stripe_size is small
3. How compression affects both stripe count and file size
4. The relationship between uncompressed target size and actual file size
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_stripe_size_batch_size_compression_interaction
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import IntegerType, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "value", StringType(), required=False),
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]',
},
)
io = PyArrowFileIO()
# Create test data
data = pa.table({"id": pa.array(range(1, 10001), type=pa.int32()), "value": [f"value_{i}" for i in range(1, 10001)]})
# Test different combinations
test_cases = [
# (stripe_size, batch_size, compression, description)
(200000, None, "uncompressed", "200KB stripe, no batch limit, uncompressed"),
(200000, 1000, "uncompressed", "200KB stripe, 1000 batch, uncompressed"),
(200000, None, "snappy", "200KB stripe, no batch limit, snappy"),
(100000, None, "uncompressed", "100KB stripe, no batch limit, uncompressed"),
(500000, None, "uncompressed", "500KB stripe, no batch limit, uncompressed"),
(None, 1000, "uncompressed", "No stripe limit, 1000 batch, uncompressed"),
(None, 2000, "uncompressed", "No stripe limit, 2000 batch, uncompressed"),
]
for stripe_size, batch_size, compression, description in test_cases:
# Create ORC file with specific parameters
orc_path = tmp_path / f"orc_test_{hash(description)}.orc"
write_kwargs: dict[str, Any] = {"compression": compression}
if stripe_size is not None:
write_kwargs["stripe_size"] = stripe_size
if batch_size is not None:
write_kwargs["batch_size"] = batch_size
orc.write_table(data, str(orc_path), **write_kwargs)
# Analyze the ORC file
file_size = orc_path.stat().st_size
orc_file = orc.ORCFile(str(orc_path))
actual_stripes = orc_file.nstripes
# Assert basic file properties
assert file_size > 0, f"ORC file should have non-zero size for {description}"
assert actual_stripes > 0, f"ORC file should have at least one stripe for {description}"
# Assert stripe count expectations based on stripe_size and compression
if stripe_size is not None:
# With stripe_size specified, we expect multiple stripes for small sizes
# But compression can make the data small enough to fit in one stripe
if stripe_size <= 200000 and compression == "uncompressed": # 200KB or smaller, uncompressed
assert actual_stripes > 1, (
f"Expected multiple stripes for small stripe_size={stripe_size} in {description}, got {actual_stripes}"
)
else: # Larger stripe sizes or compressed data might result in single stripe
assert actual_stripes >= 1, (
f"Expected at least 1 stripe for stripe_size={stripe_size} in {description}, got {actual_stripes}"
)
else:
# Without stripe_size, we expect at least 1 stripe
assert actual_stripes >= 1, f"Expected at least 1 stripe for {description}, got {actual_stripes}"
# Test PyArrow batching
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=file_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=10000,
column_sizes={1: 40000, 2: 80000},
value_counts={1: 10000, 2: 10000},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"},
upper_bounds={1: b"\x10'\x00\x00", 2: b"value_10000"},
split_offsets=None,
)
data_file.spec_id = 0
# Test batching behavior
scan_task = FileScanTask(data_file=data_file)
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches([scan_task]))
# Assert batching behavior
assert len(batches) > 0, f"Should have at least one batch for {description}"
assert len(batches) == actual_stripes, (
f"Number of batches should match number of stripes for {description}: "
f"{len(batches)} batches vs {actual_stripes} stripes"
)
# Assert data integrity
total_rows = sum(batch.num_rows for batch in batches)
assert total_rows == 10000, f"Total rows should be 10000 for {description}, got {total_rows}"
# Assert compression effect
if compression == "snappy":
# Snappy compression should result in smaller file size than uncompressed
uncompressed_size = len(data) * 15 # Rough estimate of uncompressed size
assert file_size < uncompressed_size, (
f"Snappy compression should reduce file size for {description}: {file_size} vs {uncompressed_size}"
)
elif compression == "uncompressed":
# Uncompressed should be larger than snappy
assert file_size > 0, f"Uncompressed file should have size > 0 for {description}"
def test_orc_near_perfect_stripe_size_mapping(tmp_path: Path) -> None:
"""
Test that demonstrates near-perfect 1:1 mapping between stripe size and actual file size.
This test shows how to achieve ratios of 0.9+ (actual/target) by using:
1. Large stripe sizes (2-5MB)
2. Large datasets (50K+ rows)
3. Data that is hard to compress (large random strings)
4. uncompressed compression setting
This is the closest we can get to having stripe size directly map to number of batches
without significant ORC encoding overhead.
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_near_perfect_stripe_size_mapping
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import NestedField, StringType
# Define schema
schema = Schema(
NestedField(1, "id", StringType(), required=True), # Large string field
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=1,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}]',
},
)
io = PyArrowFileIO()
# Create large dataset with hard-to-compress data
data = pa.table(
{
"id": pa.array(
[
(
f"very_long_string_value_{i:06d}_with_lots_of_padding_to_make_it_harder_to_compress_"
f"{i * 7919 % 100000:05d}_more_padding_{i * 7919 % 100000:05d}"
)
for i in range(1, 50001)
]
) # 50K rows
}
)
# Test with large stripe sizes that should give us multiple stripes
test_cases = [
(2000000, "2MB stripe size"),
(3000000, "3MB stripe size"),
(4000000, "4MB stripe size"),
(5000000, "5MB stripe size"),
]
for stripe_size, _description in test_cases:
# Create ORC file with specific stripe size
orc_path = tmp_path / f"orc_perfect_test_{stripe_size}.orc"
orc.write_table(data, str(orc_path), stripe_size=stripe_size, compression="uncompressed")
# Analyze the ORC file
file_size = orc_path.stat().st_size
orc_file = orc.ORCFile(str(orc_path))
actual_stripes = orc_file.nstripes
# Assert basic file properties
assert file_size > 0, f"ORC file should have non-zero size for stripe_size={stripe_size}"
assert actual_stripes > 0, f"ORC file should have at least one stripe for stripe_size={stripe_size}"
# Assert that larger stripe sizes result in fewer stripes
# With 50K rows of large strings, we expect multiple stripes for smaller stripe sizes
if stripe_size <= 3000000: # 3MB or smaller
assert actual_stripes > 1, f"Expected multiple stripes for stripe_size={stripe_size}, got {actual_stripes}"
else: # Larger stripe sizes might result in single stripe
assert actual_stripes >= 1, f"Expected at least 1 stripe for stripe_size={stripe_size}, got {actual_stripes}"
# Test PyArrow batching
from pyiceberg.manifest import DataFile
from pyiceberg.typedef import Record
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=file_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=50000,
column_sizes={1: 5450000},
value_counts={1: 50000},
null_value_counts={1: 0},
nan_value_counts={1: 0},
lower_bounds={1: b"very_long_string_value_000001_"},
upper_bounds={1: b"very_long_string_value_050000_"},
split_offsets=None,
)
data_file.spec_id = 0
# Test batching behavior
scan_task = FileScanTask(data_file=data_file)
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches([scan_task]))
# Assert batching behavior
assert len(batches) > 0, f"Should have at least one batch for stripe_size={stripe_size}"
assert len(batches) == actual_stripes, (
f"Number of batches should match number of stripes for stripe_size={stripe_size}: "
f"{len(batches)} batches vs {actual_stripes} stripes"
)
# Assert data integrity
total_rows = sum(batch.num_rows for batch in batches)
assert total_rows == 50000, f"Total rows should be 50000 for stripe_size={stripe_size}, got {total_rows}"
# Assert compression ratio is reasonable (uncompressed should be close to raw data size)
raw_data_size = data.nbytes
compression_ratio = raw_data_size / file_size if file_size > 0 else 0
assert compression_ratio > 0.5, (
f"Compression ratio should be reasonable for stripe_size={stripe_size}: {compression_ratio:.2f}"
)
assert compression_ratio < 2.0, (
f"Compression ratio should not be too high for stripe_size={stripe_size}: {compression_ratio:.2f}"
)
def test_orc_stripe_based_batching(tmp_path: Path) -> None:
"""
Test ORC stripe-based batching to demonstrate that ORC can have multiple batches per file.
This corrects the previous understanding that ORC always has 1 batch per file.
This test uses hardcoded expected values based on observed behavior with 10,000 rows:
- 200KB stripe size: 5 stripes of [2048, 2048, 2048, 2048, 1808] rows
- 400KB stripe size: 2 stripes of [7168, 2832] rows
- 600KB stripe size: 1 stripe of [10000] rows
To run just this test:
pytest tests/io/test_pyarrow.py -k test_orc_stripe_based_batching
"""
import pyarrow as pa
import pyarrow.orc as orc
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
from pyiceberg.manifest import DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import FileScanTask
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.typedef import Record
from pyiceberg.types import IntegerType, NestedField, StringType
# Define schema
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "value", StringType(), required=False),
)
# Create table metadata
table_metadata = TableMetadataV2(
location=f"file://{tmp_path}/test_location",
last_column_id=2,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={
"schema.name-mapping.default": '[{"field-id": 1, "names": ["id"]}, {"field-id": 2, "names": ["value"]}]',
},
)
io = PyArrowFileIO()
# Test ORC with different stripe configurations (stripe_size in bytes)
# Note: PyArrow ORC ignores stripe_size < 200KB, so we use larger values
# Expected values are hardcoded based on observed behavior with 10,000 rows
test_cases = [
(200000, "Small stripes (200KB)", 5, [2048, 2048, 2048, 2048, 1808]),
(400000, "Medium stripes (400KB)", 2, [7168, 2832]),
(600000, "Large stripes (600KB)", 1, [10000]),
]
for stripe_size, _description, expected_stripes, expected_stripe_sizes in test_cases:
# Create data
data = pa.table({"id": pa.array(range(1, 10001), type=pa.int32()), "value": [f"value_{i}" for i in range(1, 10001)]})
# Create ORC file with specific stripe size (in bytes)
orc_path = tmp_path / f"orc_stripe_test_{stripe_size}.orc"
orc.write_table(data, str(orc_path), stripe_size=stripe_size)
# Check ORC metadata
orc_file = orc.ORCFile(str(orc_path))
actual_stripes = orc_file.nstripes
actual_stripe_sizes = [orc_file.read_stripe(i).num_rows for i in range(actual_stripes)]
# Create DataFile
from pyiceberg.manifest import DataFile
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=str(orc_path),
file_format=FileFormat.ORC,
partition=Record(),
file_size_in_bytes=orc_path.stat().st_size,
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
record_count=10000,
column_sizes={1: 40000, 2: 80000},
value_counts={1: 10000, 2: 10000},
null_value_counts={1: 0, 2: 0},
nan_value_counts={1: 0, 2: 0},
lower_bounds={1: b"\x01\x00\x00\x00", 2: b"value_1"},
upper_bounds={1: b"\x10'\x00\x00", 2: b"value_10000"},
split_offsets=None,
)
data_file.spec_id = 0
# Test batching behavior
scan_task = FileScanTask(data_file=data_file)
scan = ArrowScan(
table_metadata=table_metadata,
io=io,
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)
batches = list(scan.to_record_batches([scan_task]))
# CRITICAL: Verify we get multiple batches for a single file (when stripe size is small enough)
if expected_stripes > 1:
assert len(batches) > 1, f"Expected multiple batches for single file, got {len(batches)} batches"
assert actual_stripes > 1, f"Expected multiple stripes for single file, got {actual_stripes} stripes"
# Verify exact batch count matches expected
assert len(batches) == expected_stripes, f"Expected {expected_stripes} batches, got {len(batches)}"
# Verify batch sizes match expected stripe sizes
batch_sizes = [batch.num_rows for batch in batches]
assert batch_sizes == expected_stripe_sizes, (
f"Batch sizes {batch_sizes} don't match expected stripe sizes {expected_stripe_sizes}"
)
# Verify actual ORC metadata matches expected
assert actual_stripes == expected_stripes, f"Expected {expected_stripes} stripes, got {actual_stripes}"
assert actual_stripe_sizes == expected_stripe_sizes, (
f"Expected stripe sizes {expected_stripe_sizes}, got {actual_stripe_sizes}"
)
# Verify total rows
total_rows = sum(batch.num_rows for batch in batches)
assert total_rows == 10000, f"Expected 10000 total rows, got {total_rows}"
def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCatalog) -> None:
"""Test column projection on partitioned table after schema evolution (https://github.com/apache/iceberg-python/issues/2672)."""
initial_schema = Schema(
NestedField(1, "partition_date", DateType(), required=False),
NestedField(2, "id", IntegerType(), required=False),
NestedField(3, "name", StringType(), required=False),
NestedField(4, "value", IntegerType(), required=False),
)
partition_spec = PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="partition_date"),
)
catalog.create_namespace("default")
table = catalog.create_table(
"default.test_schema_evolution_projection",
schema=initial_schema,
partition_spec=partition_spec,
)
data_v1 = pa.Table.from_pylist(
[
{"partition_date": date(2024, 1, 1), "id": 1, "name": "Alice", "value": 100},
{"partition_date": date(2024, 1, 1), "id": 2, "name": "Bob", "value": 200},
],
schema=pa.schema(
[
("partition_date", pa.date32()),
("id", pa.int32()),
("name", pa.string()),
("value", pa.int32()),
]
),
)
table.append(data_v1)
with table.update_schema() as update:
update.add_column("new_column", StringType())
table = catalog.load_table("default.test_schema_evolution_projection")
data_v2 = pa.Table.from_pylist(
[
{"partition_date": date(2024, 1, 2), "id": 3, "name": "Charlie", "value": 300, "new_column": "new1"},
{"partition_date": date(2024, 1, 2), "id": 4, "name": "David", "value": 400, "new_column": "new2"},
],
schema=pa.schema(
[
("partition_date", pa.date32()),
("id", pa.int32()),
("name", pa.string()),
("value", pa.int32()),
("new_column", pa.string()),
]
),
)
table.append(data_v2)
result = table.scan(selected_fields=("id", "name", "value", "new_column")).to_arrow()
assert set(result.schema.names) == {"id", "name", "value", "new_column"}
assert result.num_rows == 4
result_sorted = result.sort_by("name")
assert result_sorted["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David"]
assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"]