blob: fb949ca93ca5216873afdee96574cacfdd4c7040 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pickle
import weakref
import pyarrow as pa
import pytest
class UuidType(pa.ExtensionType):
def __init__(self):
pa.ExtensionType.__init__(self, pa.binary(16))
def __reduce__(self):
return UuidType, ()
class ParamExtType(pa.ExtensionType):
def __init__(self, width):
self.width = width
pa.ExtensionType.__init__(self, pa.binary(width))
def __reduce__(self):
return ParamExtType, (self.width,)
def ipc_write_batch(batch):
stream = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
writer.close()
return stream.getvalue()
def ipc_read_batch(buf):
reader = pa.RecordBatchStreamReader(buf)
return reader.read_next_batch()
def test_ext_type_basics():
ty = UuidType()
assert ty.extension_name == "arrow.py_extension_type"
def test_ext_type__lifetime():
ty = UuidType()
wr = weakref.ref(ty)
del ty
assert wr() is None
def test_ext_type__storage_type():
ty = UuidType()
assert ty.storage_type == pa.binary(16)
assert ty.__class__ is UuidType
ty = ParamExtType(5)
assert ty.storage_type == pa.binary(5)
assert ty.__class__ is ParamExtType
def test_uuid_type_pickle():
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
ty = UuidType()
ser = pickle.dumps(ty, protocol=proto)
del ty
ty = pickle.loads(ser)
wr = weakref.ref(ty)
assert ty.extension_name == "arrow.py_extension_type"
del ty
assert wr() is None
def test_ext_type_equality():
a = ParamExtType(5)
b = ParamExtType(6)
c = ParamExtType(6)
assert a != b
assert b == c
d = UuidType()
e = UuidType()
assert a != d
assert d == e
def test_ext_array_basics():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
arr.validate()
assert arr.type is ty
assert arr.storage.equals(storage)
def test_ext_array_lifetime():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
refs = [weakref.ref(ty), weakref.ref(arr), weakref.ref(storage)]
del ty, storage, arr
for ref in refs:
assert ref() is None
def test_ext_array_errors():
ty = ParamExtType(4)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
with pytest.raises(TypeError, match="Incompatible storage type"):
pa.ExtensionArray.from_storage(ty, storage)
def test_ext_array_equality():
storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
storage3 = pa.array([], type=pa.binary(16))
ty1 = UuidType()
ty2 = ParamExtType(16)
a = pa.ExtensionArray.from_storage(ty1, storage1)
b = pa.ExtensionArray.from_storage(ty1, storage2)
assert a.equals(b)
c = pa.ExtensionArray.from_storage(ty1, storage3)
assert not a.equals(c)
d = pa.ExtensionArray.from_storage(ty2, storage1)
assert not a.equals(d)
e = pa.ExtensionArray.from_storage(ty2, storage2)
assert d.equals(e)
f = pa.ExtensionArray.from_storage(ty2, storage3)
assert not d.equals(f)
def test_ext_array_pickling():
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
ser = pickle.dumps(arr, protocol=proto)
del ty, storage, arr
arr = pickle.loads(ser)
arr.validate()
assert isinstance(arr, pa.ExtensionArray)
assert arr.type == ParamExtType(3)
assert arr.type.storage_type == pa.binary(3)
assert arr.storage.type == pa.binary(3)
assert arr.storage.to_pylist() == [b"foo", b"bar"]
def example_batch():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
return pa.RecordBatch.from_arrays([arr], ["exts"])
def check_example_batch(batch):
arr = batch.column(0)
assert isinstance(arr, pa.ExtensionArray)
assert arr.type.storage_type == pa.binary(3)
assert arr.storage.to_pylist() == [b"foo", b"bar"]
return arr
def test_ipc():
batch = example_batch()
buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)
arr = check_example_batch(batch)
assert arr.type == ParamExtType(3)
def test_ipc_unknown_type():
batch = example_batch()
buf = ipc_write_batch(batch)
del batch
orig_type = ParamExtType
try:
# Simulate the original Python type being unavailable.
# Deserialization should not fail but return a placeholder type.
del globals()['ParamExtType']
batch = ipc_read_batch(buf)
arr = check_example_batch(batch)
assert isinstance(arr.type, pa.UnknownExtensionType)
# Can be serialized again
buf2 = ipc_write_batch(batch)
del batch, arr
batch = ipc_read_batch(buf2)
arr = check_example_batch(batch)
assert isinstance(arr.type, pa.UnknownExtensionType)
finally:
globals()['ParamExtType'] = orig_type
# Deserialize again with the type restored
batch = ipc_read_batch(buf2)
arr = check_example_batch(batch)
assert arr.type == ParamExtType(3)