blob: 22983e77e998c602d3fa27e561ab9ed1691e1e38 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import division
import pytest
import collections
import datetime
import os
import pickle
import subprocess
import string
import sys
import pyarrow as pa
import numpy as np
import pyarrow.tests.util as test_util
try:
import torch
except ImportError:
torch = None
# Blacklist the module in case `import torch` is costly before
# failing (ARROW-2071)
sys.modules['torch'] = None
def assert_equal(obj1, obj2):
if torch is not None and torch.is_tensor(obj1) and torch.is_tensor(obj2):
if obj1.is_sparse:
obj1 = obj1.to_dense()
if obj2.is_sparse:
obj2 = obj2.to_dense()
assert torch.equal(obj1, obj2)
return
module_numpy = (type(obj1).__module__ == np.__name__ or
type(obj2).__module__ == np.__name__)
if module_numpy:
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or
(hasattr(obj2, "shape") and obj2.shape == ()))
if empty_shape:
# This is a special case because currently np.testing.assert_equal
# fails because we do not properly handle different numerical
# types.
assert obj1 == obj2, ("Objects {} and {} are "
"different.".format(obj1, obj2))
else:
np.testing.assert_equal(obj1, obj2)
elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
special_keys = ["_pytype_"]
assert (set(list(obj1.__dict__.keys()) + special_keys) ==
set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} "
"and {} are "
"different."
.format(
obj1,
obj2))
try:
# Workaround to make comparison of OrderedDicts work on Python 2.7
if obj1 == obj2:
return
except Exception:
pass
if obj1.__dict__ == {}:
print("WARNING: Empty dict in ", obj1)
for key in obj1.__dict__.keys():
if key not in special_keys:
assert_equal(obj1.__dict__[key], obj2.__dict__[key])
elif type(obj1) is dict or type(obj2) is dict:
assert_equal(obj1.keys(), obj2.keys())
for key in obj1.keys():
assert_equal(obj1[key], obj2[key])
elif type(obj1) is list or type(obj2) is list:
assert len(obj1) == len(obj2), ("Objects {} and {} are lists with "
"different lengths."
.format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif type(obj1) is tuple or type(obj2) is tuple:
assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with "
"different lengths."
.format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif (pa.lib.is_named_tuple(type(obj1)) or
pa.lib.is_named_tuple(type(obj2))):
assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples "
"with different lengths."
.format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif isinstance(obj1, pa.Array) and isinstance(obj2, pa.Array):
assert obj1.equals(obj2)
elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor):
assert obj1.equals(obj2)
elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor):
assert obj1.equals(obj2)
elif isinstance(obj1, pa.RecordBatch) and isinstance(obj2, pa.RecordBatch):
assert obj1.equals(obj2)
elif isinstance(obj1, pa.Table) and isinstance(obj2, pa.Table):
assert obj1.equals(obj2)
else:
assert type(obj1) == type(obj2) and obj1 == obj2, \
"Objects {} and {} are different.".format(obj1, obj2)
PRIMITIVE_OBJECTS = [
0, 0.0, 0.9, 1 << 62, 1 << 999,
[1 << 100, [1 << 100]], "a", string.printable, "\u262F",
"hello world", u"hello world", u"\xff\xfe\x9c\x001\x000\x00",
None, True, False, [], (), {}, {(1, 2): 1}, {(): 2},
[1, "hello", 3.0], u"\u262F", 42.0, (1.0, "hi"),
[1, 2, 3, None], [(None,), 3, 1.0], ["h", "e", "l", "l", "o", None],
(None, None), ("hello", None), (True, False),
{True: "hello", False: "world"}, {"hello": "world", 1: 42, 2.5: 45},
{"hello": set([2, 3]), "world": set([42.0]), "this": None},
np.int8(3), np.int32(4), np.int64(5),
np.uint8(3), np.uint32(4), np.uint64(5),
np.float16(1.9), np.float32(1.9),
np.float64(1.9), np.zeros([8, 20]),
np.random.normal(size=[17, 10]), np.array(["hi", 3]),
np.array(["hi", 3], dtype=object),
np.random.normal(size=[15, 13]).T
]
if sys.version_info >= (3, 0):
PRIMITIVE_OBJECTS += [0, np.array([["hi", u"hi"], [1.3, 1]])]
else:
PRIMITIVE_OBJECTS += [long(42), long(1 << 62), long(0), # noqa
np.array([["hi", u"hi"],
[1.3, long(1)]])] # noqa
COMPLEX_OBJECTS = [
[[[[[[[[[[[[]]]]]]]]]]]],
{"obj{}".format(i): np.random.normal(size=[4, 4]) for i in range(5)},
# {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
# (): {(): {}}}}}}}}}}}}},
((((((((((),),),),),),),),),),
{"a": {"b": {"c": {"d": {}}}}},
]
class Foo(object):
def __init__(self, value=0):
self.value = value
def __hash__(self):
return hash(self.value)
def __eq__(self, other):
return other.value == self.value
class Bar(object):
def __init__(self):
for i, val in enumerate(COMPLEX_OBJECTS):
setattr(self, "field{}".format(i), val)
class Baz(object):
def __init__(self):
self.foo = Foo()
self.bar = Bar()
def method(self, arg):
pass
class Qux(object):
def __init__(self):
self.objs = [Foo(1), Foo(42)]
class SubQux(Qux):
def __init__(self):
Qux.__init__(self)
class SubQuxPickle(Qux):
def __init__(self):
Qux.__init__(self)
class CustomError(Exception):
pass
Point = collections.namedtuple("Point", ["x", "y"])
NamedTupleExample = collections.namedtuple(
"Example", "field1, field2, field3, field4, field5")
CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22),
Foo(), Bar(), Baz(), Qux(), SubQux(), SubQuxPickle(),
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]),
collections.OrderedDict([("hello", 1), ("world", 2)]),
collections.deque([1, 2, 3, "a", "b", "c", 3.5]),
collections.Counter([1, 1, 1, 2, 2, 3, "a", "b"])]
def make_serialization_context():
context = pa.default_serialization_context()
context.register_type(Foo, "Foo")
context.register_type(Bar, "Bar")
context.register_type(Baz, "Baz")
context.register_type(Qux, "Quz")
context.register_type(SubQux, "SubQux")
context.register_type(SubQuxPickle, "SubQuxPickle", pickle=True)
context.register_type(Exception, "Exception")
context.register_type(CustomError, "CustomError")
context.register_type(Point, "Point")
context.register_type(NamedTupleExample, "NamedTupleExample")
return context
global_serialization_context = make_serialization_context()
def serialization_roundtrip(value, scratch_buffer,
context=global_serialization_context):
writer = pa.FixedSizeBufferWriter(scratch_buffer)
pa.serialize_to(value, writer, context=context)
reader = pa.BufferReader(scratch_buffer)
result = pa.deserialize_from(reader, None, context=context)
assert_equal(value, result)
_check_component_roundtrip(value, context=context)
def _check_component_roundtrip(value, context=global_serialization_context):
# Test to/from components
serialized = pa.serialize(value, context=context)
components = serialized.to_components()
from_comp = pa.SerializedPyObject.from_components(components)
recons = from_comp.deserialize(context=context)
assert_equal(value, recons)
@pytest.yield_fixture(scope='session')
def large_buffer(size=32*1024*1024):
return pa.allocate_buffer(size)
def large_memory_map(tmpdir_factory, size=100*1024*1024):
path = (tmpdir_factory.mktemp('data')
.join('pyarrow-serialization-tmp-file').strpath)
# Create a large memory mapped file
with open(path, 'wb') as f:
f.write(np.random.randint(0, 256, size=size)
.astype('u1')
.tobytes()
[:size])
return path
def test_clone():
context = pa.SerializationContext()
class Foo(object):
pass
def custom_serializer(obj):
return 0
def custom_deserializer(serialized_obj):
return (serialized_obj, 'a')
context.register_type(Foo, 'Foo', custom_serializer=custom_serializer,
custom_deserializer=custom_deserializer)
new_context = context.clone()
f = Foo()
serialized = pa.serialize(f, context=context)
deserialized = serialized.deserialize(context=context)
assert deserialized == (0, 'a')
serialized = pa.serialize(f, context=new_context)
deserialized = serialized.deserialize(context=new_context)
assert deserialized == (0, 'a')
def test_primitive_serialization(large_buffer):
for obj in PRIMITIVE_OBJECTS:
serialization_roundtrip(obj, large_buffer)
def test_integer_limits(large_buffer):
# Check that Numpy scalars can be represented up to their limit values
# (except np.uint64 which is limited to 2**63 - 1)
for dt in [np.int8, np.int64, np.int32, np.int64,
np.uint8, np.uint64, np.uint32, np.uint64]:
scal = dt(np.iinfo(dt).min)
serialization_roundtrip(scal, large_buffer)
if dt is not np.uint64:
scal = dt(np.iinfo(dt).max)
serialization_roundtrip(scal, large_buffer)
else:
scal = dt(2**63 - 1)
serialization_roundtrip(scal, large_buffer)
for v in (2**63, 2**64 - 1):
scal = dt(v)
with pytest.raises(pa.ArrowInvalid):
pa.serialize(scal)
def test_serialize_to_buffer():
for nthreads in [1, 4]:
for value in COMPLEX_OBJECTS:
buf = pa.serialize(value).to_buffer(nthreads=nthreads)
result = pa.deserialize(buf)
assert_equal(value, result)
def test_complex_serialization(large_buffer):
for obj in COMPLEX_OBJECTS:
serialization_roundtrip(obj, large_buffer)
def test_custom_serialization(large_buffer):
for obj in CUSTOM_OBJECTS:
serialization_roundtrip(obj, large_buffer)
def test_default_dict_serialization(large_buffer):
pytest.importorskip("cloudpickle")
obj = collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)])
serialization_roundtrip(obj, large_buffer)
def test_numpy_serialization(large_buffer):
for t in ["bool", "int8", "uint8", "int16", "uint16", "int32",
"uint32", "float16", "float32", "float64", "<U1", "<U2", "<U3",
"<U4", "|S1", "|S2", "|S3", "|S4", "|O"]:
obj = np.random.randint(0, 10, size=(100, 100)).astype(t)
serialization_roundtrip(obj, large_buffer)
obj = obj[1:99, 10:90]
serialization_roundtrip(obj, large_buffer)
def test_datetime_serialization(large_buffer):
data = [
# Principia Mathematica published
datetime.datetime(year=1687, month=7, day=5),
# Some random date
datetime.datetime(year=1911, month=6, day=3, hour=4,
minute=55, second=44),
# End of WWI
datetime.datetime(year=1918, month=11, day=11),
# Beginning of UNIX time
datetime.datetime(year=1970, month=1, day=1),
# The Berlin wall falls
datetime.datetime(year=1989, month=11, day=9),
# Another random date
datetime.datetime(year=2011, month=6, day=3, hour=4,
minute=0, second=3),
# Another random date
datetime.datetime(year=1970, month=1, day=3, hour=4,
minute=0, second=0)
]
for d in data:
serialization_roundtrip(d, large_buffer)
def test_torch_serialization(large_buffer):
pytest.importorskip("torch")
serialization_context = pa.default_serialization_context()
pa.register_torch_serialization_handlers(serialization_context)
# Dense tensors:
# These are the only types that are supported for the
# PyTorch to NumPy conversion
for t in ["float32", "float64",
"uint8", "int16", "int32", "int64"]:
obj = torch.from_numpy(np.random.randn(1000).astype(t))
serialization_roundtrip(obj, large_buffer,
context=serialization_context)
tensor_requiring_grad = torch.randn(10, 10, requires_grad=True)
serialization_roundtrip(tensor_requiring_grad, large_buffer,
context=serialization_context)
# Sparse tensors:
# These are the only types that are supported for the
# PyTorch to NumPy conversion
for t in ["float32", "float64",
"uint8", "int16", "int32", "int64"]:
i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
v = torch.from_numpy(np.array([3, 4, 5]).astype(t))
obj = torch.sparse_coo_tensor(i.t(), v, torch.Size([2, 3]))
serialization_roundtrip(obj, large_buffer,
context=serialization_context)
@pytest.mark.skipif(not torch or not torch.cuda.is_available(),
reason="requires pytorch with CUDA")
def test_torch_cuda():
# ARROW-2920: This used to segfault if torch is not imported
# before pyarrow
# Note that this test will only catch the issue if it is run
# with a pyarrow that has been built in the manylinux1 environment
torch.nn.Conv2d(64, 2, kernel_size=3, stride=1,
padding=1, bias=False).cuda()
def test_numpy_immutable(large_buffer):
obj = np.zeros([10])
writer = pa.FixedSizeBufferWriter(large_buffer)
pa.serialize_to(obj, writer, global_serialization_context)
reader = pa.BufferReader(large_buffer)
result = pa.deserialize_from(reader, None, global_serialization_context)
with pytest.raises(ValueError):
result[0] = 1.0
def test_numpy_base_object(tmpdir):
# ARROW-2040: deserialized Numpy array should keep a reference to the
# owner of its memory
path = os.path.join(str(tmpdir), 'zzz.bin')
data = np.arange(12, dtype=np.int32)
with open(path, 'wb') as f:
f.write(pa.serialize(data).to_buffer())
serialized = pa.read_serialized(pa.OSFile(path))
result = serialized.deserialize()
assert_equal(result, data)
serialized = None
assert_equal(result, data)
assert result.base is not None
# see https://issues.apache.org/jira/browse/ARROW-1695
def test_serialization_callback_numpy():
class DummyClass(object):
pass
def serialize_dummy_class(obj):
x = np.zeros(4)
return x
def deserialize_dummy_class(serialized_obj):
return serialized_obj
context = pa.default_serialization_context()
context.register_type(DummyClass, "DummyClass",
custom_serializer=serialize_dummy_class,
custom_deserializer=deserialize_dummy_class)
pa.serialize(DummyClass(), context=context)
def test_numpy_subclass_serialization():
# Check that we can properly serialize subclasses of np.ndarray.
class CustomNDArray(np.ndarray):
def __new__(cls, input_array):
array = np.asarray(input_array).view(cls)
return array
def serializer(obj):
return {'numpy': obj.view(np.ndarray)}
def deserializer(data):
array = data['numpy'].view(CustomNDArray)
return array
context = pa.default_serialization_context()
context.register_type(CustomNDArray, 'CustomNDArray',
custom_serializer=serializer,
custom_deserializer=deserializer)
x = CustomNDArray(np.zeros(3))
serialized = pa.serialize(x, context=context).to_buffer()
new_x = pa.deserialize(serialized, context=context)
assert type(new_x) == CustomNDArray
assert np.alltrue(new_x.view(np.ndarray) == np.zeros(3))
def test_numpy_matrix_serialization(tmpdir):
class CustomType(object):
def __init__(self, val):
self.val = val
path = os.path.join(str(tmpdir), 'pyarrow_npmatrix_serialization_test.bin')
array = np.random.randint(low=-1, high=1, size=(2, 2))
for data_type in [str, int, float, CustomType]:
matrix = np.matrix(array.astype(data_type))
with open(path, 'wb') as f:
f.write(pa.serialize(matrix).to_buffer())
serialized = pa.read_serialized(pa.OSFile(path))
result = serialized.deserialize()
assert_equal(result, matrix)
assert_equal(result.dtype, matrix.dtype)
serialized = None
assert_equal(result, matrix)
assert result.base is not None
def test_pyarrow_objects_serialization(large_buffer):
# NOTE: We have to put these objects inside,
# or it will affect 'test_total_bytes_allocated'.
pyarrow_objects = [
pa.array([1, 2, 3, 4]), pa.array(['1', u'never U+1F631', '',
u"233 * U+1F600"]),
pa.array([1, None, 2, 3]),
pa.Tensor.from_numpy(np.random.rand(2, 3, 4)),
pa.RecordBatch.from_arrays(
[pa.array([1, None, 2, 3]),
pa.array(['1', u'never U+1F631', '', u"233 * u1F600"])],
['a', 'b']),
pa.Table.from_arrays([pa.array([1, None, 2, 3]),
pa.array(['1', u'never U+1F631', '',
u"233 * u1F600"])],
['a', 'b'])
]
for obj in pyarrow_objects:
serialization_roundtrip(obj, large_buffer)
def test_buffer_serialization():
class BufferClass(object):
pass
def serialize_buffer_class(obj):
return pa.py_buffer(b"hello")
def deserialize_buffer_class(serialized_obj):
return serialized_obj
context = pa.default_serialization_context()
context.register_type(
BufferClass, "BufferClass",
custom_serializer=serialize_buffer_class,
custom_deserializer=deserialize_buffer_class)
b = pa.serialize(BufferClass(), context=context).to_buffer()
assert pa.deserialize(b, context=context).to_pybytes() == b"hello"
@pytest.mark.skip(reason="extensive memory requirements")
def test_arrow_limits(self):
def huge_memory_map(temp_dir):
return large_memory_map(temp_dir, 100 * 1024 * 1024 * 1024)
with pa.memory_map(huge_memory_map, mode="r+") as mmap:
# Test that objects that are too large for Arrow throw a Python
# exception. These tests give out of memory errors on Travis and need
# to be run on a machine with lots of RAM.
x = 2 ** 29 * [1.0]
serialization_roundtrip(x, mmap)
del x
x = 2 ** 29 * ["s"]
serialization_roundtrip(x, mmap)
del x
x = 2 ** 29 * [["1"], 2, 3, [{"s": 4}]]
serialization_roundtrip(x, mmap)
del x
x = 2 ** 29 * [{"s": 1}] + 2 ** 29 * [1.0]
serialization_roundtrip(x, mmap)
del x
x = np.zeros(2 ** 25)
serialization_roundtrip(x, mmap)
del x
x = [np.zeros(2 ** 18) for _ in range(2 ** 7)]
serialization_roundtrip(x, mmap)
del x
def test_serialization_callback_error():
class TempClass(object):
pass
# Pass a SerializationContext into serialize, but TempClass
# is not registered
serialization_context = pa.SerializationContext()
val = TempClass()
with pytest.raises(pa.SerializationCallbackError) as err:
serialized_object = pa.serialize(val, serialization_context)
assert err.value.example_object == val
serialization_context.register_type(TempClass, "TempClass")
serialized_object = pa.serialize(TempClass(), serialization_context)
deserialization_context = pa.SerializationContext()
# Pass a Serialization Context into deserialize, but TempClass
# is not registered
with pytest.raises(pa.DeserializationCallbackError) as err:
serialized_object.deserialize(deserialization_context)
assert err.value.type_id == "TempClass"
class TempClass2(object):
pass
# Make sure that we receive an error when we use an inappropriate value for
# the type_id argument.
with pytest.raises(TypeError):
serialization_context.register_type(TempClass2, 1)
def test_fallback_to_subclasses():
class SubFoo(Foo):
def __init__(self):
Foo.__init__(self)
# should be able to serialize/deserialize an instance
# if a base class has been registered
serialization_context = pa.SerializationContext()
serialization_context.register_type(Foo, "Foo")
subfoo = SubFoo()
# should fallbact to Foo serializer
serialized_object = pa.serialize(subfoo, serialization_context)
reconstructed_object = serialized_object.deserialize(
serialization_context
)
assert type(reconstructed_object) == Foo
class Serializable(object):
pass
def serialize_serializable(obj):
return {"type": type(obj), "data": obj.__dict__}
def deserialize_serializable(obj):
val = obj["type"].__new__(obj["type"])
val.__dict__.update(obj["data"])
return val
class SerializableClass(Serializable):
def __init__(self):
self.value = 3
def test_serialize_subclasses():
# This test shows how subclasses can be handled in an idiomatic way
# by having only a serializer for the base class
# This technique should however be used with care, since pickling
# type(obj) with couldpickle will include the full class definition
# in the serialized representation.
# This means the class definition is part of every instance of the
# object, which in general is not desirable; registering all subclasses
# with register_type will result in faster and more memory
# efficient serialization.
context = pa.default_serialization_context()
context.register_type(
Serializable, "Serializable",
custom_serializer=serialize_serializable,
custom_deserializer=deserialize_serializable)
a = SerializableClass()
serialized = pa.serialize(a, context=context)
deserialized = serialized.deserialize(context=context)
assert type(deserialized).__name__ == SerializableClass.__name__
assert deserialized.value == 3
def test_serialize_to_components_invalid_cases():
buf = pa.py_buffer(b'hello')
components = {
'num_tensors': 0,
'num_ndarrays': 0,
'num_buffers': 1,
'data': [buf]
}
with pytest.raises(pa.ArrowInvalid):
pa.deserialize_components(components)
components = {
'num_tensors': 0,
'num_ndarrays': 1,
'num_buffers': 0,
'data': [buf, buf]
}
with pytest.raises(pa.ArrowInvalid):
pa.deserialize_components(components)
def test_deserialize_components_in_different_process():
arr = pa.array([1, 2, 5, 6], type=pa.int8())
ser = pa.serialize(arr)
data = pickle.dumps(ser.to_components(), protocol=-1)
code = """if 1:
import pickle
import pyarrow as pa
data = {0!r}
components = pickle.loads(data)
arr = pa.deserialize_components(components)
assert arr.to_pylist() == [1, 2, 5, 6], arr
""".format(data)
subprocess_env = test_util.get_modified_env_with_pythonpath()
print("** sys.path =", sys.path)
print("** setting PYTHONPATH to:", subprocess_env['PYTHONPATH'])
subprocess.check_call(["python", "-c", code], env=subprocess_env)
def test_serialize_read_concatenated_records():
# ARROW-1996 -- see stream alignment work in ARROW-2840, ARROW-3212
f = pa.BufferOutputStream()
pa.serialize_to(12, f)
pa.serialize_to(23, f)
buf = f.getvalue()
f = pa.BufferReader(buf)
pa.read_serialized(f).deserialize()
pa.read_serialized(f).deserialize()
@pytest.mark.skipif(os.name == 'nt', reason="deserialize_regex not pickleable")
def test_deserialize_in_different_process():
from multiprocessing import Process, Queue
import re
regex = re.compile(r"\d+\.\d*")
serialization_context = pa.SerializationContext()
serialization_context.register_type(type(regex), "Regex", pickle=True)
serialized = pa.serialize(regex, serialization_context)
serialized_bytes = serialized.to_buffer().to_pybytes()
def deserialize_regex(serialized, q):
import pyarrow as pa
q.put(pa.deserialize(serialized))
q = Queue()
p = Process(target=deserialize_regex, args=(serialized_bytes, q))
p.start()
assert q.get().pattern == regex.pattern
p.join()
def test_deserialize_buffer_in_different_process():
import tempfile
f = tempfile.NamedTemporaryFile(delete=False)
b = pa.serialize(pa.py_buffer(b'hello')).to_buffer()
f.write(b.to_pybytes())
f.close()
subprocess_env = test_util.get_modified_env_with_pythonpath()
dir_path = os.path.dirname(os.path.realpath(__file__))
python_file = os.path.join(dir_path, 'deserialize_buffer.py')
subprocess.check_call([sys.executable, python_file, f.name],
env=subprocess_env)
def test_set_pickle():
# Use a custom type to trigger pickling.
class Foo(object):
pass
context = pa.SerializationContext()
context.register_type(Foo, 'Foo', pickle=True)
test_object = Foo()
# Define a custom serializer and deserializer to use in place of pickle.
def dumps1(obj):
return b'custom'
def loads1(serialized_obj):
return serialized_obj + b' serialization 1'
# Test that setting a custom pickler changes the behavior.
context.set_pickle(dumps1, loads1)
serialized = pa.serialize(test_object, context=context).to_buffer()
deserialized = pa.deserialize(serialized.to_pybytes(), context=context)
assert deserialized == b'custom serialization 1'
# Define another custom serializer and deserializer.
def dumps2(obj):
return b'custom'
def loads2(serialized_obj):
return serialized_obj + b' serialization 2'
# Test that setting another custom pickler changes the behavior again.
context.set_pickle(dumps2, loads2)
serialized = pa.serialize(test_object, context=context).to_buffer()
deserialized = pa.deserialize(serialized.to_pybytes(), context=context)
assert deserialized == b'custom serialization 2'
@pytest.mark.skipif(sys.version_info < (3, 6), reason="need Python 3.6")
def test_path_objects(tmpdir):
# Test compatibility with PEP 519 path-like objects
import pathlib
p = pathlib.Path(tmpdir) / 'zzz.bin'
obj = 1234
pa.serialize_to(obj, p)
res = pa.deserialize_from(p, None)
assert res == obj
def test_tensor_alignment():
# Deserialized numpy arrays should be 64-byte aligned.
x = np.random.normal(size=(10, 20, 30))
y = pa.deserialize(pa.serialize(x).to_buffer())
assert y.ctypes.data % 64 == 0
xs = [np.random.normal(size=i) for i in range(100)]
ys = pa.deserialize(pa.serialize(xs).to_buffer())
for y in ys:
assert y.ctypes.data % 64 == 0
xs = [np.random.normal(size=i * (1,)) for i in range(20)]
ys = pa.deserialize(pa.serialize(xs).to_buffer())
for y in ys:
assert y.ctypes.data % 64 == 0
xs = [np.random.normal(size=i * (5,)) for i in range(1, 8)]
xs = [xs[i][(i + 1) * (slice(1, 3),)] for i in range(len(xs))]
ys = pa.deserialize(pa.serialize(xs).to_buffer())
for y in ys:
assert y.ctypes.data % 64 == 0
def test_serialization_determinism():
for obj in COMPLEX_OBJECTS:
buf1 = pa.serialize(obj).to_buffer()
buf2 = pa.serialize(obj).to_buffer()
assert buf1.to_pybytes() == buf2.to_pybytes()
def test_serialize_recursive_objects():
class ClassA(object):
pass
# Make a list that contains itself.
lst = []
lst.append(lst)
# Make an object that contains itself as a field.
a1 = ClassA()
a1.field = a1
# Make two objects that contain each other as fields.
a2 = ClassA()
a3 = ClassA()
a2.field = a3
a3.field = a2
# Make a dictionary that contains itself.
d1 = {}
d1["key"] = d1
# Make a numpy array that contains itself.
arr = np.array([None], dtype=object)
arr[0] = arr
# Create a list of recursive objects.
recursive_objects = [lst, a1, a2, a3, d1, arr]
# Check that exceptions are thrown when we serialize the recursive
# objects.
for obj in recursive_objects:
with pytest.raises(Exception):
pa.serialize(obj).deserialize()