blob: b19082ba323c1f43ea2bbd36d47b7d1f1faf20ba [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pickle
import pytest
from pyfory import Fory
try:
import numpy as np
except ImportError:
np = None
try:
import pandas as pd
except ImportError:
pd = None
def test_pickle_buffer_serialization():
fory = Fory(xlang=False, ref=False, strict=False)
data = b"Hello, PickleBuffer!"
pickle_buffer = pickle.PickleBuffer(data)
serialized = fory.serialize(pickle_buffer)
deserialized = fory.deserialize(serialized)
assert isinstance(deserialized, pickle.PickleBuffer)
assert bytes(deserialized.raw()) == data
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_numpy_out_of_band_serialization():
fory = Fory(xlang=False, ref=False, strict=False)
arr = np.arange(10000, dtype=np.float64)
buffer_objects = []
serialized = fory.serialize(arr, buffer_callback=buffer_objects.append)
buffers = [o.getbuffer() for o in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
np.testing.assert_array_equal(arr, deserialized)
@pytest.mark.skipif(pd is None, reason="Requires pandas")
def test_pandas_out_of_band_serialization():
fory = Fory(xlang=False, ref=False, strict=False)
df = pd.DataFrame(
{
"a": np.arange(1000, dtype=np.float64),
"b": np.arange(1000, dtype=np.int64),
"c": ["text"] * 1000,
}
)
buffer_objects = []
serialized = fory.serialize(df, buffer_callback=buffer_objects.append)
buffers = [o.getbuffer() for o in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
pd.testing.assert_frame_equal(df, deserialized)
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_numpy_multiple_arrays_out_of_band():
fory = Fory(xlang=False, ref=True, strict=False)
arr1 = np.arange(5000, dtype=np.float32)
arr2 = np.arange(3000, dtype=np.int32)
arr3 = np.arange(2000, dtype=np.float64)
data = [arr1, arr2, arr3]
buffer_objects = []
serialized = fory.serialize(data, buffer_callback=buffer_objects.append)
buffers = [o.getbuffer() for o in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
assert len(deserialized) == 3
np.testing.assert_array_equal(arr1, deserialized[0])
np.testing.assert_array_equal(arr2, deserialized[1])
np.testing.assert_array_equal(arr3, deserialized[2])
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_numpy_with_mixed_types():
fory = Fory(xlang=False, ref=True, strict=False)
arr = np.arange(1000, dtype=np.float64)
text = "some text"
number = 42
data = {"array": arr, "text": text, "number": number}
buffer_objects = []
serialized = fory.serialize(data, buffer_callback=buffer_objects.append)
buffers = [o.getbuffer() for o in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
assert deserialized["text"] == text
assert deserialized["number"] == number
np.testing.assert_array_equal(arr, deserialized["array"])
@pytest.mark.skipif(pd is None or np is None, reason="Requires numpy and pandas")
def test_mixed_numpy_pandas_out_of_band():
fory = Fory(xlang=False, ref=True, strict=False)
arr = np.arange(500, dtype=np.float64)
df = pd.DataFrame({"x": np.arange(500, dtype=np.int64), "y": np.arange(500, dtype=np.float32)})
data = {"array": arr, "dataframe": df}
buffer_objects = []
serialized = fory.serialize(data, buffer_callback=buffer_objects.append)
buffers = [o.getbuffer() for o in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
np.testing.assert_array_equal(arr, deserialized["array"])
pd.testing.assert_frame_equal(df, deserialized["dataframe"])
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_selective_out_of_band_serialization():
fory = Fory(xlang=False, ref=True, strict=False)
arr1 = np.arange(1000, dtype=np.float64)
arr2 = np.arange(1000, dtype=np.float64)
data = [arr1, arr2]
buffer_objects = []
counter = 0
def selective_buffer_callback(buffer_object):
nonlocal counter
counter += 1
if counter % 2 == 0:
buffer_objects.append(buffer_object)
return False
else:
return True
serialized = fory.serialize(data, buffer_callback=selective_buffer_callback)
buffers = [o.getbuffer() for o in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
assert len(deserialized) == 2
np.testing.assert_array_equal(arr1, deserialized[0])
np.testing.assert_array_equal(arr2, deserialized[1])
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_buffer_object_write_to_stream():
"""Test BufferObject.write_to() with different stream types"""
import io
from pyfory.serializer import NDArrayBufferObject
fory = Fory(xlang=False, ref=False, strict=False)
arr = np.arange(100).reshape(10, 10).astype(np.float64)
buffer_objects = []
serialized = fory.serialize(arr, buffer_callback=buffer_objects.append)
assert len(buffer_objects) > 0, "Should have collected out-of-band buffers"
for buffer_obj in buffer_objects:
assert isinstance(buffer_obj, NDArrayBufferObject), f"Expected NDArrayBufferObject, got {type(buffer_obj)}"
for buffer_obj in buffer_objects:
stream = io.BytesIO()
buffer_obj.write_to(stream)
stream.seek(0)
data = stream.read()
assert len(data) == buffer_obj.total_bytes()
for buffer_obj in buffer_objects:
mv = buffer_obj.getbuffer()
assert isinstance(mv, memoryview)
assert mv.nbytes == buffer_obj.total_bytes()
buffers = [obj.getbuffer() for obj in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
np.testing.assert_array_equal(arr, deserialized)
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_multidimensional_numpy_array_out_of_band():
"""Test out-of-band serialization with multi-dimensional numpy arrays"""
from pyfory.serializer import NDArrayBufferObject
fory = Fory(xlang=False, ref=False, strict=False)
arr_2d = np.arange(100).reshape(10, 10).astype(np.float64)
arr_3d = np.arange(1000).reshape(10, 10, 10).astype(np.int64)
arr_4d = np.arange(256).reshape(4, 4, 4, 4).astype(np.float32)
data = [arr_2d, arr_3d, arr_4d]
buffer_objects = []
serialized = fory.serialize(data, buffer_callback=buffer_objects.append)
assert len(buffer_objects) > 0, "Should have collected out-of-band buffers"
for buffer_obj in buffer_objects:
assert isinstance(buffer_obj, NDArrayBufferObject), f"Expected NDArrayBufferObject, got {type(buffer_obj)}"
mv = buffer_obj.getbuffer()
assert isinstance(mv, memoryview), "getbuffer() should return memoryview"
assert len(mv) > 0, "Buffer should contain data"
buffers = [obj.getbuffer() for obj in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
assert len(deserialized) == 3
assert deserialized[0].shape == (10, 10)
assert deserialized[1].shape == (10, 10, 10)
assert deserialized[2].shape == (4, 4, 4, 4)
np.testing.assert_array_equal(arr_2d, deserialized[0])
np.testing.assert_array_equal(arr_3d, deserialized[1])
np.testing.assert_array_equal(arr_4d, deserialized[2])
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_numpy_array_different_dtypes_out_of_band():
"""Test out-of-band serialization preserves various numpy dtypes"""
from pyfory.serializer import NDArrayBufferObject
fory = Fory(xlang=False, ref=False, strict=False)
arrays = {
"float32": np.arange(100).reshape(10, 10).astype(np.float32),
"float64": np.arange(100).reshape(10, 10).astype(np.float64),
"int8": np.arange(100).reshape(10, 10).astype(np.int8),
"int16": np.arange(100).reshape(10, 10).astype(np.int16),
"int32": np.arange(100).reshape(10, 10).astype(np.int32),
"int64": np.arange(100).reshape(10, 10).astype(np.int64),
"uint8": np.arange(100).reshape(10, 10).astype(np.uint8),
"bool": np.array([True, False] * 50, dtype=np.bool_).reshape(10, 10),
}
buffer_objects = []
serialized = fory.serialize(arrays, buffer_callback=buffer_objects.append)
assert len(buffer_objects) > 0, "Should have collected out-of-band buffers"
for buffer_obj in buffer_objects:
assert isinstance(buffer_obj, NDArrayBufferObject), f"Expected NDArrayBufferObject, got {type(buffer_obj)}"
mv = buffer_obj.getbuffer()
assert isinstance(mv, memoryview), "getbuffer() should return memoryview"
buffers = [obj.getbuffer() for obj in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
for key, original_array in arrays.items():
np.testing.assert_array_equal(original_array, deserialized[key])
assert original_array.dtype == deserialized[key].dtype, f"dtype mismatch for {key}: {original_array.dtype} != {deserialized[key].dtype}"
@pytest.mark.skipif(np is None, reason="Requires numpy")
def test_large_numpy_arrays_verify_buffer_collection():
"""Verify that large numpy arrays properly use out-of-band buffers"""
from pyfory.serializer import NDArrayBufferObject
fory = Fory(xlang=False, ref=False, strict=False)
large_2d = np.arange(10000).reshape(100, 100).astype(np.int64)
large_3d = np.arange(27000).reshape(30, 30, 30).astype(np.float32)
large_4d = np.arange(10000).reshape(10, 10, 10, 10).astype(np.float64)
arrays = [large_2d, large_3d, large_4d]
buffer_objects = []
serialized = fory.serialize(arrays, buffer_callback=buffer_objects.append)
assert len(buffer_objects) > 0, "Should have collected out-of-band buffers"
for i, buffer_obj in enumerate(buffer_objects):
assert isinstance(buffer_obj, NDArrayBufferObject), f"Buffer {i}: Expected NDArrayBufferObject, got {type(buffer_obj)}"
mv = buffer_obj.getbuffer()
assert isinstance(mv, memoryview), "getbuffer() should return memoryview"
assert len(mv) > 0, f"Buffer {i} should contain data"
buffers = [obj.getbuffer() for obj in buffer_objects]
deserialized = fory.deserialize(serialized, buffers=buffers)
assert len(deserialized) == 3
np.testing.assert_array_equal(large_2d, deserialized[0])
np.testing.assert_array_equal(large_3d, deserialized[1])
np.testing.assert_array_equal(large_4d, deserialized[2])