| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=protected-access,unused-argument,redefined-outer-name |
| import logging |
| import os |
| import tempfile |
| import uuid |
| import warnings |
| from datetime import date |
| from typing import Any, List, Optional |
| from unittest.mock import MagicMock, patch |
| from uuid import uuid4 |
| |
| import pyarrow |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| import pytest |
| from packaging import version |
| from pyarrow.fs import AwsDefaultS3RetryStrategy, FileType, LocalFileSystem, S3FileSystem |
| |
| from pyiceberg.exceptions import ResolveError |
| from pyiceberg.expressions import ( |
| AlwaysFalse, |
| AlwaysTrue, |
| And, |
| BooleanExpression, |
| BoundEqualTo, |
| BoundGreaterThan, |
| BoundGreaterThanOrEqual, |
| BoundIn, |
| BoundIsNaN, |
| BoundIsNull, |
| BoundLessThan, |
| BoundLessThanOrEqual, |
| BoundNotEqualTo, |
| BoundNotIn, |
| BoundNotNaN, |
| BoundNotNull, |
| BoundNotStartsWith, |
| BoundReference, |
| BoundStartsWith, |
| GreaterThan, |
| Not, |
| Or, |
| ) |
| from pyiceberg.expressions.literals import literal |
| from pyiceberg.io import S3_RETRY_STRATEGY_IMPL, InputStream, OutputStream, load_file_io |
| from pyiceberg.io.pyarrow import ( |
| ICEBERG_SCHEMA, |
| ArrowScan, |
| PyArrowFile, |
| PyArrowFileIO, |
| StatsAggregator, |
| _check_pyarrow_schema_compatible, |
| _ConvertToArrowSchema, |
| _determine_partitions, |
| _primitive_to_physical, |
| _read_deletes, |
| _to_requested_schema, |
| bin_pack_arrow_table, |
| compute_statistics_plan, |
| data_file_statistics_from_parquet_metadata, |
| expression_to_pyarrow, |
| parquet_path_to_id_mapping, |
| schema_to_pyarrow, |
| ) |
| from pyiceberg.manifest import DataFile, DataFileContent, FileFormat |
| from pyiceberg.partitioning import PartitionField, PartitionSpec |
| from pyiceberg.schema import Schema, make_compatible_name, visit |
| from pyiceberg.table import FileScanTask, TableProperties |
| from pyiceberg.table.metadata import TableMetadataV2 |
| from pyiceberg.table.name_mapping import create_mapping_from_schema |
| from pyiceberg.transforms import HourTransform, IdentityTransform |
| from pyiceberg.typedef import UTF8, Properties, Record |
| from pyiceberg.types import ( |
| BinaryType, |
| BooleanType, |
| DateType, |
| DecimalType, |
| DoubleType, |
| FixedType, |
| FloatType, |
| IntegerType, |
| ListType, |
| LongType, |
| MapType, |
| NestedField, |
| PrimitiveType, |
| StringType, |
| StructType, |
| TimestampType, |
| TimestamptzType, |
| TimeType, |
| ) |
| from tests.catalog.test_base import InMemoryCatalog |
| from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES |
| |
| skip_if_pyarrow_too_old = pytest.mark.skipif( |
| version.parse(pyarrow.__version__) < version.parse("20.0.0"), |
| reason="Requires pyarrow version >= 20.0.0", |
| ) |
| |
| |
| def test_pyarrow_infer_local_fs_from_path() -> None: |
| """Test path with `file` scheme and no scheme both use LocalFileSystem""" |
| assert isinstance(PyArrowFileIO().new_output("file://tmp/warehouse")._filesystem, LocalFileSystem) |
| assert isinstance(PyArrowFileIO().new_output("/tmp/warehouse")._filesystem, LocalFileSystem) |
| |
| |
| def test_pyarrow_local_fs_can_create_path_without_parent_dir() -> None: |
| """Test LocalFileSystem can create path without first creating the parent directories""" |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_path = f"{tmpdirname}/foo/bar/baz.txt" |
| output_file = PyArrowFileIO().new_output(file_path) |
| parent_path = os.path.dirname(file_path) |
| assert output_file._filesystem.get_file_info(parent_path).type == FileType.NotFound |
| try: |
| with output_file.create() as f: |
| f.write(b"foo") |
| except Exception: |
| pytest.fail("Failed to write to file without parent directory") |
| |
| |
| def test_pyarrow_input_file() -> None: |
| """Test reading a file using PyArrowFile""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| with open(file_location, "wb") as f: |
| f.write(b"foo") |
| |
| # Confirm that the file initially exists |
| assert os.path.exists(file_location) |
| |
| # Instantiate the input file |
| absolute_file_location = os.path.abspath(file_location) |
| input_file = PyArrowFileIO().new_input(location=f"{absolute_file_location}") |
| |
| # Test opening and reading the file |
| r = input_file.open(seekable=False) |
| assert isinstance(r, InputStream) # Test that the file object abides by the InputStream protocol |
| data = r.read() |
| assert data == b"foo" |
| assert len(input_file) == 3 |
| with pytest.raises(OSError) as exc_info: |
| r.seek(0, 0) |
| assert "only valid on seekable files" in str(exc_info.value) |
| |
| |
| def test_pyarrow_input_file_seekable() -> None: |
| """Test reading a file using PyArrowFile""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| with open(file_location, "wb") as f: |
| f.write(b"foo") |
| |
| # Confirm that the file initially exists |
| assert os.path.exists(file_location) |
| |
| # Instantiate the input file |
| absolute_file_location = os.path.abspath(file_location) |
| input_file = PyArrowFileIO().new_input(location=f"{absolute_file_location}") |
| |
| # Test opening and reading the file |
| r = input_file.open(seekable=True) |
| assert isinstance(r, InputStream) # Test that the file object abides by the InputStream protocol |
| data = r.read() |
| assert data == b"foo" |
| assert len(input_file) == 3 |
| r.seek(0, 0) |
| data = r.read() |
| assert data == b"foo" |
| assert len(input_file) == 3 |
| |
| |
| def test_pyarrow_output_file() -> None: |
| """Test writing a file using PyArrowFile""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| |
| # Instantiate the output file |
| absolute_file_location = os.path.abspath(file_location) |
| output_file = PyArrowFileIO().new_output(location=f"{absolute_file_location}") |
| |
| # Create the output file and write to it |
| f = output_file.create() |
| assert isinstance(f, OutputStream) # Test that the file object abides by the OutputStream protocol |
| f.write(b"foo") |
| |
| # Confirm that bytes were written |
| with open(file_location, "rb") as f: |
| assert f.read() == b"foo" |
| |
| assert len(output_file) == 3 |
| |
| |
| def test_pyarrow_invalid_scheme() -> None: |
| """Test that a ValueError is raised if a location is provided with an invalid scheme""" |
| |
| with pytest.raises(ValueError) as exc_info: |
| PyArrowFileIO().new_input("foo://bar/baz.txt") |
| |
| assert "Unrecognized filesystem type in URI" in str(exc_info.value) |
| |
| with pytest.raises(ValueError) as exc_info: |
| PyArrowFileIO().new_output("foo://bar/baz.txt") |
| |
| assert "Unrecognized filesystem type in URI" in str(exc_info.value) |
| |
| |
| def test_pyarrow_violating_input_stream_protocol() -> None: |
| """Test that a TypeError is raised if an input file is provided that violates the InputStream protocol""" |
| |
| # Missing seek, tell, closed, and close |
| input_file_mock = MagicMock(spec=["read"]) |
| |
| # Create a mocked filesystem that returns input_file_mock |
| filesystem_mock = MagicMock() |
| filesystem_mock.open_input_file.return_value = input_file_mock |
| |
| input_file = PyArrowFile("foo.txt", path="foo.txt", fs=filesystem_mock) |
| |
| f = input_file.open() |
| assert not isinstance(f, InputStream) |
| |
| |
| def test_pyarrow_violating_output_stream_protocol() -> None: |
| """Test that a TypeError is raised if an output stream is provided that violates the OutputStream protocol""" |
| |
| # Missing closed, and close |
| output_file_mock = MagicMock(spec=["write", "exists"]) |
| output_file_mock.exists.return_value = False |
| |
| file_info_mock = MagicMock() |
| file_info_mock.type = FileType.NotFound |
| |
| # Create a mocked filesystem that returns output_file_mock |
| filesystem_mock = MagicMock() |
| filesystem_mock.open_output_stream.return_value = output_file_mock |
| filesystem_mock.get_file_info.return_value = file_info_mock |
| |
| output_file = PyArrowFile("foo.txt", path="foo.txt", fs=filesystem_mock) |
| |
| f = output_file.create() |
| |
| assert not isinstance(f, OutputStream) |
| |
| |
| def test_raise_on_opening_a_local_file_not_found() -> None: |
| """Test that a PyArrowFile raises appropriately when a local file is not found""" |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| file_location = os.path.join(tmpdirname, "foo.txt") |
| f = PyArrowFileIO().new_input(file_location) |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| f.open() |
| |
| assert "[Errno 2] Failed to open local file" in str(exc_info.value) |
| |
| |
| def test_raise_on_opening_an_s3_file_no_permission() -> None: |
| """Test that opening a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.open_input_file.side_effect = OSError("AWS Error [code 15]") |
| |
| f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock) |
| |
| with pytest.raises(PermissionError) as exc_info: |
| f.open() |
| |
| assert "Cannot open file, access denied:" in str(exc_info.value) |
| |
| |
| def test_raise_on_opening_an_s3_file_not_found() -> None: |
| """Test that a PyArrowFile raises a FileNotFoundError when the pyarrow error includes 'Path does not exist'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.open_input_file.side_effect = OSError("Path does not exist") |
| |
| f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock) |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| f.open() |
| |
| assert "Cannot open file, does not exist:" in str(exc_info.value) |
| |
| |
| @patch("pyiceberg.io.pyarrow.PyArrowFile.exists", return_value=False) |
| def test_raise_on_creating_an_s3_file_no_permission(_: Any) -> None: |
| """Test that creating a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.open_output_stream.side_effect = OSError("AWS Error [code 15]") |
| |
| f = PyArrowFile("s3://foo/bar.txt", path="foo/bar.txt", fs=s3fs_mock) |
| |
| with pytest.raises(PermissionError) as exc_info: |
| f.create() |
| |
| assert "Cannot create file, access denied:" in str(exc_info.value) |
| |
| |
| def test_deleting_s3_file_no_permission() -> None: |
| """Test that a PyArrowFile raises a PermissionError when the pyarrow OSError includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.delete_file.side_effect = OSError("AWS Error [code 15]") |
| |
| with patch.object(PyArrowFileIO, "_initialize_fs") as submocked: |
| submocked.return_value = s3fs_mock |
| |
| with pytest.raises(PermissionError) as exc_info: |
| PyArrowFileIO().delete("s3://foo/bar.txt") |
| |
| assert "Cannot delete file, access denied:" in str(exc_info.value) |
| |
| |
| def test_deleting_s3_file_not_found() -> None: |
| """Test that a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'""" |
| |
| s3fs_mock = MagicMock() |
| s3fs_mock.delete_file.side_effect = OSError("Path does not exist") |
| |
| with patch.object(PyArrowFileIO, "_initialize_fs") as submocked: |
| submocked.return_value = s3fs_mock |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| PyArrowFileIO().delete("s3://foo/bar.txt") |
| |
| assert "Cannot delete file, does not exist:" in str(exc_info.value) |
| |
| |
| def test_deleting_hdfs_file_not_found() -> None: |
| """Test that a PyArrowFile raises a PermissionError when the pyarrow error includes 'No such file or directory'""" |
| |
| hdfs_mock = MagicMock() |
| hdfs_mock.delete_file.side_effect = OSError("Path does not exist") |
| |
| with patch.object(PyArrowFileIO, "_initialize_fs") as submocked: |
| submocked.return_value = hdfs_mock |
| |
| with pytest.raises(FileNotFoundError) as exc_info: |
| PyArrowFileIO().delete("hdfs://foo/bar.txt") |
| |
| assert "Cannot delete file, does not exist:" in str(exc_info.value) |
| |
| |
| def test_pyarrow_s3_session_properties() -> None: |
| session_properties: Properties = { |
| "s3.endpoint": "http://localhost:9000", |
| "s3.access-key-id": "admin", |
| "s3.secret-access-key": "password", |
| "s3.region": "us-east-1", |
| "s3.session-token": "s3.session-token", |
| **UNIFIED_AWS_SESSION_PROPERTIES, |
| } |
| |
| with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| s3_fileio = PyArrowFileIO(properties=session_properties) |
| filename = str(uuid.uuid4()) |
| |
| # Mock `resolve_s3_region` to prevent from the location used resolving to a different s3 region |
| mock_s3_region_resolver.side_effect = OSError("S3 bucket is not found") |
| s3_fileio.new_input(location=f"s3://warehouse/{filename}") |
| |
| mock_s3fs.assert_called_with( |
| endpoint_override="http://localhost:9000", |
| access_key="admin", |
| secret_key="password", |
| region="us-east-1", |
| session_token="s3.session-token", |
| ) |
| |
| |
| def test_pyarrow_unified_session_properties() -> None: |
| session_properties: Properties = { |
| "s3.endpoint": "http://localhost:9000", |
| **UNIFIED_AWS_SESSION_PROPERTIES, |
| } |
| |
| with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| s3_fileio = PyArrowFileIO(properties=session_properties) |
| filename = str(uuid.uuid4()) |
| |
| mock_s3_region_resolver.return_value = "client.region" |
| s3_fileio.new_input(location=f"s3://warehouse/{filename}") |
| |
| mock_s3fs.assert_called_with( |
| endpoint_override="http://localhost:9000", |
| access_key="client.access-key-id", |
| secret_key="client.secret-access-key", |
| region="client.region", |
| session_token="client.session-token", |
| ) |
| |
| |
| def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None: |
| actual = schema_to_pyarrow(table_schema_nested) |
| expected = """foo: large_string |
| -- field metadata -- |
| PARQUET:field_id: '1' |
| bar: int32 not null |
| -- field metadata -- |
| PARQUET:field_id: '2' |
| baz: bool |
| -- field metadata -- |
| PARQUET:field_id: '3' |
| qux: large_list<element: large_string not null> not null |
| child 0, element: large_string not null |
| -- field metadata -- |
| PARQUET:field_id: '5' |
| -- field metadata -- |
| PARQUET:field_id: '4' |
| quux: map<large_string, map<large_string, int32>> not null |
| child 0, entries: struct<key: large_string not null, value: map<large_string, int32> not null> not null |
| child 0, key: large_string not null |
| -- field metadata -- |
| PARQUET:field_id: '7' |
| child 1, value: map<large_string, int32> not null |
| child 0, entries: struct<key: large_string not null, value: int32 not null> not null |
| child 0, key: large_string not null |
| -- field metadata -- |
| PARQUET:field_id: '9' |
| child 1, value: int32 not null |
| -- field metadata -- |
| PARQUET:field_id: '10' |
| -- field metadata -- |
| PARQUET:field_id: '8' |
| -- field metadata -- |
| PARQUET:field_id: '6' |
| location: large_list<element: struct<latitude: float, longitude: float> not null> not null |
| child 0, element: struct<latitude: float, longitude: float> not null |
| child 0, latitude: float |
| -- field metadata -- |
| PARQUET:field_id: '13' |
| child 1, longitude: float |
| -- field metadata -- |
| PARQUET:field_id: '14' |
| -- field metadata -- |
| PARQUET:field_id: '12' |
| -- field metadata -- |
| PARQUET:field_id: '11' |
| person: struct<name: large_string, age: int32 not null> |
| child 0, name: large_string |
| -- field metadata -- |
| PARQUET:field_id: '16' |
| child 1, age: int32 not null |
| -- field metadata -- |
| PARQUET:field_id: '17' |
| -- field metadata -- |
| PARQUET:field_id: '15'""" |
| assert repr(actual) == expected |
| |
| |
| def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: Schema) -> None: |
| actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False) |
| expected = """foo: large_string |
| bar: int32 not null |
| baz: bool |
| qux: large_list<element: large_string not null> not null |
| child 0, element: large_string not null |
| quux: map<large_string, map<large_string, int32>> not null |
| child 0, entries: struct<key: large_string not null, value: map<large_string, int32> not null> not null |
| child 0, key: large_string not null |
| child 1, value: map<large_string, int32> not null |
| child 0, entries: struct<key: large_string not null, value: int32 not null> not null |
| child 0, key: large_string not null |
| child 1, value: int32 not null |
| location: large_list<element: struct<latitude: float, longitude: float> not null> not null |
| child 0, element: struct<latitude: float, longitude: float> not null |
| child 0, latitude: float |
| child 1, longitude: float |
| person: struct<name: large_string, age: int32 not null> |
| child 0, name: large_string |
| child 1, age: int32 not null""" |
| assert repr(actual) == expected |
| |
| |
| def test_fixed_type_to_pyarrow() -> None: |
| length = 22 |
| iceberg_type = FixedType(length) |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.binary(length) |
| |
| |
| def test_decimal_type_to_pyarrow() -> None: |
| precision = 25 |
| scale = 19 |
| iceberg_type = DecimalType(precision, scale) |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.decimal128(precision, scale) |
| |
| |
| def test_boolean_type_to_pyarrow() -> None: |
| iceberg_type = BooleanType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.bool_() |
| |
| |
| def test_integer_type_to_pyarrow() -> None: |
| iceberg_type = IntegerType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.int32() |
| |
| |
| def test_long_type_to_pyarrow() -> None: |
| iceberg_type = LongType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.int64() |
| |
| |
| def test_float_type_to_pyarrow() -> None: |
| iceberg_type = FloatType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.float32() |
| |
| |
| def test_double_type_to_pyarrow() -> None: |
| iceberg_type = DoubleType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.float64() |
| |
| |
| def test_date_type_to_pyarrow() -> None: |
| iceberg_type = DateType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.date32() |
| |
| |
| def test_time_type_to_pyarrow() -> None: |
| iceberg_type = TimeType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.time64("us") |
| |
| |
| def test_timestamp_type_to_pyarrow() -> None: |
| iceberg_type = TimestampType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.timestamp(unit="us") |
| |
| |
| def test_timestamptz_type_to_pyarrow() -> None: |
| iceberg_type = TimestamptzType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.timestamp(unit="us", tz="UTC") |
| |
| |
| def test_string_type_to_pyarrow() -> None: |
| iceberg_type = StringType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_string() |
| |
| |
| def test_binary_type_to_pyarrow() -> None: |
| iceberg_type = BinaryType() |
| assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_binary() |
| |
| |
| def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None: |
| expected = pa.struct( |
| [ |
| pa.field("foo", pa.large_string(), nullable=True, metadata={"field_id": "1"}), |
| pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), |
| pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), |
| ] |
| ) |
| assert visit(table_schema_simple.as_struct(), _ConvertToArrowSchema()) == expected |
| |
| |
| def test_map_type_to_pyarrow() -> None: |
| iceberg_map = MapType( |
| key_id=1, |
| key_type=IntegerType(), |
| value_id=2, |
| value_type=StringType(), |
| value_required=True, |
| ) |
| assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_( |
| pa.field("key", pa.int32(), nullable=False, metadata={"field_id": "1"}), |
| pa.field("value", pa.large_string(), nullable=False, metadata={"field_id": "2"}), |
| ) |
| |
| |
| def test_list_type_to_pyarrow() -> None: |
| iceberg_map = ListType( |
| element_id=1, |
| element_type=IntegerType(), |
| element_required=True, |
| ) |
| assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.large_list( |
| pa.field("element", pa.int32(), nullable=False, metadata={"field_id": "1"}) |
| ) |
| |
| |
| @pytest.fixture |
| def bound_reference(table_schema_simple: Schema) -> BoundReference[str]: |
| return BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1)) |
| |
| |
| @pytest.fixture |
| def bound_double_reference() -> BoundReference[float]: |
| schema = Schema( |
| NestedField(field_id=1, name="foo", field_type=DoubleType(), required=False), |
| schema_id=1, |
| identifier_field_ids=[], |
| ) |
| return BoundReference(schema.find_field(1), schema.accessor_for_field(1)) |
| |
| |
| def test_expr_is_null_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundIsNull(bound_reference))) |
| == "<pyarrow.compute.Expression is_null(foo, {nan_is_null=false})>" |
| ) |
| |
| |
| def test_expr_not_null_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert repr(expression_to_pyarrow(BoundNotNull(bound_reference))) == "<pyarrow.compute.Expression is_valid(foo)>" |
| |
| |
| def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None: |
| assert repr(expression_to_pyarrow(BoundIsNaN(bound_double_reference))) == "<pyarrow.compute.Expression is_nan(foo)>" |
| |
| |
| def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None: |
| assert repr(expression_to_pyarrow(BoundNotNaN(bound_double_reference))) == "<pyarrow.compute.Expression invert(is_nan(foo))>" |
| |
| |
| def test_expr_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundEqualTo(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo == "hello")>' |
| ) |
| |
| |
| def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundNotEqualTo(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo != "hello")>' |
| ) |
| |
| |
| def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundGreaterThanOrEqual(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo >= "hello")>' |
| ) |
| |
| |
| def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundGreaterThan(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo > "hello")>' |
| ) |
| |
| |
| def test_expr_less_than_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundLessThan(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo < "hello")>' |
| ) |
| |
| |
| def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundLessThanOrEqual(bound_reference, literal("hello")))) |
| == '<pyarrow.compute.Expression (foo <= "hello")>' |
| ) |
| |
| |
| def test_expr_in_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert repr(expression_to_pyarrow(BoundIn(bound_reference, {literal("hello"), literal("world")}))) in ( |
| """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[ |
| "hello", |
| "world" |
| ], null_matching_behavior=MATCH})>""", |
| """<pyarrow.compute.Expression is_in(foo, {value_set=large_string:[ |
| "world", |
| "hello" |
| ], null_matching_behavior=MATCH})>""", |
| ) |
| |
| |
| def test_expr_not_in_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert repr(expression_to_pyarrow(BoundNotIn(bound_reference, {literal("hello"), literal("world")}))) in ( |
| """<pyarrow.compute.Expression invert(is_in(foo, {value_set=large_string:[ |
| "hello", |
| "world" |
| ], null_matching_behavior=MATCH}))>""", |
| """<pyarrow.compute.Expression invert(is_in(foo, {value_set=large_string:[ |
| "world", |
| "hello" |
| ], null_matching_behavior=MATCH}))>""", |
| ) |
| |
| |
| def test_expr_starts_with_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundStartsWith(bound_reference, literal("he")))) |
| == '<pyarrow.compute.Expression starts_with(foo, {pattern="he", ignore_case=false})>' |
| ) |
| |
| |
| def test_expr_not_starts_with_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(BoundNotStartsWith(bound_reference, literal("he")))) |
| == '<pyarrow.compute.Expression invert(starts_with(foo, {pattern="he", ignore_case=false}))>' |
| ) |
| |
| |
| def test_and_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(And(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) |
| == '<pyarrow.compute.Expression ((foo == "hello") and is_null(foo, {nan_is_null=false}))>' |
| ) |
| |
| |
| def test_or_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(Or(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference)))) |
| == '<pyarrow.compute.Expression ((foo == "hello") or is_null(foo, {nan_is_null=false}))>' |
| ) |
| |
| |
| def test_not_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert ( |
| repr(expression_to_pyarrow(Not(BoundEqualTo(bound_reference, literal("hello"))))) |
| == '<pyarrow.compute.Expression invert((foo == "hello"))>' |
| ) |
| |
| |
| def test_always_true_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert repr(expression_to_pyarrow(AlwaysTrue())) == "<pyarrow.compute.Expression true>" |
| |
| |
| def test_always_false_to_pyarrow(bound_reference: BoundReference[str]) -> None: |
| assert repr(expression_to_pyarrow(AlwaysFalse())) == "<pyarrow.compute.Expression false>" |
| |
| |
| @pytest.fixture |
| def schema_int() -> Schema: |
| return Schema(NestedField(1, "id", IntegerType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_int_str() -> Schema: |
| return Schema(NestedField(1, "id", IntegerType(), required=False), NestedField(2, "data", StringType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_str() -> Schema: |
| return Schema(NestedField(2, "data", StringType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_long() -> Schema: |
| return Schema(NestedField(3, "id", LongType(), required=False)) |
| |
| |
| @pytest.fixture |
| def schema_struct() -> Schema: |
| return Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType()), |
| NestedField(42, "long", DoubleType()), |
| ), |
| ) |
| ) |
| |
| |
| @pytest.fixture |
| def schema_list() -> Schema: |
| return Schema( |
| NestedField(5, "ids", ListType(51, IntegerType(), element_required=False), required=False), |
| ) |
| |
| |
| @pytest.fixture |
| def schema_list_of_structs() -> Schema: |
| return Schema( |
| NestedField( |
| 5, |
| "locations", |
| ListType( |
| 51, |
| StructType(NestedField(511, "lat", DoubleType()), NestedField(512, "long", DoubleType())), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def schema_map_of_structs() -> Schema: |
| return Schema( |
| NestedField( |
| 5, |
| "locations", |
| MapType( |
| key_id=51, |
| value_id=52, |
| key_type=StringType(), |
| value_type=StructType( |
| NestedField(511, "lat", DoubleType(), required=True), NestedField(512, "long", DoubleType(), required=True) |
| ), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def schema_map() -> Schema: |
| return Schema( |
| NestedField( |
| 5, |
| "properties", |
| MapType( |
| key_id=51, |
| key_type=StringType(), |
| value_id=52, |
| value_type=StringType(), |
| value_required=True, |
| ), |
| required=False, |
| ), |
| ) |
| |
| |
| def _write_table_to_file(filepath: str, schema: pa.Schema, table: pa.Table) -> str: |
| with pq.ParquetWriter(filepath, schema) as writer: |
| writer.write_table(table) |
| return filepath |
| |
| |
| @pytest.fixture |
| def file_int(schema_int: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_int, metadata={ICEBERG_SCHEMA: bytes(schema_int.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/a.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema) |
| ) |
| |
| |
| @pytest.fixture |
| def file_int_str(schema_int_str: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_int_str, metadata={ICEBERG_SCHEMA: bytes(schema_int_str.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/a.parquet", |
| pyarrow_schema, |
| pa.Table.from_arrays([pa.array([0, 1, 2]), pa.array(["0", "1", "2"])], schema=pyarrow_schema), |
| ) |
| |
| |
| @pytest.fixture |
| def file_string(schema_str: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_str, metadata={ICEBERG_SCHEMA: bytes(schema_str.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/b.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array(["0", "1", "2"])], schema=pyarrow_schema) |
| ) |
| |
| |
| @pytest.fixture |
| def file_long(schema_long: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_long, metadata={ICEBERG_SCHEMA: bytes(schema_long.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/c.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema) |
| ) |
| |
| |
| @pytest.fixture |
| def file_struct(schema_struct: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_struct, metadata={ICEBERG_SCHEMA: bytes(schema_struct.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/d.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"location": {"lat": 52.371807, "long": 4.896029}}, |
| {"location": {"lat": 52.387386, "long": 4.646219}}, |
| {"location": {"lat": 52.078663, "long": 4.288788}}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_list(schema_list: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_list, metadata={ICEBERG_SCHEMA: bytes(schema_list.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"ids": list(range(1, 10))}, |
| {"ids": list(range(2, 20))}, |
| {"ids": list(range(3, 30))}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_list_of_structs(schema_list_of_structs: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow( |
| schema_list_of_structs, metadata={ICEBERG_SCHEMA: bytes(schema_list_of_structs.model_dump_json(), UTF8)} |
| ) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"locations": [{"lat": 52.371807, "long": 4.896029}, {"lat": 52.387386, "long": 4.646219}]}, |
| {"locations": []}, |
| {"locations": [{"lat": 52.078663, "long": 4.288788}, {"lat": 52.387386, "long": 4.646219}]}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_map_of_structs(schema_map_of_structs: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow( |
| schema_map_of_structs, metadata={ICEBERG_SCHEMA: bytes(schema_map_of_structs.model_dump_json(), UTF8)} |
| ) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"locations": {"1": {"lat": 52.371807, "long": 4.896029}, "2": {"lat": 52.387386, "long": 4.646219}}}, |
| {"locations": {}}, |
| {"locations": {"3": {"lat": 52.078663, "long": 4.288788}, "4": {"lat": 52.387386, "long": 4.646219}}}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| @pytest.fixture |
| def file_map(schema_map: Schema, tmpdir: str) -> str: |
| pyarrow_schema = schema_to_pyarrow(schema_map, metadata={ICEBERG_SCHEMA: bytes(schema_map.model_dump_json(), UTF8)}) |
| return _write_table_to_file( |
| f"file:{tmpdir}/e.parquet", |
| pyarrow_schema, |
| pa.Table.from_pylist( |
| [ |
| {"properties": [("a", "b")]}, |
| {"properties": [("c", "d")]}, |
| {"properties": [("e", "f"), ("g", "h")]}, |
| ], |
| schema=pyarrow_schema, |
| ), |
| ) |
| |
| |
| def project( |
| schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None |
| ) -> pa.Table: |
| return ArrowScan( |
| table_metadata=TableMetadataV2( |
| location="file://a/b/", |
| last_column_id=1, |
| format_version=2, |
| schemas=[table_schema or schema], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=PyArrowFileIO(), |
| projected_schema=schema, |
| row_filter=expr or AlwaysTrue(), |
| case_sensitive=True, |
| ).to_table( |
| tasks=[ |
| FileScanTask( |
| DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=file, |
| file_format=FileFormat.PARQUET, |
| partition={}, |
| record_count=3, |
| file_size_in_bytes=3, |
| ) |
| ) |
| for file in files |
| ] |
| ) |
| |
| |
| def test_projection_add_column(file_int: str) -> None: |
| schema = Schema( |
| # All new IDs |
| NestedField(10, "id", IntegerType(), required=False), |
| NestedField(20, "list", ListType(21, IntegerType(), element_required=False), required=False), |
| NestedField( |
| 30, |
| "map", |
| MapType(key_id=31, key_type=IntegerType(), value_id=32, value_type=StringType(), value_required=False), |
| required=False, |
| ), |
| NestedField( |
| 40, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "lon", DoubleType(), required=False) |
| ), |
| required=False, |
| ), |
| ) |
| result_table = project(schema, [file_int]) |
| |
| for col in result_table.columns: |
| assert len(col) == 3 |
| |
| for actual, expected in zip(result_table.columns[0], [None, None, None]): |
| assert actual.as_py() == expected |
| |
| for actual, expected in zip(result_table.columns[1], [None, None, None]): |
| assert actual.as_py() == expected |
| |
| for actual, expected in zip(result_table.columns[2], [None, None, None]): |
| assert actual.as_py() == expected |
| |
| for actual, expected in zip(result_table.columns[3], [None, None, None]): |
| assert actual.as_py() == expected |
| assert ( |
| repr(result_table.schema) |
| == """id: int32 |
| list: large_list<element: int32> |
| child 0, element: int32 |
| map: map<int32, large_string> |
| child 0, entries: struct<key: int32 not null, value: large_string> not null |
| child 0, key: int32 not null |
| child 1, value: large_string |
| location: struct<lat: double, lon: double> |
| child 0, lat: double |
| child 1, lon: double""" |
| ) |
| |
| |
| def test_read_list(schema_list: Schema, file_list: str) -> None: |
| result_table = project(schema_list, [file_list]) |
| |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip(result_table.columns[0], [list(range(1, 10)), list(range(2, 20)), list(range(3, 30))]): |
| assert actual.as_py() == expected |
| |
| assert ( |
| repr(result_table.schema) |
| == """ids: large_list<element: int32> |
| child 0, element: int32""" |
| ) |
| |
| |
| def test_read_map(schema_map: Schema, file_map: str) -> None: |
| result_table = project(schema_map, [file_map]) |
| |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip(result_table.columns[0], [[("a", "b")], [("c", "d")], [("e", "f"), ("g", "h")]]): |
| assert actual.as_py() == expected |
| |
| assert ( |
| repr(result_table.schema) |
| == """properties: map<string, string> |
| child 0, entries: struct<key: string not null, value: string not null> not null |
| child 0, key: string not null |
| child 1, value: string not null""" |
| ) |
| |
| |
| def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None: |
| schema = Schema( |
| # A new ID |
| NestedField( |
| 2, |
| "id", |
| MapType(key_id=3, key_type=IntegerType(), value_id=4, value_type=StringType(), value_required=False), |
| required=False, |
| ) |
| ) |
| result_table = project(schema, [file_int]) |
| # Everything should be None |
| for r in result_table.columns[0]: |
| assert r.as_py() is None |
| assert ( |
| repr(result_table.schema) |
| == """id: map<int32, large_string> |
| child 0, entries: struct<key: int32 not null, value: large_string> not null |
| child 0, key: int32 not null |
| child 1, value: large_string""" |
| ) |
| |
| |
| def test_projection_add_column_struct_required(file_int: str) -> None: |
| schema = Schema( |
| # A new ID |
| NestedField( |
| 2, |
| "other_id", |
| IntegerType(), |
| required=True, |
| ) |
| ) |
| with pytest.raises(ResolveError) as exc_info: |
| _ = project(schema, [file_int]) |
| assert "Field is required, and could not be found in the file: 2: other_id: required int" in str(exc_info.value) |
| |
| |
| def test_projection_rename_column(schema_int: Schema, file_int: str) -> None: |
| schema = Schema( |
| # Reuses the id 1 |
| NestedField(1, "other_name", IntegerType(), required=True) |
| ) |
| result_table = project(schema, [file_int]) |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2]): |
| assert actual.as_py() == expected |
| |
| assert repr(result_table.schema) == "other_name: int32 not null" |
| |
| |
| def test_projection_concat_files(schema_int: Schema, file_int: str) -> None: |
| result_table = project(schema_int, [file_int, file_int]) |
| |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2, 0, 1, 2]): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 6 |
| assert repr(result_table.schema) == "id: int32" |
| |
| |
| def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: |
| # Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata. |
| # TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs. |
| # (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875) |
| |
| schema = Schema( |
| NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False) |
| ) |
| |
| partition_spec = PartitionSpec( |
| PartitionField(2, 1000, IdentityTransform(), "partition_id"), |
| ) |
| |
| catalog.create_namespace("default") |
| table = catalog.create_table( |
| "default.test_projection_partition", |
| schema=schema, |
| partition_spec=partition_spec, |
| properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()}, |
| ) |
| |
| file_data = pa.array(["foo", "bar", "baz"], type=pa.string()) |
| file_loc = f"{tmp_path}/test.parquet" |
| pq.write_table(pa.table([file_data], names=["other_field"]), file_loc) |
| |
| statistics = data_file_statistics_from_parquet_metadata( |
| parquet_metadata=pq.read_metadata(file_loc), |
| stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties), |
| parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), |
| ) |
| |
| unpartitioned_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=file_loc, |
| file_format=FileFormat.PARQUET, |
| # projected value |
| partition=Record(1), |
| file_size_in_bytes=os.path.getsize(file_loc), |
| sort_order_id=None, |
| spec_id=table.metadata.default_spec_id, |
| equality_ids=None, |
| key_metadata=None, |
| **statistics.to_serialized_dict(), |
| ) |
| |
| with table.transaction() as transaction: |
| with transaction.update_snapshot().overwrite() as update: |
| update.append_data_file(unpartitioned_file) |
| |
| schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int64())]) |
| assert table.scan().to_arrow() == pa.table( |
| { |
| "other_field": ["foo", "bar", "baz"], |
| "partition_id": [1, 1, 1], |
| }, |
| schema=schema, |
| ) |
| |
| |
| def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: |
| # Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata. |
| # TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs. |
| # (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875) |
| schema = Schema( |
| NestedField(1, "field_1", StringType(), required=False), |
| NestedField(2, "field_2", IntegerType(), required=False), |
| NestedField(3, "field_3", IntegerType(), required=False), |
| ) |
| |
| partition_spec = PartitionSpec( |
| PartitionField(2, 1000, IdentityTransform(), "field_2"), |
| PartitionField(3, 1001, IdentityTransform(), "field_3"), |
| ) |
| |
| catalog.create_namespace("default") |
| table = catalog.create_table( |
| "default.test_projection_partitions", |
| schema=schema, |
| partition_spec=partition_spec, |
| properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()}, |
| ) |
| |
| file_data = pa.array(["foo"], type=pa.string()) |
| file_loc = f"{tmp_path}/test.parquet" |
| pq.write_table(pa.table([file_data], names=["field_1"]), file_loc) |
| |
| statistics = data_file_statistics_from_parquet_metadata( |
| parquet_metadata=pq.read_metadata(file_loc), |
| stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties), |
| parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), |
| ) |
| |
| unpartitioned_file = DataFile.from_args( |
| content=DataFileContent.DATA, |
| file_path=file_loc, |
| file_format=FileFormat.PARQUET, |
| # projected value |
| partition=Record(2, 3), |
| file_size_in_bytes=os.path.getsize(file_loc), |
| sort_order_id=None, |
| spec_id=table.metadata.default_spec_id, |
| equality_ids=None, |
| key_metadata=None, |
| **statistics.to_serialized_dict(), |
| ) |
| |
| with table.transaction() as transaction: |
| with transaction.update_snapshot().overwrite() as update: |
| update.append_data_file(unpartitioned_file) |
| |
| assert ( |
| str(table.scan().to_arrow()) |
| == """pyarrow.Table |
| field_1: string |
| field_2: int64 |
| field_3: int64 |
| ---- |
| field_1: [["foo"]] |
| field_2: [[2]] |
| field_3: [[3]]""" |
| ) |
| |
| |
| @pytest.fixture |
| def catalog() -> InMemoryCatalog: |
| return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"}) |
| |
| |
| def test_projection_filter(schema_int: Schema, file_int: str) -> None: |
| result_table = project(schema_int, [file_int], GreaterThan("id", 4)) |
| assert len(result_table.columns[0]) == 0 |
| assert repr(result_table.schema) == """id: int32""" |
| |
| |
| def test_projection_filter_renamed_column(file_int: str) -> None: |
| schema = Schema( |
| # Reuses the id 1 |
| NestedField(1, "other_id", IntegerType(), required=True) |
| ) |
| result_table = project(schema, [file_int], GreaterThan("other_id", 1)) |
| assert len(result_table.columns[0]) == 1 |
| assert repr(result_table.schema) == "other_id: int32 not null" |
| |
| |
| def test_projection_filter_add_column(schema_int: Schema, file_int: str, file_string: str) -> None: |
| """We have one file that has the column, and the other one doesn't""" |
| result_table = project(schema_int, [file_int, file_string]) |
| |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2, None, None, None]): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 6 |
| assert repr(result_table.schema) == "id: int32" |
| |
| |
| def test_projection_filter_add_column_promote(file_int: str) -> None: |
| schema_long = Schema(NestedField(1, "id", LongType(), required=True)) |
| result_table = project(schema_long, [file_int]) |
| |
| for actual, expected in zip(result_table.columns[0], [0, 1, 2]): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 3 |
| assert repr(result_table.schema) == "id: int64 not null" |
| |
| |
| def test_projection_filter_add_column_demote(file_long: str) -> None: |
| schema_int = Schema(NestedField(3, "id", IntegerType())) |
| with pytest.raises(ResolveError) as exc_info: |
| _ = project(schema_int, [file_long]) |
| assert "Cannot promote long to int" in str(exc_info.value) |
| |
| |
| def test_projection_nested_struct_subset(file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=True), |
| # long is missing! |
| ), |
| required=True, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| |
| for actual, expected in zip(result_table.columns[0], [52.371807, 52.387386, 52.078663]): |
| assert actual.as_py() == {"lat": expected} |
| |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<lat: double not null> not null |
| child 0, lat: double not null""" |
| ) |
| |
| |
| def test_projection_nested_new_field(file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(43, "null", DoubleType(), required=False), # Whoa, this column doesn't exist in the file |
| ), |
| required=True, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| |
| for actual, expected in zip(result_table.columns[0], [None, None, None]): |
| assert actual.as_py() == {"null": expected} |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<null: double> not null |
| child 0, null: double""" |
| ) |
| |
| |
| def test_projection_nested_struct(schema_struct: Schema, file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 4, |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=False), |
| NestedField(43, "null", DoubleType(), required=False), |
| NestedField(42, "long", DoubleType(), required=False), |
| ), |
| required=True, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| for actual, expected in zip( |
| result_table.columns[0], |
| [ |
| {"lat": 52.371807, "long": 4.896029, "null": None}, |
| {"lat": 52.387386, "long": 4.646219, "null": None}, |
| {"lat": 52.078663, "long": 4.288788, "null": None}, |
| ], |
| ): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<lat: double, null: double, long: double> not null |
| child 0, lat: double |
| child 1, null: double |
| child 2, long: double""" |
| ) |
| |
| |
| def test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of_structs: str) -> None: |
| schema = Schema( |
| NestedField( |
| 5, |
| "locations", |
| ListType( |
| 51, |
| StructType( |
| NestedField(511, "latitude", DoubleType(), required=True), |
| NestedField(512, "longitude", DoubleType(), required=True), |
| NestedField(513, "altitude", DoubleType(), required=False), |
| ), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| result_table = project(schema, [file_list_of_structs]) |
| assert len(result_table.columns) == 1 |
| assert len(result_table.columns[0]) == 3 |
| results = [row.as_py() for row in result_table.columns[0]] |
| assert results == [ |
| [ |
| {"latitude": 52.371807, "longitude": 4.896029, "altitude": None}, |
| {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}, |
| ], |
| [], |
| [ |
| {"latitude": 52.078663, "longitude": 4.288788, "altitude": None}, |
| {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}, |
| ], |
| ] |
| assert ( |
| repr(result_table.schema) |
| == """locations: large_list<element: struct<latitude: double not null, longitude: double not null, altitude: double>> |
| child 0, element: struct<latitude: double not null, longitude: double not null, altitude: double> |
| child 0, latitude: double not null |
| child 1, longitude: double not null |
| child 2, altitude: double""" |
| ) |
| |
| |
| def test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_structs: str) -> None: |
| schema = Schema( |
| NestedField( |
| 5, |
| "locations", |
| MapType( |
| key_id=51, |
| value_id=52, |
| key_type=StringType(), |
| value_type=StructType( |
| NestedField(511, "latitude", DoubleType(), required=True), |
| NestedField(512, "longitude", DoubleType(), required=True), |
| NestedField(513, "altitude", DoubleType()), |
| ), |
| element_required=False, |
| ), |
| required=False, |
| ), |
| ) |
| |
| result_table = project(schema, [file_map_of_structs]) |
| assert len(result_table.columns) == 1 |
| assert len(result_table.columns[0]) == 3 |
| for actual, expected in zip( |
| result_table.columns[0], |
| [ |
| [ |
| ("1", {"latitude": 52.371807, "longitude": 4.896029, "altitude": None}), |
| ("2", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}), |
| ], |
| [], |
| [ |
| ("3", {"latitude": 52.078663, "longitude": 4.288788, "altitude": None}), |
| ("4", {"latitude": 52.387386, "longitude": 4.646219, "altitude": None}), |
| ], |
| ], |
| ): |
| assert actual.as_py() == expected |
| assert ( |
| repr(result_table.schema) |
| == """locations: map<string, struct<latitude: double not null, longitude: double not null, altitude: double>> |
| child 0, entries: struct<key: string not null, value: struct<latitude: double not null, longitude: double not null, altitude: double> not null> not null |
| child 0, key: string not null |
| child 1, value: struct<latitude: double not null, longitude: double not null, altitude: double> not null |
| child 0, latitude: double not null |
| child 1, longitude: double not null |
| child 2, altitude: double""" |
| ) |
| |
| |
| def test_projection_nested_struct_different_parent_id(file_struct: str) -> None: |
| schema = Schema( |
| NestedField( |
| 5, # 😱 this is 4 in the file, this will be fixed when projecting the file schema |
| "location", |
| StructType( |
| NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "long", DoubleType(), required=False) |
| ), |
| required=False, |
| ) |
| ) |
| |
| result_table = project(schema, [file_struct]) |
| for actual, expected in zip(result_table.columns[0], [None, None, None]): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 3 |
| assert ( |
| repr(result_table.schema) |
| == """location: struct<lat: double, long: double> |
| child 0, lat: double |
| child 1, long: double""" |
| ) |
| |
| |
| def test_projection_filter_on_unprojected_field(schema_int_str: Schema, file_int_str: str) -> None: |
| schema = Schema(NestedField(1, "id", IntegerType(), required=True)) |
| |
| result_table = project(schema, [file_int_str], GreaterThan("data", "1"), schema_int_str) |
| |
| for actual, expected in zip( |
| result_table.columns[0], |
| [2], |
| ): |
| assert actual.as_py() == expected |
| assert len(result_table.columns[0]) == 1 |
| assert repr(result_table.schema) == "id: int32 not null" |
| |
| |
| def test_projection_filter_on_unknown_field(schema_int_str: Schema, file_int_str: str) -> None: |
| schema = Schema(NestedField(1, "id", IntegerType())) |
| |
| with pytest.raises(ValueError) as exc_info: |
| _ = project(schema, [file_int_str], GreaterThan("unknown_field", "1"), schema_int_str) |
| |
| assert "Could not find field with name unknown_field, case_sensitive=True" in str(exc_info.value) |
| |
| |
| @pytest.fixture |
| def deletes_file(tmp_path: str, example_task: FileScanTask) -> str: |
| path = example_task.file.file_path |
| table = pa.table({"file_path": [path, path, path], "pos": [1, 3, 5]}) |
| |
| deletes_file_path = f"{tmp_path}/deletes.parquet" |
| pq.write_table(table, deletes_file_path) |
| |
| return deletes_file_path |
| |
| |
| def test_read_deletes(deletes_file: str, example_task: FileScanTask) -> None: |
| deletes = _read_deletes(PyArrowFileIO(), DataFile.from_args(file_path=deletes_file, file_format=FileFormat.PARQUET)) |
| assert set(deletes.keys()) == {example_task.file.file_path} |
| assert list(deletes.values())[0] == pa.chunked_array([[1, 3, 5]]) |
| |
| |
| def test_delete(deletes_file: str, example_task: FileScanTask, table_schema_simple: Schema) -> None: |
| metadata_location = "file://a/b/c.json" |
| example_task_with_delete = FileScanTask( |
| data_file=example_task.file, |
| delete_files={ |
| DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=FileFormat.PARQUET) |
| }, |
| ) |
| with_deletes = ArrowScan( |
| table_metadata=TableMetadataV2( |
| location=metadata_location, |
| last_column_id=1, |
| format_version=2, |
| current_schema_id=1, |
| schemas=[table_schema_simple], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=load_file_io(), |
| projected_schema=table_schema_simple, |
| row_filter=AlwaysTrue(), |
| ).to_table(tasks=[example_task_with_delete]) |
| |
| assert ( |
| str(with_deletes) |
| == """pyarrow.Table |
| foo: large_string |
| bar: int32 not null |
| baz: bool |
| ---- |
| foo: [["a","c"]] |
| bar: [[1,3]] |
| baz: [[true,null]]""" |
| ) |
| |
| |
| def test_delete_duplicates(deletes_file: str, example_task: FileScanTask, table_schema_simple: Schema) -> None: |
| metadata_location = "file://a/b/c.json" |
| example_task_with_delete = FileScanTask( |
| data_file=example_task.file, |
| delete_files={ |
| DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=FileFormat.PARQUET), |
| DataFile.from_args(content=DataFileContent.POSITION_DELETES, file_path=deletes_file, file_format=FileFormat.PARQUET), |
| }, |
| ) |
| |
| with_deletes = ArrowScan( |
| table_metadata=TableMetadataV2( |
| location=metadata_location, |
| last_column_id=1, |
| format_version=2, |
| current_schema_id=1, |
| schemas=[table_schema_simple], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=load_file_io(), |
| projected_schema=table_schema_simple, |
| row_filter=AlwaysTrue(), |
| ).to_table(tasks=[example_task_with_delete]) |
| |
| assert ( |
| str(with_deletes) |
| == """pyarrow.Table |
| foo: large_string |
| bar: int32 not null |
| baz: bool |
| ---- |
| foo: [["a","c"]] |
| bar: [[1,3]] |
| baz: [[true,null]]""" |
| ) |
| |
| |
| def test_pyarrow_wrap_fsspec(example_task: FileScanTask, table_schema_simple: Schema) -> None: |
| metadata_location = "file://a/b/c.json" |
| |
| projection = ArrowScan( |
| table_metadata=TableMetadataV2( |
| location=metadata_location, |
| last_column_id=1, |
| format_version=2, |
| current_schema_id=1, |
| schemas=[table_schema_simple], |
| partition_specs=[PartitionSpec()], |
| ), |
| io=load_file_io(properties={"py-io-impl": "pyiceberg.io.fsspec.FsspecFileIO"}, location=metadata_location), |
| case_sensitive=True, |
| projected_schema=table_schema_simple, |
| row_filter=AlwaysTrue(), |
| ).to_table(tasks=[example_task]) |
| |
| assert ( |
| str(projection) |
| == """pyarrow.Table |
| foo: large_string |
| bar: int32 not null |
| baz: bool |
| ---- |
| foo: [["a","b","c"]] |
| bar: [[1,2,3]] |
| baz: [[true,false,null]]""" |
| ) |
| |
| |
| @pytest.mark.gcs |
| def test_new_input_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test creating a new input file from a fsspec file-io""" |
| filename = str(uuid4()) |
| |
| input_file = pyarrow_fileio_gcs.new_input(f"gs://warehouse/{filename}") |
| |
| assert isinstance(input_file, PyArrowFile) |
| assert input_file.location == f"gs://warehouse/{filename}" |
| |
| |
| @pytest.mark.gcs |
| def test_new_output_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test creating a new output file from an fsspec file-io""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_gcs.new_output(f"gs://warehouse/{filename}") |
| |
| assert isinstance(output_file, PyArrowFile) |
| assert output_file.location == f"gs://warehouse/{filename}" |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_write_and_read_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test writing and reading a file using PyArrowFile""" |
| location = f"gs://warehouse/{uuid4()}.txt" |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| with output_file.create() as f: |
| assert f.write(b"foo") == 3 |
| |
| assert output_file.exists() |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=location) |
| with input_file.open() as f: |
| assert f.read() == b"foo" |
| |
| pyarrow_fileio_gcs.delete(input_file) |
| |
| |
| @pytest.mark.gcs |
| def test_getting_length_of_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test getting the length of PyArrowFile""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}") |
| with output_file.create() as f: |
| f.write(b"foobar") |
| |
| assert len(output_file) == 6 |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}") |
| assert len(input_file) == 6 |
| |
| pyarrow_fileio_gcs.delete(output_file) |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_file_tell_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| location = f"gs://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foobar") |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert f.tell() == 0 |
| f.seek(1) |
| assert f.tell() == 1 |
| f.seek(3) |
| assert f.tell() == 3 |
| f.seek(0) |
| assert f.tell() == 0 |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_read_specified_bytes_for_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| location = f"gs://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert b"f" == f.read(1) |
| f.seek(0) |
| assert b"fo" == f.read(2) |
| f.seek(1) |
| assert b"o" == f.read(1) |
| f.seek(1) |
| assert b"oo" == f.read(2) |
| f.seek(0) |
| assert b"foo" == f.read(999) # test reading amount larger than entire content length |
| |
| pyarrow_fileio_gcs.delete(input_file) |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_raise_on_opening_file_not_found_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test that PyArrowFile raises appropriately when the gcs file is not found""" |
| |
| filename = str(uuid4()) |
| input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}") |
| with pytest.raises(FileNotFoundError) as exc_info: |
| input_file.open().read() |
| |
| assert filename in str(exc_info.value) |
| |
| |
| @pytest.mark.gcs |
| def test_checking_if_a_file_exists_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test checking if a file exists""" |
| non_existent_file = pyarrow_fileio_gcs.new_input(location="gs://warehouse/does-not-exist.txt") |
| assert not non_existent_file.exists() |
| |
| location = f"gs://warehouse/{uuid4()}" |
| output_file = pyarrow_fileio_gcs.new_output(location=location) |
| assert not output_file.exists() |
| with output_file.create() as f: |
| f.write(b"foo") |
| |
| existing_input_file = pyarrow_fileio_gcs.new_input(location=location) |
| assert existing_input_file.exists() |
| |
| existing_output_file = pyarrow_fileio_gcs.new_output(location=location) |
| assert existing_output_file.exists() |
| |
| pyarrow_fileio_gcs.delete(existing_output_file) |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_closing_a_file_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test closing an output file and input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}") |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| assert not write_file.closed # type: ignore |
| assert write_file.closed # type: ignore |
| |
| input_file = pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}") |
| with input_file.open() as f: |
| assert not f.closed # type: ignore |
| assert f.closed # type: ignore |
| |
| pyarrow_fileio_gcs.delete(f"gs://warehouse/{filename}") |
| |
| |
| @pytest.mark.gcs |
| def test_converting_an_outputfile_to_an_inputfile_gcs(pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test converting an output file to an input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}") |
| input_file = output_file.to_input_file() |
| assert input_file.location == output_file.location |
| |
| |
| @pytest.mark.gcs |
| @pytest.mark.skip(reason="Open issue on Arrow: https://github.com/apache/arrow/issues/36993") |
| def test_writing_avro_file_gcs(generated_manifest_entry_file: str, pyarrow_fileio_gcs: PyArrowFileIO) -> None: |
| """Test that bytes match when reading a local avro file, writing it using pyarrow file-io, and then reading it again""" |
| filename = str(uuid4()) |
| with PyArrowFileIO().new_input(location=generated_manifest_entry_file).open() as f: |
| b1 = f.read() |
| with pyarrow_fileio_gcs.new_output(location=f"gs://warehouse/{filename}").create() as out_f: |
| out_f.write(b1) |
| with pyarrow_fileio_gcs.new_input(location=f"gs://warehouse/{filename}").open() as in_f: |
| b2 = in_f.read() |
| assert b1 == b2 # Check that bytes of read from local avro file match bytes written to s3 |
| |
| pyarrow_fileio_gcs.delete(f"gs://warehouse/{filename}") |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_new_input_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test creating a new input file from pyarrow file-io""" |
| filename = str(uuid4()) |
| |
| input_file = pyarrow_fileio_adls.new_input(f"{adls_scheme}://warehouse/{filename}") |
| |
| assert isinstance(input_file, PyArrowFile) |
| assert input_file.location == f"{adls_scheme}://warehouse/{filename}" |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_new_output_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test creating a new output file from pyarrow file-io""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_adls.new_output(f"{adls_scheme}://warehouse/{filename}") |
| |
| assert isinstance(output_file, PyArrowFile) |
| assert output_file.location == f"{adls_scheme}://warehouse/{filename}" |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_write_and_read_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test writing and reading a file using PyArrowFile""" |
| location = f"{adls_scheme}://warehouse/{uuid4()}.txt" |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| with output_file.create() as f: |
| assert f.write(b"foo") == 3 |
| |
| assert output_file.exists() |
| |
| input_file = pyarrow_fileio_adls.new_input(location=location) |
| with input_file.open() as f: |
| assert f.read() == b"foo" |
| |
| pyarrow_fileio_adls.delete(input_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_getting_length_of_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test getting the length of PyArrowFile""" |
| filename = str(uuid4()) |
| |
| output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}") |
| with output_file.create() as f: |
| f.write(b"foobar") |
| |
| assert len(output_file) == 6 |
| |
| input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}") |
| assert len(input_file) == 6 |
| |
| pyarrow_fileio_adls.delete(output_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_file_tell_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| location = f"{adls_scheme}://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foobar") |
| |
| input_file = pyarrow_fileio_adls.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert f.tell() == 0 |
| f.seek(1) |
| assert f.tell() == 1 |
| f.seek(3) |
| assert f.tell() == 3 |
| f.seek(0) |
| assert f.tell() == 0 |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_read_specified_bytes_for_file_adls(pyarrow_fileio_adls: PyArrowFileIO) -> None: |
| location = f"abfss://warehouse/{uuid4()}" |
| |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| |
| input_file = pyarrow_fileio_adls.new_input(location=location) |
| with input_file.open() as f: |
| f.seek(0) |
| assert b"f" == f.read(1) |
| f.seek(0) |
| assert b"fo" == f.read(2) |
| f.seek(1) |
| assert b"o" == f.read(1) |
| f.seek(1) |
| assert b"oo" == f.read(2) |
| f.seek(0) |
| assert b"foo" == f.read(999) # test reading amount larger than entire content length |
| |
| pyarrow_fileio_adls.delete(input_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_raise_on_opening_file_not_found_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test that PyArrowFile raises appropriately when the adls file is not found""" |
| |
| filename = str(uuid4()) |
| input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}") |
| with pytest.raises(FileNotFoundError) as exc_info: |
| input_file.open().read() |
| |
| assert filename in str(exc_info.value) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_checking_if_a_file_exists_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test checking if a file exists""" |
| non_existent_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/does-not-exist.txt") |
| assert not non_existent_file.exists() |
| |
| location = f"{adls_scheme}://warehouse/{uuid4()}" |
| output_file = pyarrow_fileio_adls.new_output(location=location) |
| assert not output_file.exists() |
| with output_file.create() as f: |
| f.write(b"foo") |
| |
| existing_input_file = pyarrow_fileio_adls.new_input(location=location) |
| assert existing_input_file.exists() |
| |
| existing_output_file = pyarrow_fileio_adls.new_output(location=location) |
| assert existing_output_file.exists() |
| |
| pyarrow_fileio_adls.delete(existing_output_file) |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_closing_a_file_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test closing an output file and input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}") |
| with output_file.create() as write_file: |
| write_file.write(b"foo") |
| assert not write_file.closed # type: ignore |
| assert write_file.closed # type: ignore |
| |
| input_file = pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}") |
| with input_file.open() as f: |
| assert not f.closed # type: ignore |
| assert f.closed # type: ignore |
| |
| pyarrow_fileio_adls.delete(f"{adls_scheme}://warehouse/{filename}") |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_converting_an_outputfile_to_an_inputfile_adls(pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test converting an output file to an input file""" |
| filename = str(uuid4()) |
| output_file = pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}") |
| input_file = output_file.to_input_file() |
| assert input_file.location == output_file.location |
| |
| |
| @pytest.mark.adls |
| @skip_if_pyarrow_too_old |
| def test_writing_avro_file_adls(generated_manifest_entry_file: str, pyarrow_fileio_adls: PyArrowFileIO, adls_scheme: str) -> None: |
| """Test that bytes match when reading a local avro file, writing it using pyarrow file-io, and then reading it again""" |
| filename = str(uuid4()) |
| with PyArrowFileIO().new_input(location=generated_manifest_entry_file).open() as f: |
| b1 = f.read() |
| with pyarrow_fileio_adls.new_output(location=f"{adls_scheme}://warehouse/{filename}").create() as out_f: |
| out_f.write(b1) |
| with pyarrow_fileio_adls.new_input(location=f"{adls_scheme}://warehouse/{filename}").open() as in_f: |
| b2 = in_f.read() |
| assert b1 == b2 # Check that bytes of read from local avro file match bytes written to s3 |
| |
| pyarrow_fileio_adls.delete(f"{adls_scheme}://warehouse/{filename}") |
| |
| |
| def test_parse_location() -> None: |
| def check_results(location: str, expected_schema: str, expected_netloc: str, expected_uri: str) -> None: |
| schema, netloc, uri = PyArrowFileIO.parse_location(location) |
| assert schema == expected_schema |
| assert netloc == expected_netloc |
| assert uri == expected_uri |
| |
| check_results("hdfs://127.0.0.1:9000/root/foo.txt", "hdfs", "127.0.0.1:9000", "/root/foo.txt") |
| check_results("hdfs://127.0.0.1/root/foo.txt", "hdfs", "127.0.0.1", "/root/foo.txt") |
| check_results("hdfs://clusterA/root/foo.txt", "hdfs", "clusterA", "/root/foo.txt") |
| |
| check_results("/root/foo.txt", "file", "", "/root/foo.txt") |
| check_results("/root/tmp/foo.txt", "file", "", "/root/tmp/foo.txt") |
| |
| |
| def test_make_compatible_name() -> None: |
| assert make_compatible_name("label/abc") == "label_x2Fabc" |
| assert make_compatible_name("label?abc") == "label_x3Fabc" |
| |
| |
| @pytest.mark.parametrize( |
| "vals, primitive_type, expected_result", |
| [ |
| ([None, 2, 1], IntegerType(), 1), |
| ([1, None, 2], IntegerType(), 1), |
| ([None, None, None], IntegerType(), None), |
| ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 1, 2)), |
| ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 1, 2)), |
| ([None, None, None], DateType(), None), |
| ], |
| ) |
| def test_stats_aggregator_update_min(vals: List[Any], primitive_type: PrimitiveType, expected_result: Any) -> None: |
| stats = StatsAggregator(primitive_type, _primitive_to_physical(primitive_type)) |
| |
| for val in vals: |
| stats.update_min(val) |
| |
| assert stats.current_min == expected_result |
| |
| |
| @pytest.mark.parametrize( |
| "vals, primitive_type, expected_result", |
| [ |
| ([None, 2, 1], IntegerType(), 2), |
| ([1, None, 2], IntegerType(), 2), |
| ([None, None, None], IntegerType(), None), |
| ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 2, 4)), |
| ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 2, 4)), |
| ([None, None, None], DateType(), None), |
| ], |
| ) |
| def test_stats_aggregator_update_max(vals: List[Any], primitive_type: PrimitiveType, expected_result: Any) -> None: |
| stats = StatsAggregator(primitive_type, _primitive_to_physical(primitive_type)) |
| |
| for val in vals: |
| stats.update_max(val) |
| |
| assert stats.current_max == expected_result |
| |
| |
| def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None: |
| # default packs to 1 bin since the table is small |
| bin_packed = bin_pack_arrow_table( |
| arrow_table_with_null, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT |
| ) |
| assert len(list(bin_packed)) == 1 |
| |
| # as long as table is smaller than default target size, it should pack to 1 bin |
| bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10) |
| assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT |
| bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT) |
| assert len(list(bin_packed)) == 1 |
| |
| # unless we override the target size to be smaller |
| bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes) |
| assert len(list(bin_packed)) == 10 |
| |
| # and will produce half the number of files if we double the target size |
| bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2) |
| assert len(list(bin_packed)) == 5 |
| |
| |
| def test_schema_mismatch_type(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.decimal128(18, 6), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| expected = r"""Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| └────┴──────────────────────────┴─────────────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=True), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| expected = """Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ❌ │ 2: bar: required int │ 2: bar: optional int │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| └────┴──────────────────────────┴──────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=False), |
| ) |
| ) |
| |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| expected = """Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ❌ │ 2: bar: required int │ Missing │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| └────┴──────────────────────────┴──────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Schema) -> None: |
| schema = table_schema_nested.as_arrow() |
| schema = schema.remove(6).insert( |
| 6, |
| pa.field( |
| "person", |
| pa.struct( |
| [ |
| pa.field("age", pa.int32(), nullable=False), |
| ] |
| ), |
| nullable=True, |
| ), |
| ) |
| try: |
| _check_pyarrow_schema_compatible(table_schema_nested, schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Schema) -> None: |
| other_schema = table_schema_nested.as_arrow() |
| other_schema = other_schema.remove(6).insert( |
| 6, |
| pa.field( |
| "person", |
| pa.struct( |
| [ |
| pa.field("name", pa.string(), nullable=True), |
| ] |
| ), |
| nullable=True, |
| ), |
| ) |
| expected = """Mismatch in fields: |
| ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| ┃ ┃ Table field ┃ Dataframe field ┃ |
| ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| │ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| │ ✅ │ 2: bar: required int │ 2: bar: required int │ |
| │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| │ ✅ │ 4: qux: required list<string> │ 4: qux: required list<string> │ |
| │ ✅ │ 5: element: required string │ 5: element: required string │ |
| │ ✅ │ 6: quux: required map<string, │ 6: quux: required map<string, │ |
| │ │ map<string, int>> │ map<string, int>> │ |
| │ ✅ │ 7: key: required string │ 7: key: required string │ |
| │ ✅ │ 8: value: required map<string, │ 8: value: required map<string, │ |
| │ │ int> │ int> │ |
| │ ✅ │ 9: key: required string │ 9: key: required string │ |
| │ ✅ │ 10: value: required int │ 10: value: required int │ |
| │ ✅ │ 11: location: required │ 11: location: required │ |
| │ │ list<struct<13: latitude: optional │ list<struct<13: latitude: optional │ |
| │ │ float, 14: longitude: optional │ float, 14: longitude: optional │ |
| │ │ float>> │ float>> │ |
| │ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │ |
| │ │ latitude: optional float, 14: │ latitude: optional float, 14: │ |
| │ │ longitude: optional float> │ longitude: optional float> │ |
| │ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │ |
| │ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │ |
| │ ✅ │ 15: person: optional struct<16: │ 15: person: optional struct<16: │ |
| │ │ name: optional string, 17: age: │ name: optional string> │ |
| │ │ required int> │ │ |
| │ ✅ │ 16: name: optional string │ 16: name: optional string │ |
| │ ❌ │ 17: age: required int │ Missing │ |
| └────┴────────────────────────────────────┴────────────────────────────────────┘ |
| """ |
| |
| with pytest.raises(ValueError, match=expected): |
| _check_pyarrow_schema_compatible(table_schema_nested, other_schema) |
| |
| |
| def test_schema_compatible_nested(table_schema_nested: Schema) -> None: |
| try: |
| _check_pyarrow_schema_compatible(table_schema_nested, table_schema_nested.as_arrow()) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=True), |
| pa.field("new_field", pa.date32(), nullable=True), |
| ) |
| ) |
| |
| with pytest.raises( |
| ValueError, match=r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." |
| ): |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| |
| |
| def test_schema_compatible(table_schema_simple: Schema) -> None: |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_projection(table_schema_simple: Schema) -> None: |
| # remove optional `baz` field from `table_schema_simple` |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| ) |
| ) |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_schema_downcast(table_schema_simple: Schema) -> None: |
| # large_string type is compatible with string type |
| other_schema = pa.schema( |
| ( |
| pa.field("foo", pa.large_string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=False), |
| pa.field("baz", pa.bool_(), nullable=True), |
| ) |
| ) |
| |
| try: |
| _check_pyarrow_schema_compatible(table_schema_simple, other_schema) |
| except Exception: |
| pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") |
| |
| |
| def test_partition_for_demo() -> None: |
| test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) |
| test_schema = Schema( |
| NestedField(field_id=1, name="year", field_type=StringType(), required=False), |
| NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), |
| NestedField(field_id=3, name="animal", field_type=StringType(), required=False), |
| schema_id=1, |
| ) |
| test_data = { |
| "year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], |
| "n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100], |
| "animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], |
| } |
| arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) |
| partition_spec = PartitionSpec( |
| PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), |
| PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), |
| ) |
| result = _determine_partitions(partition_spec, test_schema, arrow_table) |
| assert {table_partition.partition_key.partition for table_partition in result} == { |
| Record(2, 2020), |
| Record(100, 2021), |
| Record(4, 2021), |
| Record(4, 2022), |
| Record(2, 2022), |
| Record(5, 2019), |
| } |
| assert ( |
| pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows |
| ) |
| |
| |
| def test_partition_for_nested_field() -> None: |
| schema = Schema( |
| NestedField(id=1, name="foo", field_type=StringType(), required=True), |
| NestedField( |
| id=2, |
| name="bar", |
| field_type=StructType( |
| NestedField(id=3, name="baz", field_type=TimestampType(), required=False), |
| NestedField(id=4, name="qux", field_type=IntegerType(), required=False), |
| ), |
| required=True, |
| ), |
| ) |
| |
| spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts")) |
| |
| from datetime import datetime |
| |
| t1 = datetime(2025, 7, 11, 9, 30, 0) |
| t2 = datetime(2025, 7, 11, 10, 30, 0) |
| |
| test_data = [ |
| {"foo": "a", "bar": {"baz": t1, "qux": 1}}, |
| {"foo": "b", "bar": {"baz": t2, "qux": 2}}, |
| ] |
| |
| arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) |
| partitions = _determine_partitions(spec, schema, arrow_table) |
| partition_values = {p.partition_key.partition[0] for p in partitions} |
| |
| assert partition_values == {486729, 486730} |
| |
| |
| def test_partition_for_deep_nested_field() -> None: |
| schema = Schema( |
| NestedField( |
| id=1, |
| name="foo", |
| field_type=StructType( |
| NestedField( |
| id=2, |
| name="bar", |
| field_type=StructType(NestedField(id=3, name="baz", field_type=StringType(), required=False)), |
| required=True, |
| ) |
| ), |
| required=True, |
| ) |
| ) |
| |
| spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="qux")) |
| |
| test_data = [ |
| {"foo": {"bar": {"baz": "data-1"}}}, |
| {"foo": {"bar": {"baz": "data-2"}}}, |
| {"foo": {"bar": {"baz": "data-1"}}}, |
| ] |
| |
| arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) |
| partitions = _determine_partitions(spec, schema, arrow_table) |
| |
| assert len(partitions) == 2 # 2 unique partitions |
| partition_values = {p.partition_key.partition[0] for p in partitions} |
| assert partition_values == {"data-1", "data-2"} |
| |
| |
| def test_inspect_partition_for_nested_field(catalog: InMemoryCatalog) -> None: |
| schema = Schema( |
| NestedField(id=1, name="foo", field_type=StringType(), required=True), |
| NestedField( |
| id=2, |
| name="bar", |
| field_type=StructType( |
| NestedField(id=3, name="baz", field_type=StringType(), required=False), |
| NestedField(id=4, name="qux", field_type=IntegerType(), required=False), |
| ), |
| required=True, |
| ), |
| ) |
| spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="part")) |
| catalog.create_namespace("default") |
| table = catalog.create_table("default.test_partition_in_struct", schema=schema, partition_spec=spec) |
| test_data = [ |
| {"foo": "a", "bar": {"baz": "data-a", "qux": 1}}, |
| {"foo": "b", "bar": {"baz": "data-b", "qux": 2}}, |
| ] |
| |
| arrow_table = pa.Table.from_pylist(test_data, schema=table.schema().as_arrow()) |
| table.append(arrow_table) |
| partitions_table = table.inspect.partitions() |
| partitions = partitions_table["partition"].to_pylist() |
| |
| assert len(partitions) == 2 |
| assert {part["part"] for part in partitions} == {"data-a", "data-b"} |
| |
| |
| def test_identity_partition_on_multi_columns() -> None: |
| test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) |
| test_schema = Schema( |
| NestedField(field_id=1, name="born_year", field_type=StringType(), required=False), |
| NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), |
| NestedField(field_id=3, name="animal", field_type=StringType(), required=False), |
| schema_id=1, |
| ) |
| # 5 partitions, 6 unique row values, 12 rows |
| test_rows = [ |
| (2021, 4, "Dog"), |
| (2022, 4, "Horse"), |
| (2022, 4, "Another Horse"), |
| (2021, 100, "Centipede"), |
| (None, 4, "Kirin"), |
| (2021, None, "Fish"), |
| ] * 2 |
| expected = {Record(test_rows[i][1], test_rows[i][0]) for i in range(len(test_rows))} |
| partition_spec = PartitionSpec( |
| PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), |
| PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), |
| ) |
| import random |
| |
| # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all |
| for _ in range(1000): |
| random.shuffle(test_rows) |
| test_data = { |
| "born_year": [row[0] for row in test_rows], |
| "n_legs": [row[1] for row in test_rows], |
| "animal": [row[2] for row in test_rows], |
| } |
| arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) |
| |
| result = _determine_partitions(partition_spec, test_schema, arrow_table) |
| |
| assert {table_partition.partition_key.partition for table_partition in result} == expected |
| concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) |
| assert concatenated_arrow_table.num_rows == arrow_table.num_rows |
| assert concatenated_arrow_table.sort_by( |
| [ |
| ("born_year", "ascending"), |
| ("n_legs", "ascending"), |
| ("animal", "ascending"), |
| ] |
| ) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) |
| |
| |
| def test_initial_value() -> None: |
| # Have some fake data, otherwise it will generate a table without records |
| data = pa.record_batch([pa.nulls(10, pa.int64())], names=["some_field"]) |
| result = _to_requested_schema( |
| Schema(NestedField(1, "we-love-22", LongType(), required=True, initial_default=22)), Schema(), data |
| ) |
| assert result.column_names == ["we-love-22"] |
| for val in result[0]: |
| assert val.as_py() == 22 |
| |
| |
| def test__to_requested_schema_timestamps( |
| arrow_table_schema_with_all_timestamp_precisions: pa.Schema, |
| arrow_table_with_all_timestamp_precisions: pa.Table, |
| arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, |
| table_schema_with_all_microseconds_timestamp_precision: Schema, |
| ) -> None: |
| requested_schema = table_schema_with_all_microseconds_timestamp_precision |
| file_schema = requested_schema |
| batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] |
| result = _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False) |
| |
| expected = arrow_table_with_all_timestamp_precisions.cast( |
| arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False |
| ).to_batches()[0] |
| assert result == expected |
| |
| |
| def test__to_requested_schema_timestamps_without_downcast_raises_exception( |
| arrow_table_schema_with_all_timestamp_precisions: pa.Schema, |
| arrow_table_with_all_timestamp_precisions: pa.Table, |
| arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, |
| table_schema_with_all_microseconds_timestamp_precision: Schema, |
| ) -> None: |
| requested_schema = table_schema_with_all_microseconds_timestamp_precision |
| file_schema = requested_schema |
| batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] |
| with pytest.raises(ValueError) as exc_info: |
| _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) |
| |
| assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value) |
| |
| |
| def test_pyarrow_file_io_fs_by_scheme_cache() -> None: |
| # It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument becomes available for `resolve_s3_region` |
| # Refer to: https://github.com/apache/arrow/issues/43713 |
| |
| pyarrow_file_io = PyArrowFileIO() |
| us_east_1_region = "us-east-1" |
| ap_southeast_2_region = "ap-southeast-2" |
| |
| with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| # Call with new argument resolves region automatically |
| mock_s3_region_resolver.return_value = us_east_1_region |
| filesystem_us = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket") |
| assert filesystem_us.region == us_east_1_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 1 # type: ignore |
| assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 1 # type: ignore |
| |
| # Call with different argument also resolves region automatically |
| mock_s3_region_resolver.return_value = ap_southeast_2_region |
| filesystem_ap_southeast_2 = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket") |
| assert filesystem_ap_southeast_2.region == ap_southeast_2_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 2 # type: ignore |
| assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 2 # type: ignore |
| |
| # Call with same argument hits cache |
| filesystem_us_cached = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket") |
| assert filesystem_us_cached.region == us_east_1_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 1 # type: ignore |
| |
| # Call with same argument hits cache |
| filesystem_ap_southeast_2_cached = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket") |
| assert filesystem_ap_southeast_2_cached.region == ap_southeast_2_region |
| assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 2 # type: ignore |
| |
| |
| def test_pyarrow_io_new_input_multi_region(caplog: Any) -> None: |
| # It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument becomes available for `resolve_s3_region` |
| # Refer to: https://github.com/apache/arrow/issues/43713 |
| user_provided_region = "ap-southeast-1" |
| bucket_regions = [ |
| ("us-east-2-bucket", "us-east-2"), |
| ("ap-southeast-2-bucket", "ap-southeast-2"), |
| ] |
| |
| def _s3_region_map(bucket: str) -> str: |
| for bucket_region in bucket_regions: |
| if bucket_region[0] == bucket: |
| return bucket_region[1] |
| raise OSError("Unknown bucket") |
| |
| # For a pyarrow io instance with configured default s3 region |
| pyarrow_file_io = PyArrowFileIO({"s3.region": user_provided_region, "s3.resolve-region": "true"}) |
| with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| mock_s3_region_resolver.side_effect = _s3_region_map |
| |
| # The region is set to provided region if bucket region cannot be resolved |
| with caplog.at_level(logging.WARNING): |
| assert pyarrow_file_io.new_input("s3://non-exist-bucket/path/to/file")._filesystem.region == user_provided_region |
| assert "Unable to resolve region for bucket non-exist-bucket" in caplog.text |
| |
| for bucket_region in bucket_regions: |
| # For s3 scheme, region is overwritten by resolved bucket region if different from user provided region |
| with caplog.at_level(logging.WARNING): |
| assert pyarrow_file_io.new_input(f"s3://{bucket_region[0]}/path/to/file")._filesystem.region == bucket_region[1] |
| assert ( |
| f"PyArrow FileIO overriding S3 bucket region for bucket {bucket_region[0]}: " |
| f"provided region {user_provided_region}, actual region {bucket_region[1]}" in caplog.text |
| ) |
| |
| # For oss scheme, user provided region is used instead |
| assert pyarrow_file_io.new_input(f"oss://{bucket_region[0]}/path/to/file")._filesystem.region == user_provided_region |
| |
| |
| def test_pyarrow_io_multi_fs() -> None: |
| pyarrow_file_io = PyArrowFileIO({"s3.region": "ap-southeast-1"}) |
| |
| with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: |
| mock_s3_region_resolver.return_value = None |
| |
| # The PyArrowFileIO instance resolves s3 file input to S3FileSystem |
| assert isinstance(pyarrow_file_io.new_input("s3://bucket/path/to/file")._filesystem, S3FileSystem) |
| |
| # Same PyArrowFileIO instance resolves local file input to LocalFileSystem |
| assert isinstance(pyarrow_file_io.new_input("file:///path/to/file")._filesystem, LocalFileSystem) |
| |
| |
| class SomeRetryStrategy(AwsDefaultS3RetryStrategy): |
| def __init__(self) -> None: |
| super().__init__() |
| warnings.warn("Initialized SomeRetryStrategy 👍") |
| |
| |
| def test_retry_strategy() -> None: |
| io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "tests.io.test_pyarrow.SomeRetryStrategy"}) |
| with pytest.warns(UserWarning, match="Initialized SomeRetryStrategy.*"): |
| io.new_input("s3://bucket/path/to/file") |
| |
| |
| def test_retry_strategy_not_found() -> None: |
| io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"}) |
| with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"): |
| io.new_input("s3://bucket/path/to/file") |