| # -*- coding: utf-8 -*- |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import contextlib |
| import datetime |
| import decimal |
| import string |
| |
| import pytest |
| import pyarrow as pa |
| import pytz |
| |
| import arrow_pyarrow_integration_testing as rust |
| |
| PYARROW_PRE_14 = int(pa.__version__.split('.')[0]) < 14 |
| |
| |
| @contextlib.contextmanager |
| def no_pyarrow_leak(): |
| # No leak of C++ memory |
| old_allocation = pa.total_allocated_bytes() |
| try: |
| yield |
| finally: |
| assert pa.total_allocated_bytes() == old_allocation |
| |
| |
| @pytest.fixture(autouse=True) |
| def assert_pyarrow_leak(): |
| # automatically applied to all test cases |
| with no_pyarrow_leak(): |
| yield |
| |
| |
| _supported_pyarrow_types = [ |
| pa.null(), |
| pa.bool_(), |
| pa.int32(), |
| pa.time32("s"), |
| pa.time64("us"), |
| pa.date32(), |
| pa.timestamp("us"), |
| pa.timestamp("us", tz="UTC"), |
| pa.timestamp("us", tz="Europe/Paris"), |
| pa.duration("s"), |
| pa.duration("ms"), |
| pa.duration("us"), |
| pa.duration("ns"), |
| pa.float16(), |
| pa.float32(), |
| pa.float64(), |
| pa.decimal128(19, 4), |
| pa.decimal256(76, 38), |
| pa.string(), |
| pa.binary(), |
| pa.binary(10), |
| pa.large_string(), |
| pa.large_binary(), |
| pa.list_(pa.int32()), |
| pa.list_(pa.int32(), 2), |
| pa.large_list(pa.uint16()), |
| pa.struct( |
| [ |
| pa.field("a", pa.int32()), |
| pa.field("b", pa.int8()), |
| pa.field("c", pa.string()), |
| ] |
| ), |
| pa.struct( |
| [ |
| pa.field("a", pa.int32(), nullable=False), |
| pa.field("b", pa.int8(), nullable=False), |
| pa.field("c", pa.string()), |
| ] |
| ), |
| pa.dictionary(pa.int8(), pa.string()), |
| pa.map_(pa.string(), pa.int32()), |
| pa.union( |
| [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], |
| mode=pa.lib.UnionMode_DENSE, |
| ), |
| pa.union( |
| [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], |
| mode=pa.lib.UnionMode_DENSE, |
| type_codes=[4, 8], |
| ), |
| pa.union( |
| [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], |
| mode=pa.lib.UnionMode_SPARSE, |
| ), |
| pa.union( |
| [ |
| pa.field("a", pa.binary(10), nullable=False), |
| pa.field("b", pa.string()), |
| ], |
| mode=pa.lib.UnionMode_SPARSE, |
| ), |
| ] |
| |
| _unsupported_pyarrow_types = [ |
| ] |
| |
| # As of pyarrow 14, pyarrow implements the Arrow PyCapsule interface |
| # (https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). |
| # This defines that Arrow consumers should allow any object that has specific "dunder" |
| # methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle |
| # _any_ class, without pyarrow-specific handling. |
| class SchemaWrapper: |
| def __init__(self, schema): |
| self.schema = schema |
| |
| def __arrow_c_schema__(self): |
| return self.schema.__arrow_c_schema__() |
| |
| |
| class ArrayWrapper: |
| def __init__(self, array): |
| self.array = array |
| |
| def __arrow_c_array__(self): |
| return self.array.__arrow_c_array__() |
| |
| |
| class StreamWrapper: |
| def __init__(self, stream): |
| self.stream = stream |
| |
| def __arrow_c_stream__(self): |
| return self.stream.__arrow_c_stream__() |
| |
| |
| @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) |
| def test_type_roundtrip(pyarrow_type): |
| restored = rust.round_trip_type(pyarrow_type) |
| assert restored == pyarrow_type |
| assert restored is not pyarrow_type |
| |
| @pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") |
| @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) |
| def test_type_roundtrip_pycapsule(pyarrow_type): |
| wrapped = SchemaWrapper(pyarrow_type) |
| restored = rust.round_trip_type(wrapped) |
| assert restored == pyarrow_type |
| assert restored is not pyarrow_type |
| |
| |
| @pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str) |
| def test_type_roundtrip_raises(pyarrow_type): |
| with pytest.raises(pa.ArrowException): |
| rust.round_trip_type(pyarrow_type) |
| |
| @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) |
| def test_field_roundtrip(pyarrow_type): |
| pyarrow_field = pa.field("test", pyarrow_type, nullable=True) |
| field = rust.round_trip_field(pyarrow_field) |
| assert field == pyarrow_field |
| |
| if pyarrow_type != pa.null(): |
| # A null type field may not be non-nullable |
| pyarrow_field = pa.field("test", pyarrow_type, nullable=False) |
| field = rust.round_trip_field(pyarrow_field) |
| assert field == pyarrow_field |
| |
| @pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") |
| @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) |
| def test_field_roundtrip_pycapsule(pyarrow_type): |
| pyarrow_field = pa.field("test", pyarrow_type, nullable=True) |
| wrapped = SchemaWrapper(pyarrow_field) |
| field = rust.round_trip_field(wrapped) |
| assert field == wrapped.schema |
| |
| if pyarrow_type != pa.null(): |
| # A null type field may not be non-nullable |
| pyarrow_field = pa.field("test", pyarrow_type, nullable=False) |
| field = rust.round_trip_field(wrapped) |
| assert field == wrapped.schema |
| |
| def test_field_metadata_roundtrip(): |
| metadata = {"hello": "World! 😊", "x": "2"} |
| pyarrow_field = pa.field("test", pa.int32(), metadata=metadata) |
| field = rust.round_trip_field(pyarrow_field) |
| assert field == pyarrow_field |
| assert field.metadata == pyarrow_field.metadata |
| |
| def test_schema_roundtrip(): |
| pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types) |
| pyarrow_schema = pa.schema(pyarrow_fields, metadata = {b'key1': b'value1'}) |
| schema = rust.round_trip_schema(pyarrow_schema) |
| assert schema == pyarrow_schema |
| assert schema.metadata == pyarrow_schema.metadata |
| |
| |
| def test_primitive_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array([1, 2, 3]) |
| b = rust.double(a) |
| assert b == pa.array([2, 4, 6]) |
| del a |
| del b |
| |
| |
| @pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") |
| def test_primitive_python_pycapsule(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array([1, 2, 3]) |
| wrapped = ArrayWrapper(a) |
| b = rust.double(wrapped) |
| assert b == pa.array([2, 4, 6]) |
| |
| |
| def test_primitive_rust(): |
| """ |
| Rust -> Python -> Rust |
| """ |
| |
| def double(array): |
| array = array.to_pylist() |
| return pa.array([x * 2 if x is not None else None for x in array]) |
| |
| is_correct = rust.double_py(double) |
| assert is_correct |
| |
| |
| def test_string_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array(["a", None, "ccc"]) |
| b = rust.substring(a, 1) |
| assert b == pa.array(["", None, "cc"]) |
| del a |
| del b |
| |
| |
| def test_time32_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array([None, 1, 2], pa.time32("s")) |
| b = rust.concatenate(a) |
| expected = pa.array([None, 1, 2] + [None, 1, 2], pa.time32("s")) |
| assert b == expected |
| del a |
| del b |
| del expected |
| |
| |
| @pytest.mark.parametrize("datatype", _supported_pyarrow_types, ids=str) |
| def test_empty_array_python(datatype): |
| """ |
| Python -> Rust -> Python |
| """ |
| if datatype == pa.float16(): |
| pytest.skip("Float 16 is not implemented in Rust") |
| |
| if type(datatype) is pa.lib.DenseUnionType or type(datatype) is pa.lib.SparseUnionType: |
| pytest.skip("Union is not implemented in Python") |
| |
| a = pa.array([], datatype) |
| b = rust.round_trip_array(a) |
| b.validate(full=True) |
| assert a.to_pylist() == b.to_pylist() |
| assert a.type == b.type |
| del a |
| del b |
| |
| |
| @pytest.mark.parametrize("datatype", _supported_pyarrow_types, ids=str) |
| def test_empty_array_rust(datatype): |
| """ |
| Rust -> Python |
| """ |
| if type(datatype) is pa.lib.DenseUnionType or type(datatype) is pa.lib.SparseUnionType: |
| pytest.skip("Union is not implemented in Python") |
| |
| a = pa.array([], type=datatype) |
| b = rust.make_empty_array(datatype) |
| b.validate(full=True) |
| assert a.to_pylist() == b.to_pylist() |
| assert a.type == b.type |
| del a |
| del b |
| |
| |
| def test_binary_array(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array(["a", None, "bb", "ccc"], pa.binary()) |
| b = rust.round_trip_array(a) |
| b.validate(full=True) |
| assert a.to_pylist() == b.to_pylist() |
| assert a.type == b.type |
| del a |
| del b |
| |
| def test_fixed_len_binary_array(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array(["aaa", None, "bbb", "ccc"], pa.binary(3)) |
| b = rust.round_trip_array(a) |
| b.validate(full=True) |
| assert a.to_pylist() == b.to_pylist() |
| assert a.type == b.type |
| del a |
| del b |
| |
| def test_list_array(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array([[], None, [1, 2], [4, 5, 6]], pa.list_(pa.int64())) |
| b = rust.round_trip_array(a) |
| b.validate(full=True) |
| assert a.to_pylist() == b.to_pylist() |
| assert a.type == b.type |
| del a |
| del b |
| |
| def test_map_array(): |
| """ |
| Python -> Rust -> Python |
| """ |
| data = [ |
| [{'key': "a", 'value': 1}, {'key': "b", 'value': 2}], |
| [{'key': "c", 'value': 3}, {'key': "d", 'value': 4}] |
| ] |
| a = pa.array(data, pa.map_(pa.string(), pa.int32())) |
| b = rust.round_trip_array(a) |
| b.validate(full=True) |
| assert a.to_pylist() == b.to_pylist() |
| assert a.type == b.type |
| del a |
| del b |
| |
| def test_fixed_len_list_array(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array([[1, 2], None, [3, 4], [5, 6]], pa.list_(pa.int64(), 2)) |
| b = rust.round_trip_array(a) |
| b.validate(full=True) |
| assert a.to_pylist() == b.to_pylist() |
| assert a.type == b.type |
| del a |
| del b |
| |
| def test_timestamp_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| data = [ |
| None, |
| datetime.datetime(2021, 1, 1, 1, 1, 1, 1), |
| datetime.datetime(2020, 3, 9, 1, 1, 1, 1), |
| ] |
| a = pa.array(data, pa.timestamp("us")) |
| b = rust.concatenate(a) |
| expected = pa.array(data + data, pa.timestamp("us")) |
| assert b == expected |
| del a |
| del b |
| del expected |
| |
| |
| def test_timestamp_tz_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| tzinfo = pytz.timezone("America/New_York") |
| pyarrow_type = pa.timestamp("us", tz="America/New_York") |
| data = [ |
| None, |
| datetime.datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=tzinfo), |
| datetime.datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=tzinfo), |
| ] |
| a = pa.array(data, type=pyarrow_type) |
| b = rust.concatenate(a) |
| expected = pa.array(data * 2, type=pyarrow_type) |
| assert b == expected |
| del a |
| del b |
| del expected |
| |
| |
| def test_decimal_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| data = [ |
| round(decimal.Decimal(123.45), 2), |
| round(decimal.Decimal(-123.45), 2), |
| None |
| ] |
| a = pa.array(data, pa.decimal128(6, 2)) |
| b = rust.round_trip_array(a) |
| assert a == b |
| del a |
| del b |
| |
| def test_dictionary_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| a = pa.array(["a", None, "b", None, "a"], type=pa.dictionary(pa.int8(), pa.string())) |
| b = rust.round_trip_array(a) |
| assert a == b |
| del a |
| del b |
| |
| def test_dense_union_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| xs = pa.array([5, 6, 7]) |
| ys = pa.array([False, True]) |
| types = pa.array([0, 1, 1, 0, 0], type=pa.int8()) |
| offsets = pa.array([0, 0, 1, 1, 2], type=pa.int32()) |
| a = pa.UnionArray.from_dense(types, offsets, [xs, ys]) |
| |
| b = rust.round_trip_array(a) |
| assert a == b |
| del a |
| del b |
| |
| def test_sparse_union_python(): |
| """ |
| Python -> Rust -> Python |
| """ |
| xs = pa.array([5, 6, 7]) |
| ys = pa.array([False, False, True]) |
| types = pa.array([0, 1, 1], type=pa.int8()) |
| a = pa.UnionArray.from_sparse(types, [xs, ys]) |
| |
| b = rust.round_trip_array(a) |
| assert a == b |
| del a |
| del b |
| |
| def test_tensor_array(): |
| tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3]) |
| inner = pa.array([float(x) for x in range(1, 7)] + [None] * 12, pa.float32()) |
| storage = pa.FixedSizeListArray.from_arrays(inner, 6) |
| f32_array = pa.ExtensionArray.from_storage(tensor_type, storage) |
| |
| # Round-tripping as an array gives back storage type, because arrow-rs has |
| # no notion of extension types. |
| b = rust.round_trip_array(f32_array) |
| assert b == f32_array.storage |
| |
| batch = pa.record_batch([f32_array], ["tensor"], metadata={b'key1': b'value1'}) |
| b = rust.round_trip_record_batch(batch) |
| assert b == batch |
| assert b.schema == batch.schema |
| assert b.schema.metadata == batch.schema.metadata |
| |
| del b |
| |
| |
| def test_empty_recordbatch_with_row_count(): |
| """ |
| A pyarrow.RecordBatch with no columns but with `num_rows` set. |
| |
| `datafusion-python` gets this as the result of a `count(*)` query. |
| """ |
| |
| # Create an empty schema with no fields |
| batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}).select([]) |
| num_rows = 4 |
| assert batch.num_rows == num_rows |
| assert batch.num_columns == 0 |
| |
| b = rust.round_trip_record_batch(batch) |
| assert b == batch |
| assert b.schema == batch.schema |
| assert b.schema.metadata == batch.schema.metadata |
| |
| assert b.num_rows == batch.num_rows |
| |
| del b |
| |
| def test_record_batch_reader(): |
| """ |
| Python -> Rust -> Python |
| """ |
| schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) |
| batches = [ |
| pa.record_batch([[[1], [2, 42]]], schema), |
| pa.record_batch([[None, [], [5, 6]]], schema), |
| ] |
| a = pa.RecordBatchReader.from_batches(schema, batches) |
| b = rust.round_trip_record_batch_reader(a) |
| |
| assert b.schema == schema |
| assert b.schema.metadata == schema.metadata |
| got_batches = list(b) |
| assert got_batches == batches |
| |
| # Also try the boxed reader variant |
| a = pa.RecordBatchReader.from_batches(schema, batches) |
| b = rust.boxed_reader_roundtrip(a) |
| assert b.schema == schema |
| assert b.schema.metadata == schema.metadata |
| got_batches = list(b) |
| assert got_batches == batches |
| |
| @pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") |
| def test_record_batch_reader_pycapsule(): |
| """ |
| Python -> Rust -> Python |
| """ |
| schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) |
| batches = [ |
| pa.record_batch([[[1], [2, 42]]], schema), |
| pa.record_batch([[None, [], [5, 6]]], schema), |
| ] |
| a = pa.RecordBatchReader.from_batches(schema, batches) |
| wrapped = StreamWrapper(a) |
| b = rust.round_trip_record_batch_reader(wrapped) |
| |
| assert b.schema == schema |
| assert b.schema.metadata == schema.metadata |
| got_batches = list(b) |
| assert got_batches == batches |
| |
| # Also try the boxed reader variant |
| a = pa.RecordBatchReader.from_batches(schema, batches) |
| wrapped = StreamWrapper(a) |
| b = rust.boxed_reader_roundtrip(wrapped) |
| assert b.schema == schema |
| assert b.schema.metadata == schema.metadata |
| got_batches = list(b) |
| assert got_batches == batches |
| |
| |
| def test_record_batch_reader_error(): |
| schema = pa.schema([('ints', pa.list_(pa.int32()))]) |
| |
| def iter_batches(): |
| yield pa.record_batch([[[1], [2, 42]]], schema) |
| raise ValueError("test error") |
| |
| reader = pa.RecordBatchReader.from_batches(schema, iter_batches()) |
| |
| with pytest.raises(ValueError, match="test error"): |
| rust.reader_return_errors(reader) |
| |
| # Due to a long-standing oversight, PyArrow allows binary values in schema |
| # metadata that are not valid UTF-8. This is not allowed in Rust, but we |
| # make sure we error and not panic here. |
| schema = schema.with_metadata({"key": b"\xff"}) |
| reader = pa.RecordBatchReader.from_batches(schema, iter_batches()) |
| with pytest.raises(ValueError, match="invalid utf-8"): |
| rust.round_trip_record_batch_reader(reader) |
| |
| |
| @pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") |
| def test_record_batch_pycapsule(): |
| """ |
| Python -> Rust -> Python |
| """ |
| schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) |
| batch = pa.record_batch([[[1], [2, 42]]], schema) |
| wrapped = StreamWrapper(batch) |
| b = rust.round_trip_record_batch_reader(wrapped) |
| new_table = b.read_all() |
| new_batches = new_table.to_batches() |
| |
| assert len(new_batches) == 1 |
| new_batch = new_batches[0] |
| |
| assert batch == new_batch |
| assert batch.schema == new_batch.schema |
| |
| |
| @pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14") |
| def test_table_pycapsule(): |
| """ |
| Python -> Rust -> Python |
| """ |
| schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) |
| batches = [ |
| pa.record_batch([[[1], [2, 42]]], schema), |
| pa.record_batch([[None, [], [5, 6]]], schema), |
| ] |
| table = pa.Table.from_batches(batches) |
| wrapped = StreamWrapper(table) |
| b = rust.round_trip_record_batch_reader(wrapped) |
| new_table = b.read_all() |
| |
| assert table.schema == new_table.schema |
| assert table == new_table |
| assert len(table.to_batches()) == len(new_table.to_batches()) |
| |
| |
| def test_reject_other_classes(): |
| # Arbitrary type that is not a PyArrow type |
| not_pyarrow = ["hello"] |
| |
| with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Array, got builtins.list"): |
| rust.round_trip_array(not_pyarrow) |
| |
| with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Schema, got builtins.list"): |
| rust.round_trip_schema(not_pyarrow) |
| |
| with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Field, got builtins.list"): |
| rust.round_trip_field(not_pyarrow) |
| |
| with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.DataType, got builtins.list"): |
| rust.round_trip_type(not_pyarrow) |
| |
| with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatch, got builtins.list"): |
| rust.round_trip_record_batch(not_pyarrow) |
| |
| with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatchReader, got builtins.list"): |
| rust.round_trip_record_batch_reader(not_pyarrow) |