blob: 25a70014fa993852602640b034696cb700f30891 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import array
import datetime
import logging
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Any
from pyfury.buffer import Buffer
from pyfury.resolver import NOT_NULL_VALUE_FLAG, NULL_FLAG
from pyfury.type import (
FuryType,
is_primitive_type,
# Int8ArrayType,
Int16ArrayType,
Int32ArrayType,
Int64ArrayType,
Float32ArrayType,
Float64ArrayType,
)
try:
import numpy as np
except ImportError:
np = None
logger = logging.getLogger(__name__)
NOT_SUPPORT_CROSS_LANGUAGE = 0
USE_CLASSNAME = 0
USE_CLASS_ID = 1
# preserve 0 as flag for class id not set in ClassInfo`
NO_CLASS_ID = 0
PYINT_CLASS_ID = 1
PYFLOAT_CLASS_ID = 2
PYBOOL_CLASS_ID = 3
STRING_CLASS_ID = 4
PICKLE_CLASS_ID = 5
PICKLE_STRONG_CACHE_CLASS_ID = 6
PICKLE_CACHE_CLASS_ID = 7
# `NOT_NULL_VALUE_FLAG` + `CLASS_ID` in little-endian order
NOT_NULL_PYINT_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYINT_CLASS_ID << 8)
NOT_NULL_PYFLOAT_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYFLOAT_CLASS_ID << 8)
NOT_NULL_PYBOOL_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYBOOL_CLASS_ID << 8)
NOT_NULL_STRING_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (STRING_CLASS_ID << 8)
class _PickleStub:
pass
class PickleStrongCacheStub:
pass
class PickleCacheStub:
pass
class BufferObject(ABC):
"""
Fury binary representation of an object.
Note: This class is used for zero-copy out-of-band serialization and shouldn't
be used for any other cases.
"""
@abstractmethod
def total_bytes(self) -> int:
"""total size for serialized bytes of an object"""
@abstractmethod
def write_to(self, buffer: "Buffer"):
"""Write serialized object to a buffer."""
@abstractmethod
def to_buffer(self) -> "Buffer":
"""Write serialized data as Buffer."""
class BytesBufferObject(BufferObject):
__slots__ = ("binary",)
def __init__(self, binary: bytes):
self.binary = binary
def total_bytes(self) -> int:
return len(self.binary)
def write_to(self, buffer: "Buffer"):
buffer.write_bytes(self.binary)
def to_buffer(self) -> "Buffer":
return Buffer(self.binary)
class Serializer(ABC):
__slots__ = "fury", "type_", "need_to_write_ref"
def __init__(self, fury, type_: type):
self.fury = fury
self.type_: type = type_
self.need_to_write_ref = not is_primitive_type(type_)
def get_xtype_id(self):
"""
Returns
-------
Returns NOT_SUPPORT_CROSS_LANGUAGE if the serializer doesn't
support cross-language serialization.
Return a number in range (0, 32767) if the serializer support
cross-language serialization and native serialization data is the
same with cross-language serialization.
Return a negative short in range [-32768, 0) if the serializer
support cross-language serialization and native serialization data
is not the same with cross-language serialization.
"""
return NOT_SUPPORT_CROSS_LANGUAGE
def get_xtype_tag(self):
"""
Returns
-------
a type tag used for setup type mapping between languages.
"""
raise RuntimeError("Tag is only for struct.")
def write(self, buffer, value):
raise NotImplementedError
def read(self, buffer):
raise NotImplementedError
@abstractmethod
def xwrite(self, buffer, value):
pass
@abstractmethod
def xread(self, buffer):
pass
@classmethod
def support_subclass(cls) -> bool:
return False
class CrossLanguageCompatibleSerializer(Serializer):
def __init__(self, fury, type_):
super().__init__(fury, type_)
def xwrite(self, buffer, value):
self.write(buffer, value)
def xread(self, buffer):
return self.read(buffer)
class NoneSerializer(Serializer):
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
def write(self, buffer, value):
pass
def read(self, buffer):
return None
class BooleanSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.BOOL.value
def write(self, buffer, value):
buffer.write_bool(value)
def read(self, buffer):
return buffer.read_bool()
class ByteSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.INT8.value
def write(self, buffer, value):
buffer.write_int8(value)
def read(self, buffer):
return buffer.read_int8()
class Int16Serializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.INT16.value
def write(self, buffer, value):
buffer.write_int16(value)
def read(self, buffer):
return buffer.read_int16()
class Int32Serializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.INT32.value
def write(self, buffer, value):
buffer.write_int32(value)
def read(self, buffer):
return buffer.read_int32()
class Int64Serializer(Serializer):
def get_xtype_id(self):
return FuryType.INT64.value
def xwrite(self, buffer, value):
buffer.write_int64(value)
def xread(self, buffer):
return buffer.read_int64()
def write(self, buffer, value):
buffer.write_varint64(value)
def read(self, buffer):
return buffer.read_varint64()
class FloatSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.FLOAT.value
def write(self, buffer, value):
buffer.write_float(value)
def read(self, buffer):
return buffer.read_float()
class DoubleSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.DOUBLE.value
def write(self, buffer, value):
buffer.write_double(value)
def read(self, buffer):
return buffer.read_double()
class StringSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.STRING.value
def write(self, buffer, value: str):
buffer.write_string(value)
def read(self, buffer):
return buffer.read_string()
_base_date = datetime.date(1970, 1, 1)
class DateSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.DATE32.value
def write(self, buffer, value: datetime.date):
if not isinstance(value, datetime.date):
raise TypeError(
"{} should be {} instead of {}".format(
value, datetime.date, type(value)
)
)
days = (value - _base_date).days
buffer.write_int32(days)
def read(self, buffer):
days = buffer.read_int32()
return _base_date + datetime.timedelta(days=days)
class TimestampSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.TIMESTAMP.value
def write(self, buffer, value: datetime.datetime):
if not isinstance(value, datetime.datetime):
raise TypeError(
"{} should be {} instead of {}".format(value, datetime, type(value))
)
# TimestampType represent micro seconds
timestamp = int(value.timestamp() * 1000000)
buffer.write_int64(timestamp)
def read(self, buffer):
ts = buffer.read_int64() / 1000000
# TODO support timezone
return datetime.datetime.fromtimestamp(ts)
class BytesSerializer(CrossLanguageCompatibleSerializer):
def get_xtype_id(self):
return FuryType.BINARY.value
def write(self, buffer, value: bytes):
assert isinstance(value, bytes)
self.fury.write_buffer_object(buffer, BytesBufferObject(value))
def read(self, buffer):
fury_buf = self.fury.read_buffer_object(buffer)
return fury_buf.to_pybytes()
# Use numpy array or python array module.
typecode_dict = {
# use bytes serializer for byte array.
"h": (2, FuryType.FURY_PRIMITIVE_SHORT_ARRAY.value),
"i": (4, FuryType.FURY_PRIMITIVE_INT_ARRAY.value),
"l": (8, FuryType.FURY_PRIMITIVE_LONG_ARRAY.value),
"f": (4, FuryType.FURY_PRIMITIVE_FLOAT_ARRAY.value),
"d": (8, FuryType.FURY_PRIMITIVE_DOUBLE_ARRAY.value),
}
if np:
typecode_dict = {
k: (itemsize, -type_id) for k, (itemsize, type_id) in typecode_dict.items()
}
class PyArraySerializer(CrossLanguageCompatibleSerializer):
typecode_dict = typecode_dict
typecodearray_type = {
"h": Int16ArrayType,
"i": Int32ArrayType,
"l": Int64ArrayType,
"f": Float32ArrayType,
"d": Float64ArrayType,
}
def __init__(self, fury, type_, typecode):
super().__init__(fury, type_)
self.typecode = typecode
self.itemsize, self.type_id = PyArraySerializer.typecode_dict[self.typecode]
def get_xtype_id(self):
return self.type_id
def xwrite(self, buffer, value):
assert value.itemsize == self.itemsize
view = memoryview(value)
assert view.format == self.typecode
assert view.itemsize == self.itemsize
assert view.c_contiguous # TODO handle contiguous
nbytes = len(value) * self.itemsize
buffer.write_varint32(nbytes)
buffer.write_buffer(value)
def xread(self, buffer):
data = buffer.read_bytes_and_size()
arr = array.array(self.typecode, [])
arr.frombytes(data)
return arr
def write(self, buffer, value: array.array):
nbytes = len(value) * value.itemsize
buffer.write_string(value.typecode)
buffer.write_varint32(nbytes)
buffer.write_buffer(value)
def read(self, buffer):
typecode = buffer.read_string()
data = buffer.read_bytes_and_size()
arr = array.array(typecode, [])
arr.frombytes(data)
return arr
if np:
_np_dtypes_dict = {
# use bytes serializer for byte array.
np.dtype(np.bool_): (1, "?", FuryType.FURY_PRIMITIVE_BOOL_ARRAY.value),
np.dtype(np.int16): (2, "h", FuryType.FURY_PRIMITIVE_SHORT_ARRAY.value),
np.dtype(np.int32): (4, "i", FuryType.FURY_PRIMITIVE_INT_ARRAY.value),
np.dtype(np.int64): (8, "l", FuryType.FURY_PRIMITIVE_LONG_ARRAY.value),
np.dtype(np.float32): (4, "f", FuryType.FURY_PRIMITIVE_FLOAT_ARRAY.value),
np.dtype(np.float64): (8, "d", FuryType.FURY_PRIMITIVE_DOUBLE_ARRAY.value),
}
else:
_np_dtypes_dict = {}
class Numpy1DArraySerializer(CrossLanguageCompatibleSerializer):
dtypes_dict = _np_dtypes_dict
def __init__(self, fury, type_, dtype):
super().__init__(fury, type_)
self.dtype = dtype
self.itemsize, self.typecode, self.type_id = _np_dtypes_dict[self.dtype]
def get_xtype_id(self):
return self.type_id
def xwrite(self, buffer, value):
assert value.itemsize == self.itemsize
view = memoryview(value)
assert view.format == self.typecode
assert view.itemsize == self.itemsize
nbytes = len(value) * self.itemsize
buffer.write_varint32(nbytes)
if self.dtype == np.dtype("bool") or not view.c_contiguous:
buffer.write_bytes(value.tobytes())
else:
buffer.write_buffer(value)
def xread(self, buffer):
data = buffer.read_bytes_and_size()
return np.frombuffer(data, dtype=self.dtype)
def write(self, buffer, value):
self.fury.handle_unsupported_write(buffer, value)
def read(self, buffer):
return self.fury.handle_unsupported_read(buffer)
class CollectionSerializer(Serializer):
__slots__ = "class_resolver", "ref_resolver", "elem_serializer"
def __init__(self, fury, type_, elem_serializer=None):
super().__init__(fury, type_)
self.class_resolver = fury.class_resolver
self.ref_resolver = fury.ref_resolver
self.elem_serializer = elem_serializer
def get_xtype_id(self):
return -FuryType.LIST.value
def write(self, buffer, value: Iterable[Any]):
buffer.write_varint32(len(value))
for s in value:
cls = type(s)
if cls is str:
buffer.write_int24(NOT_NULL_STRING_FLAG)
buffer.write_string(s)
elif cls is int:
buffer.write_int24(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)
elif cls is bool:
buffer.write_int24(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)
else:
if not self.ref_resolver.write_ref_or_null(buffer, s):
classinfo = self.class_resolver.get_or_create_classinfo(cls)
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)
def read(self, buffer):
len_ = buffer.read_varint32()
collection_ = self.new_instance(self.type_)
for i in range(len_):
self.handle_read_elem(self.fury.deserialize_ref(buffer), collection_)
return collection_
def new_instance(self, type_):
# TODO support iterable subclass
instance = []
self.fury.ref_resolver.reference(instance)
return instance
def handle_read_elem(self, elem, collection_):
collection_.append(elem)
def xwrite(self, buffer, value):
try:
len_ = len(value)
except AttributeError:
value = list(value)
len_ = len(value)
buffer.write_varint32(len_)
for s in value:
self.fury.xserialize_ref(buffer, s, serializer=self.elem_serializer)
len_ += 1
def xread(self, buffer):
len_ = buffer.read_varint32()
collection_ = self.new_instance(self.type_)
for i in range(len_):
self.handle_read_elem(
self.fury.xdeserialize_ref(buffer, serializer=self.elem_serializer),
collection_,
)
return collection_
class ListSerializer(CollectionSerializer):
def get_xtype_id(self):
return FuryType.LIST.value
def read(self, buffer):
len_ = buffer.read_varint32()
instance = []
self.fury.ref_resolver.reference(instance)
for i in range(len_):
instance.append(self.fury.deserialize_ref(buffer))
return instance
class TupleSerializer(CollectionSerializer):
def read(self, buffer):
len_ = buffer.read_varint32()
collection_ = []
for i in range(len_):
collection_.append(self.fury.deserialize_ref(buffer))
return tuple(collection_)
class StringArraySerializer(ListSerializer):
def __init__(self, fury, type_):
super().__init__(fury, type_, StringSerializer(fury, str))
def get_xtype_id(self):
return FuryType.FURY_STRING_ARRAY.value
class SetSerializer(CollectionSerializer):
def get_xtype_id(self):
return FuryType.FURY_SET.value
def new_instance(self, type_):
instance = set()
self.fury.ref_resolver.reference(instance)
return instance
def handle_read_elem(self, elem, set_: set):
set_.add(elem)
class MapSerializer(Serializer):
__slots__ = (
"class_resolver",
"ref_resolver",
"key_serializer",
"value_serializer",
)
def __init__(self, fury, type_, key_serializer=None, value_serializer=None):
super().__init__(fury, type_)
self.class_resolver = fury.class_resolver
self.ref_resolver = fury.ref_resolver
self.key_serializer = key_serializer
self.value_serializer = value_serializer
def get_xtype_id(self):
return FuryType.MAP.value
def write(self, buffer, value: Dict):
buffer.write_varint32(len(value))
for k, v in value.items():
key_cls = type(k)
if key_cls is str:
buffer.write_int24(NOT_NULL_STRING_FLAG)
buffer.write_string(k)
else:
if not self.ref_resolver.write_ref_or_null(buffer, k):
classinfo = self.class_resolver.get_or_create_classinfo(key_cls)
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, k)
value_cls = type(v)
if value_cls is str:
buffer.write_int24(NOT_NULL_STRING_FLAG)
buffer.write_string(v)
elif value_cls is int:
buffer.write_int24(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(v)
else:
if not self.ref_resolver.write_ref_or_null(buffer, v):
classinfo = self.class_resolver.get_or_create_classinfo(value_cls)
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, v)
def read(self, buffer):
len_ = buffer.read_varint32()
map_ = self.type_()
self.fury.ref_resolver.reference(map_)
for i in range(len_):
k = self.fury.deserialize_ref(buffer)
v = self.fury.deserialize_ref(buffer)
map_[k] = v
return map_
def xwrite(self, buffer, value: Dict):
buffer.write_varint32(len(value))
for k, v in value.items():
self.fury.xserialize_ref(buffer, k, serializer=self.key_serializer)
self.fury.xserialize_ref(buffer, v, serializer=self.value_serializer)
def xread(self, buffer):
len_ = buffer.read_varint32()
map_ = {}
self.fury.ref_resolver.reference(map_)
for i in range(len_):
k = self.fury.xdeserialize_ref(buffer, serializer=self.key_serializer)
v = self.fury.xdeserialize_ref(buffer, serializer=self.value_serializer)
map_[k] = v
return map_
SubMapSerializer = MapSerializer
class EnumSerializer(Serializer):
@classmethod
def support_subclass(cls) -> bool:
return True
def write(self, buffer, value):
buffer.write_string(value.name)
def read(self, buffer):
name = buffer.read_string()
return getattr(self.type_, name)
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
class SliceSerializer(Serializer):
def write(self, buffer, value: slice):
start, stop, step = value.start, value.stop, value.step
if type(start) is int:
# TODO support varint128
buffer.write_int24(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(start)
else:
if start is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
self.fury.serialize_nonref(buffer, start)
if type(stop) is int:
# TODO support varint128
buffer.write_int24(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(stop)
else:
if stop is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
self.fury.serialize_nonref(buffer, stop)
if type(step) is int:
# TODO support varint128
buffer.write_int24(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(step)
else:
if step is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
self.fury.serialize_nonref(buffer, step)
def read(self, buffer):
if buffer.read_int8() == NULL_FLAG:
start = None
else:
start = self.fury.deserialize_nonref(buffer)
if buffer.read_int8() == NULL_FLAG:
stop = None
else:
stop = self.fury.deserialize_nonref(buffer)
if buffer.read_int8() == NULL_FLAG:
step = None
else:
step = self.fury.deserialize_nonref(buffer)
return slice(start, stop, step)
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
class PickleSerializer(Serializer):
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
def write(self, buffer, value):
self.fury.handle_unsupported_write(buffer, value)
def read(self, buffer):
return self.fury.handle_unsupported_read(buffer)
class SerializationContext:
"""
A context is used to add some context-related information, so that the
serializers can setup relation between serializing different objects.
The context will be reset after finished serializing/deserializing the
object tree.
"""
__slots__ = ("objects",)
def __init__(self):
self.objects = dict()
def add(self, key, obj):
self.objects[id(key)] = obj
def __contains__(self, key):
return id(key) in self.objects
def __getitem__(self, key):
return self.objects[id(key)]
def get(self, key):
return self.objects.get(id(key))
def reset(self):
if len(self.objects) > 0:
self.objects.clear()