blob: 3e606825968cbbad13e88b7f52c6a1f990140574 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import array
import dataclasses
import datetime
import enum
import logging
import os
import sys
import warnings
from dataclasses import dataclass
from typing import Dict, Tuple, TypeVar, Union, Iterable
from pyfury.lib import mmh3
from pyfury.buffer import Buffer
from pyfury.resolver import (
MapRefResolver,
NoRefResolver,
NULL_FLAG,
NOT_NULL_VALUE_FLAG,
)
from pyfury._serializer import (
Serializer,
SerializationContext,
NOT_SUPPORT_CROSS_LANGUAGE,
BufferObject,
PickleSerializer,
Numpy1DArraySerializer,
PyArraySerializer,
PYINT_CLASS_ID,
PYFLOAT_CLASS_ID,
PYBOOL_CLASS_ID,
STRING_CLASS_ID,
PICKLE_CLASS_ID,
USE_CLASSNAME,
USE_CLASS_ID,
NOT_NULL_STRING_FLAG,
NOT_NULL_PYINT_FLAG,
NOT_NULL_PYBOOL_FLAG,
NO_CLASS_ID,
NoneSerializer,
_PickleStub,
PickleStrongCacheStub,
PICKLE_STRONG_CACHE_CLASS_ID,
PICKLE_CACHE_CLASS_ID,
PickleCacheStub,
)
from pyfury.type import (
FuryType,
Int8Type,
Int16Type,
Int32Type,
Int64Type,
Float32Type,
Float64Type,
load_class,
)
from pyfury.util import is_little_endian, set_bit, get_bit, clear_bit
try:
import numpy as np
except ImportError:
np = None
from cloudpickle import Pickler
if sys.version_info[:2] < (3, 8): # pragma: no cover
from pickle5 import Unpickler
else:
from pickle import Unpickler
logger = logging.getLogger(__name__)
DEFAULT_DYNAMIC_WRITE_STRING_ID = -1
class EnumStringBytes:
__slots__ = (
"data",
"length",
"hashcode",
"dynamic_write_string_id",
)
def __init__(self, data, hashcode=None):
self.data = data
self.length = len(data)
self.hashcode = hashcode or mmh3.hash_buffer(data, 47)[0]
self.dynamic_write_string_id = DEFAULT_DYNAMIC_WRITE_STRING_ID
def __eq__(self, other):
return type(other) is EnumStringBytes and other.hashcode == self.hashcode
def __hash__(self):
return self.hashcode
class ClassInfo:
__slots__ = (
"cls",
"class_id",
"serializer",
"class_name_bytes",
"type_tag_bytes",
)
def __init__(
self,
cls: type = None,
class_id: int = NO_CLASS_ID,
serializer: Serializer = None,
class_name_bytes: bytes = None,
type_tag_bytes: bytes = None,
):
self.cls = cls
self.class_id = class_id
self.serializer = serializer
self.class_name_bytes = EnumStringBytes(class_name_bytes)
self.type_tag_bytes = (
EnumStringBytes(type_tag_bytes) if type_tag_bytes else None
)
def __repr__(self):
return (
f"ClassInfo(cls={self.cls}, class_id={self.class_id}, "
f"serializer={self.serializer})"
)
class ClassResolver:
__slots__ = (
"fury",
"_type_id_to_class",
"_type_id_to_serializer",
"_type_id_and_cls_to_serializer",
"_type_tag_to_class_x_lang_map",
"_enum_str_to_str",
"_class_id_counter",
"_used_classes_id",
"_classes_info",
"_registered_id2_class_info",
"_hash_to_enum_string",
"_enum_str_to_class",
"_hash_to_classinfo",
"_dynamic_id_to_classinfo_list",
"_dynamic_id_to_enum_str_list",
"_serializer",
"_dynamic_write_string_id",
"_dynamic_written_enum_string",
)
_type_id_to_class: Dict[int, type]
_type_id_to_serializer: Dict[int, Serializer]
_type_id_and_cls_to_serializer: Dict[Tuple[int, type], Serializer]
_classes_info: Dict[type, "ClassInfo"]
def __init__(self, fury):
self.fury = fury
self._type_id_to_class = dict()
self._type_id_to_serializer = dict()
self._type_id_and_cls_to_serializer = dict()
self._type_tag_to_class_x_lang_map = dict()
self._class_id_counter = PICKLE_CACHE_CLASS_ID + 1
self._used_classes_id = set()
self._classes_info = dict()
self._registered_id2_class_info = []
self._enum_str_to_str = dict()
self._enum_str_to_class = dict()
self._hash_to_enum_string = dict()
self._hash_to_classinfo = dict()
self._dynamic_id_to_classinfo_list = list()
self._dynamic_id_to_enum_str_list = list()
self._serializer = None
self._dynamic_write_string_id = 0
self._dynamic_written_enum_string = []
def initialize(self):
self.register_class(int, class_id=PYINT_CLASS_ID)
self.register_class(float, class_id=PYFLOAT_CLASS_ID)
self.register_class(bool, class_id=PYBOOL_CLASS_ID)
self.register_class(str, class_id=STRING_CLASS_ID)
self.register_class(_PickleStub, class_id=PICKLE_CLASS_ID)
self.register_class(
PickleStrongCacheStub, class_id=PICKLE_STRONG_CACHE_CLASS_ID
)
self.register_class(PickleCacheStub, class_id=PICKLE_CACHE_CLASS_ID)
self._add_default_serializers()
# `Union[type, TypeVar]` is not supported in py3.6
def register_serializer(self, cls, serializer):
assert isinstance(cls, (type, TypeVar)), cls
type_id = serializer.get_xtype_id()
if type_id != NOT_SUPPORT_CROSS_LANGUAGE:
self._add_x_lang_serializer(cls, serializer=serializer)
else:
self.register_class(cls)
self._classes_info[cls].serializer = serializer
# `Union[type, TypeVar]` is not supported in py3.6
def register_class(self, cls, *, class_id: int = None, type_tag: str = None):
"""Register class with given type id or tag, if tag is not None, it will be used for
cross-language serialization."""
if type_tag is not None:
assert class_id is None, (
f"Type tag {type_tag} has been set already, "
f"set class id at the same time is not allowed."
)
from pyfury._struct import ComplexObjectSerializer
self.register_serializer(
cls, ComplexObjectSerializer(self.fury, cls, type_tag)
)
return
classinfo = self._classes_info.get(cls)
if classinfo is None:
if isinstance(cls, TypeVar):
class_name_bytes = (cls.__module__ + "#" + cls.__name__).encode("utf-8")
else:
class_name_bytes = (cls.__module__ + "#" + cls.__qualname__).encode(
"utf-8"
)
class_id = class_id if class_id is not None else self._next_class_id()
assert class_id not in self._used_classes_id, (
self._used_classes_id,
self._classes_info,
)
classinfo = ClassInfo(
cls=cls, class_name_bytes=class_name_bytes, class_id=class_id
)
self._classes_info[cls] = classinfo
if len(self._registered_id2_class_info) <= class_id:
self._registered_id2_class_info.extend(
[None] * (class_id - len(self._registered_id2_class_info) + 1)
)
self._registered_id2_class_info[class_id] = classinfo
else:
if classinfo.class_id == NO_CLASS_ID:
class_id = class_id if class_id is not None else self._next_class_id()
assert class_id not in self._used_classes_id, (
self._used_classes_id,
self._classes_info,
)
classinfo.class_id = class_id
if len(self._registered_id2_class_info) <= class_id:
self._registered_id2_class_info.extend(
[None] * (class_id - len(self._registered_id2_class_info) + 1)
)
self._registered_id2_class_info[class_id] = classinfo
else:
if class_id is not None and classinfo.class_id != class_id:
raise ValueError(
f"Inconsistent class id {class_id} vs {classinfo.class_id} "
f"for class {cls}"
)
def _next_class_id(self):
class_id = self._class_id_counter = self._class_id_counter + 1
while class_id in self._used_classes_id:
class_id = self._class_id_counter = self._class_id_counter + 1
return class_id
def _add_serializer(self, cls: type, serializer=None, serializer_cls=None):
if serializer_cls:
serializer = serializer_cls(self.fury, cls)
self.register_serializer(cls, serializer)
def _add_x_lang_serializer(self, cls: type, serializer=None, serializer_cls=None):
if serializer_cls:
serializer = serializer_cls(self.fury, cls)
type_id = serializer.get_xtype_id()
from pyfury._serializer import NOT_SUPPORT_CROSS_LANGUAGE
assert type_id != NOT_SUPPORT_CROSS_LANGUAGE
self._type_id_and_cls_to_serializer[(type_id, cls)] = serializer
self.register_class(cls)
classinfo = self._classes_info[cls]
classinfo.serializer = serializer
if type_id == FuryType.FURY_TYPE_TAG.value:
type_tag = serializer.get_xtype_tag()
assert type(type_tag) is str
assert type_tag not in self._type_tag_to_class_x_lang_map
classinfo.type_tag_bytes = EnumStringBytes(type_tag.encode("utf-8"))
self._type_tag_to_class_x_lang_map[type_tag] = cls
else:
self._type_id_to_serializer[type_id] = serializer
if type_id > NOT_SUPPORT_CROSS_LANGUAGE:
self._type_id_to_class[type_id] = cls
def _add_default_serializers(self):
import pyfury.serializer as serializers
from pyfury._serializer import PyArraySerializer, Numpy1DArraySerializer
self._add_x_lang_serializer(int, serializer_cls=serializers.ByteSerializer)
self._add_x_lang_serializer(int, serializer_cls=serializers.Int16Serializer)
self._add_x_lang_serializer(int, serializer_cls=serializers.Int32Serializer)
self._add_x_lang_serializer(int, serializer_cls=serializers.Int64Serializer)
self._add_x_lang_serializer(float, serializer_cls=serializers.FloatSerializer)
self._add_x_lang_serializer(float, serializer_cls=serializers.DoubleSerializer)
self._add_serializer(type(None), serializer_cls=NoneSerializer)
self._add_serializer(bool, serializer_cls=serializers.BooleanSerializer)
self._add_serializer(Int8Type, serializer_cls=serializers.ByteSerializer)
self._add_serializer(Int16Type, serializer_cls=serializers.Int16Serializer)
self._add_serializer(Int32Type, serializer_cls=serializers.Int32Serializer)
self._add_serializer(Int64Type, serializer_cls=serializers.Int64Serializer)
self._add_serializer(Float32Type, serializer_cls=serializers.FloatSerializer)
self._add_serializer(Float64Type, serializer_cls=serializers.DoubleSerializer)
self._add_serializer(str, serializer_cls=serializers.StringSerializer)
self._add_serializer(datetime.date, serializer_cls=serializers.DateSerializer)
self._add_serializer(
datetime.datetime, serializer_cls=serializers.TimestampSerializer
)
self._add_serializer(bytes, serializer_cls=serializers.BytesSerializer)
self._add_serializer(list, serializer_cls=serializers.ListSerializer)
self._add_serializer(tuple, serializer_cls=serializers.TupleSerializer)
self._add_serializer(dict, serializer_cls=serializers.MapSerializer)
self._add_serializer(set, serializer_cls=serializers.SetSerializer)
self._add_serializer(enum.Enum, serializer_cls=serializers.EnumSerializer)
self._add_serializer(slice, serializer_cls=serializers.SliceSerializer)
from pyfury import PickleCacheSerializer, PickleStrongCacheSerializer
self._add_serializer(
PickleStrongCacheStub, serializer=PickleStrongCacheSerializer(self.fury)
)
self._add_serializer(
PickleCacheStub, serializer=PickleCacheSerializer(self.fury)
)
try:
import pyarrow as pa
from pyfury.format.serializer import (
ArrowRecordBatchSerializer,
ArrowTableSerializer,
)
self._add_serializer(
pa.RecordBatch, serializer_cls=ArrowRecordBatchSerializer
)
self._add_serializer(pa.Table, serializer_cls=ArrowTableSerializer)
except Exception:
pass
for typecode in PyArraySerializer.typecode_dict.keys():
self._add_serializer(
array.array,
serializer=PyArraySerializer(self.fury, array.array, typecode),
)
self._add_serializer(
PyArraySerializer.typecodearray_type[typecode],
serializer=PyArraySerializer(self.fury, array.array, typecode),
)
if np:
for dtype in Numpy1DArraySerializer.dtypes_dict.keys():
self._add_serializer(
np.ndarray,
serializer=Numpy1DArraySerializer(self.fury, array.array, dtype),
)
def get_serializer(self, cls: type = None, type_id: int = None, obj=None):
"""
Returns
-------
Returns or create serializer for the provided class
"""
assert cls is not None or type_id is not None or obj is not None
if obj is not None:
cls = type(obj)
if cls is int and 2**63 - 1 >= obj >= -(2**63):
type_id = FuryType.INT64.value
elif cls is float:
type_id = FuryType.DOUBLE.value
elif cls is array.array:
info = PyArraySerializer.typecode_dict.get(obj.typecode)
if info is not None:
type_id = info[1]
elif np and cls is np.ndarray and obj.ndim == 1:
info = Numpy1DArraySerializer.dtypes_dict.get(obj.dtype)
if info:
type_id = info[2]
if type_id is not None:
if cls is not None:
serializer_ = self._type_id_and_cls_to_serializer[(type_id, cls)]
else:
serializer_ = self._type_id_to_serializer[type_id]
else:
class_info = self._classes_info.get(cls)
if class_info is not None:
serializer_ = class_info.serializer
else:
self._add_serializer(cls, serializer=self.get_or_create_serializer(cls))
serializer_ = self._classes_info.get(cls).serializer
self._serializer = serializer_
return serializer_
def get_or_create_serializer(self, cls):
return self.get_or_create_classinfo(cls).serializer
def get_or_create_classinfo(self, cls):
class_info = self._classes_info.get(cls)
if class_info is not None:
if class_info.serializer is not None:
return class_info
else:
class_info.serializer = self._create_serializer(cls)
return class_info
else:
serializer = self._create_serializer(cls)
class_id = (
NO_CLASS_ID
if type(serializer) is not PickleSerializer
else PICKLE_CLASS_ID
)
class_name_bytes = (cls.__module__ + "#" + cls.__qualname__).encode("utf-8")
class_info = ClassInfo(
cls=cls,
class_name_bytes=class_name_bytes,
serializer=serializer,
class_id=class_id,
)
self._classes_info[cls] = class_info
return class_info
def _create_serializer(self, cls):
mro = cls.__mro__
classinfo_ = self._classes_info.get(cls)
for clz in mro:
class_info = self._classes_info.get(clz)
if (
class_info
and class_info.serializer
and class_info.serializer.support_subclass()
):
if classinfo_ is None or classinfo_.class_id == NO_CLASS_ID:
logger.info("Class %s not registered", cls)
serializer = type(class_info.serializer)(self.fury, cls)
break
else:
if dataclasses.is_dataclass(cls):
if classinfo_ is None or classinfo_.class_id == NO_CLASS_ID:
logger.info("Class %s not registered", cls)
logger.info("Class %s not registered", cls)
from pyfury import DataClassSerializer
serializer = DataClassSerializer(self.fury, cls)
else:
serializer = PickleSerializer(self.fury, cls)
return serializer
def write_classinfo(self, buffer: Buffer, classinfo: ClassInfo):
class_id = classinfo.class_id
if class_id != NO_CLASS_ID:
buffer.write_int16(class_id)
return
buffer.write_int16(NO_CLASS_ID)
self.write_enum_string_bytes(buffer, classinfo.class_name_bytes)
def read_classinfo(self, buffer):
class_id = buffer.read_int16()
if (
class_id > NO_CLASS_ID
): # registered class id are greater than `NO_CLASS_ID`.
classinfo = self._registered_id2_class_info[class_id]
if classinfo.serializer is None:
classinfo.serializer = self._create_serializer(classinfo.cls)
return classinfo
if buffer.read_int8() == USE_CLASS_ID:
return self._dynamic_id_to_classinfo_list[buffer.read_int16()]
class_name_bytes_hash = buffer.read_int64()
class_name_bytes_length = buffer.read_int16()
reader_index = buffer.reader_index
buffer.check_bound(reader_index, class_name_bytes_length)
buffer.reader_index = reader_index + class_name_bytes_length
classinfo = self._hash_to_classinfo.get(class_name_bytes_hash)
if classinfo is None:
classname_bytes = buffer.get_bytes(reader_index, class_name_bytes_length)
full_class_name = classname_bytes.decode(encoding="utf-8")
cls = load_class(full_class_name)
classinfo = self.get_or_create_classinfo(cls)
self._hash_to_classinfo[class_name_bytes_hash] = classinfo
self._dynamic_id_to_classinfo_list.append(classinfo)
return classinfo
def write_enum_string_bytes(
self, buffer: Buffer, enum_string_bytes: EnumStringBytes
):
dynamic_write_string_id = enum_string_bytes.dynamic_write_string_id
if dynamic_write_string_id == DEFAULT_DYNAMIC_WRITE_STRING_ID:
dynamic_write_string_id = self._dynamic_write_string_id
enum_string_bytes.dynamic_write_string_id = dynamic_write_string_id
self._dynamic_write_string_id += 1
self._dynamic_written_enum_string.append(enum_string_bytes)
buffer.write_int8(USE_CLASSNAME)
buffer.write_int64(enum_string_bytes.hashcode)
buffer.write_int16(enum_string_bytes.length)
buffer.write_bytes(enum_string_bytes.data)
else:
buffer.write_int8(USE_CLASS_ID)
buffer.write_int16(dynamic_write_string_id)
def read_enum_string_bytes(self, buffer: Buffer) -> EnumStringBytes:
if buffer.read_int8() != USE_CLASSNAME:
return self._dynamic_id_to_enum_str_list[buffer.read_int16()]
hashcode = buffer.read_int64()
length = buffer.read_int16()
reader_index = buffer.reader_index
buffer.check_bound(reader_index, length)
buffer.reader_index = reader_index + length
enum_str = self._hash_to_enum_string.get(hashcode)
if enum_str is None:
str_bytes = buffer.get_bytes(reader_index, length)
enum_str = EnumStringBytes(str_bytes, hashcode=hashcode)
self._hash_to_enum_string[hashcode] = enum_str
self._dynamic_id_to_enum_str_list.append(enum_str)
return enum_str
def xwrite_class(self, buffer, cls):
class_name_bytes = self._classes_info[cls].class_name_bytes
self.write_enum_string_bytes(buffer, class_name_bytes)
def xwrite_type_tag(self, buffer, cls):
type_tag_bytes = self._classes_info[cls].type_tag_bytes
self.write_enum_string_bytes(buffer, type_tag_bytes)
def read_class_by_type_tag(self, buffer):
tag = self.xread_classname(buffer)
return self._type_tag_to_class_x_lang_map[tag]
def xread_class(self, buffer):
class_name_bytes = self.read_enum_string_bytes(buffer)
cls = self._enum_str_to_class.get(class_name_bytes)
if cls is None:
full_class_name = class_name_bytes.data.decode(encoding="utf-8")
cls = load_class(full_class_name)
self._enum_str_to_class[class_name_bytes] = cls
return cls
def xread_classname(self, buffer) -> str:
str_bytes = self.read_enum_string_bytes(buffer)
str_ = self._enum_str_to_str.get(str_bytes)
if str_ is None:
str_ = str_bytes.data.decode(encoding="utf-8")
self._enum_str_to_str[str_bytes] = str_
return str_
def get_class_by_type_id(self, type_id: int):
return self._type_id_to_class[type_id]
def reset(self):
self.reset_write()
self.reset_read()
def reset_read(self):
self._dynamic_id_to_classinfo_list.clear()
self._dynamic_id_to_enum_str_list.clear()
def reset_write(self):
if self._dynamic_write_string_id != 0:
self._dynamic_write_string_id = 0
for enum_str in self._dynamic_written_enum_string:
enum_str.dynamic_write_string_id = DEFAULT_DYNAMIC_WRITE_STRING_ID
self._dynamic_written_enum_string.clear()
class Language(enum.Enum):
XLANG = 0
JAVA = 1
PYTHON = 2
CPP = 3
GO = 4
@dataclass
class OpaqueObject:
language: Language
classname: str
ordinal: int
class Fury:
__slots__ = (
"language",
"ref_tracking",
"ref_resolver",
"class_resolver",
"serialization_context",
"require_class_registration",
"buffer",
"pickler",
"unpickler",
"_buffer_callback",
"_buffers",
"_unsupported_callback",
"_unsupported_objects",
"_peer_language",
"_native_objects",
)
serialization_context: "SerializationContext"
def __init__(
self,
language=Language.XLANG,
ref_tracking: bool = False,
require_class_registration: bool = True,
):
"""
:param require_class_registration:
Whether to require registering classes for serialization, enabled by default.
If disabled, unknown insecure classes can be deserialized, which can be
insecure and cause remote code execution attack if the classes
`__new__`/`__init__`/`__eq__`/`__hash__` method contain malicious code.
Do not disable class registration if you can't ensure your environment are
*indeed secure*. We are not responsible for security risks if
you disable this option.
"""
self.language = language
self.require_class_registration = (
_ENABLE_CLASS_REGISTRATION_FORCIBLY or require_class_registration
)
self.ref_tracking = ref_tracking
if self.ref_tracking:
self.ref_resolver = MapRefResolver()
else:
self.ref_resolver = NoRefResolver()
self.class_resolver = ClassResolver(self)
self.class_resolver.initialize()
self.serialization_context = SerializationContext()
self.buffer = Buffer.allocate(32)
if not require_class_registration:
warnings.warn(
"Class registration is disabled, unknown classes can be deserialized "
"which may be insecure.",
RuntimeWarning,
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
else:
self.pickler = _PicklerStub(self.buffer)
self.unpickler = None
self._buffer_callback = None
self._buffers = None
self._unsupported_callback = None
self._unsupported_objects = None
self._peer_language = None
self._native_objects = []
def register_serializer(self, cls: type, serializer):
self.class_resolver.register_serializer(cls, serializer)
# `Union[type, TypeVar]` is not supported in py3.6
def register_class(self, cls, *, class_id: int = None, type_tag: str = None):
self.class_resolver.register_class(cls, class_id=class_id, type_tag=type_tag)
def serialize(
self,
obj,
buffer: Buffer = None,
buffer_callback=None,
unsupported_callback=None,
) -> Union[Buffer, bytes]:
try:
return self._serialize(
obj,
buffer,
buffer_callback=buffer_callback,
unsupported_callback=unsupported_callback,
)
finally:
self.reset_write()
def _serialize(
self,
obj,
buffer: Buffer = None,
buffer_callback=None,
unsupported_callback=None,
) -> Union[Buffer, bytes]:
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = Pickler(buffer)
else:
self.buffer.writer_index = 0
buffer = self.buffer
mask_index = buffer.writer_index
# 1byte used for bit mask
buffer.grow(1)
buffer.writer_index = mask_index + 1
if obj is None:
set_bit(buffer, mask_index, 0)
else:
clear_bit(buffer, mask_index, 0)
# set endian
if is_little_endian:
set_bit(buffer, mask_index, 1)
else:
clear_bit(buffer, mask_index, 1)
if self.language == Language.XLANG:
# set reader as x_lang.
set_bit(buffer, mask_index, 2)
# set writer language.
buffer.write_int8(Language.PYTHON.value)
else:
# set reader as native.
clear_bit(buffer, mask_index, 2)
if self._buffer_callback is not None:
set_bit(buffer, mask_index, 3)
else:
clear_bit(buffer, mask_index, 3)
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
start_offset = buffer.writer_index
buffer.write_int32(-1) # preserve 4-byte for nativeObjects start offsets.
buffer.write_int32(-1) # preserve 4-byte for nativeObjects size
self.xserialize_ref(buffer, obj)
buffer.put_int32(start_offset, buffer.writer_index)
buffer.put_int32(start_offset + 4, len(self._native_objects))
self.ref_resolver.reset_write()
# fury write opaque object classname which cause later write of classname
# only write an id.
self.class_resolver.reset_write()
for native_object in self._native_objects:
self.serialize_ref(buffer, native_object)
self.reset_write()
if buffer is not self.buffer:
return buffer
else:
return buffer.to_bytes(0, buffer.writer_index)
def serialize_ref(self, buffer, obj, classinfo=None):
cls = type(obj)
if cls is str:
buffer.write_int24(NOT_NULL_STRING_FLAG)
buffer.write_string(obj)
return
elif cls is int:
buffer.write_int24(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(obj)
return
elif cls is bool:
buffer.write_int24(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(obj)
return
if self.ref_resolver.write_ref_or_null(buffer, obj):
return
if classinfo is None:
classinfo = self.class_resolver.get_or_create_classinfo(cls)
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, obj)
def serialize_nonref(self, buffer, obj):
cls = type(obj)
if cls is str:
buffer.write_int16(STRING_CLASS_ID)
buffer.write_string(obj)
return
elif cls is int:
buffer.write_int16(PYINT_CLASS_ID)
buffer.write_varint64(obj)
return
elif cls is bool:
buffer.write_int16(PYBOOL_CLASS_ID)
buffer.write_bool(obj)
return
else:
classinfo = self.class_resolver.get_or_create_classinfo(cls)
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, obj)
def xserialize_ref(self, buffer, obj, serializer=None):
if serializer is None or serializer.need_to_write_ref:
if not self.ref_resolver.write_ref_or_null(buffer, obj):
self.xserialize_nonref(buffer, obj, serializer=serializer)
else:
if obj is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
self.xserialize_nonref(buffer, obj, serializer=serializer)
def xserialize_nonref(self, buffer, obj, serializer=None):
cls = type(obj)
serializer = serializer or self.class_resolver.get_serializer(obj=obj)
type_id = serializer.get_xtype_id()
buffer.write_int16(type_id)
if type_id != NOT_SUPPORT_CROSS_LANGUAGE:
if type_id == FuryType.FURY_TYPE_TAG.value:
self.class_resolver.xwrite_type_tag(buffer, cls)
if type_id < NOT_SUPPORT_CROSS_LANGUAGE:
self.class_resolver.xwrite_class(buffer, cls)
serializer.xwrite(buffer, obj)
else:
# Write classname so it can be used for debugging which object doesn't
# support cross-language.
# TODO add a config to disable this to reduce space cost.
self.class_resolver.xwrite_class(buffer, cls)
# serializer may increase reference id multi times internally, thus peer
# cross-language later fields/objects deserialization will use wrong
# reference id since we skip opaque objects deserialization.
# So we stash native objects and serialize all those object at the last.
buffer.write_varint32(len(self._native_objects))
self._native_objects.append(obj)
def deserialize(
self,
buffer: Union[Buffer, bytes],
buffers: Iterable = None,
unsupported_objects: Iterable = None,
):
try:
return self._deserialize(buffer, buffers, unsupported_objects)
finally:
self.reset_read()
def _deserialize(
self,
buffer: Union[Buffer, bytes],
buffers: Iterable = None,
unsupported_objects: Iterable = None,
):
if type(buffer) == bytes:
buffer = Buffer(buffer)
if self.require_class_registration:
self.unpickler = _UnpicklerStub(buffer)
else:
self.unpickler = Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
reader_index = buffer.reader_index
buffer.reader_index = reader_index + 1
if get_bit(buffer, reader_index, 0):
return None
is_little_endian_ = get_bit(buffer, reader_index, 1)
assert is_little_endian_, (
"Big endian is not supported for now, "
"please ensure peer machine is little endian."
)
is_target_x_lang = get_bit(buffer, reader_index, 2)
if is_target_x_lang:
self._peer_language = Language(buffer.read_int8())
else:
self._peer_language = Language.PYTHON
is_out_of_band_serialization_enabled = get_bit(buffer, reader_index, 3)
if is_out_of_band_serialization_enabled:
assert buffers is not None, (
"buffers shouldn't be null when the serialized stream is "
"produced with buffer_callback not null."
)
self._buffers = iter(buffers)
else:
assert buffers is None, (
"buffers should be null when the serialized stream is "
"produced with buffer_callback null."
)
if is_target_x_lang:
native_objects_start_offset = buffer.read_int32()
native_objects_size = buffer.read_int32()
if self._peer_language == Language.PYTHON:
native_objects_buffer = buffer.slice(native_objects_start_offset)
for i in range(native_objects_size):
self._native_objects.append(
self.deserialize_ref(native_objects_buffer)
)
self.ref_resolver.reset_read()
self.class_resolver.reset_read()
obj = self.xdeserialize_ref(buffer)
else:
obj = self.deserialize_ref(buffer)
return obj
def deserialize_ref(self, buffer):
ref_resolver = self.ref_resolver
ref_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
if ref_id >= NOT_NULL_VALUE_FLAG:
classinfo = self.class_resolver.read_classinfo(buffer)
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
return o
else:
return ref_resolver.get_read_object()
def deserialize_nonref(self, buffer):
"""Deserialize not-null and non-reference object from buffer."""
classinfo = self.class_resolver.read_classinfo(buffer)
return classinfo.serializer.read(buffer)
def xdeserialize_ref(self, buffer, serializer=None):
if serializer is None or serializer.need_to_write_ref:
ref_resolver = self.ref_resolver
red_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
if red_id >= NOT_NULL_VALUE_FLAG:
o = self.xdeserialize_nonref(buffer, serializer=serializer)
ref_resolver.set_read_object(red_id, o)
return o
else:
return ref_resolver.get_read_object()
head_flag = buffer.read_int8()
if head_flag == NULL_FLAG:
return None
return self.xdeserialize_nonref(buffer, serializer=serializer)
def xdeserialize_nonref(self, buffer, serializer=None):
type_id = buffer.read_int16()
cls = None
if type_id != NOT_SUPPORT_CROSS_LANGUAGE:
if type_id == FuryType.FURY_TYPE_TAG.value:
cls = self.class_resolver.read_class_by_type_tag(buffer)
if type_id < NOT_SUPPORT_CROSS_LANGUAGE:
if self._peer_language is not Language.PYTHON:
self.class_resolver.read_enum_string_bytes(buffer)
cls = self.class_resolver.get_class_by_type_id(-type_id)
serializer = serializer or self.class_resolver.get_serializer(
type_id=-type_id
)
else:
cls = self.class_resolver.xread_class(buffer)
serializer = serializer or self.class_resolver.get_serializer(
cls=cls, type_id=type_id
)
else:
if type_id != FuryType.FURY_TYPE_TAG.value:
cls = self.class_resolver.get_class_by_type_id(type_id)
serializer = serializer or self.class_resolver.get_serializer(
cls=cls, type_id=type_id
)
assert cls is not None
return serializer.xread(buffer)
else:
class_name = self.class_resolver.xread_classname(buffer)
ordinal = buffer.read_varint32()
if self._peer_language != Language.PYTHON:
return OpaqueObject(self._peer_language, class_name, ordinal)
else:
return self._native_objects[ordinal]
def write_buffer_object(self, buffer, buffer_object: BufferObject):
if self._buffer_callback is None or self._buffer_callback(buffer_object):
buffer.write_bool(True)
size = buffer_object.total_bytes()
# writer length.
buffer.write_varint32(size)
writer_index = buffer.writer_index
buffer.ensure(writer_index + size)
buf = buffer.slice(buffer.writer_index, size)
buffer_object.write_to(buf)
buffer.writer_index += size
else:
buffer.write_bool(False)
def read_buffer_object(self, buffer) -> Buffer:
in_band = buffer.read_bool()
if in_band:
size = buffer.read_varint32()
buf = buffer.slice(buffer.reader_index, size)
buffer.reader_index += size
return buf
else:
assert self._buffers is not None
return next(self._buffers)
def handle_unsupported_write(self, buffer, obj):
if self._unsupported_callback is None or self._unsupported_callback(obj):
buffer.write_bool(True)
self.pickler.dump(obj)
else:
buffer.write_bool(False)
def handle_unsupported_read(self, buffer):
in_band = buffer.read_bool()
if in_band:
return self.unpickler.load()
else:
assert self._unsupported_objects is not None
return next(self._unsupported_objects)
def write_ref_pyobject(self, buffer, value, classinfo=None):
if self.ref_resolver.write_ref_or_null(buffer, value):
return
if classinfo is None:
classinfo = self.class_resolver.get_or_create_classinfo(type(value))
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, value)
def read_ref_pyobject(self, buffer):
return self.deserialize_ref(buffer)
def reset_write(self):
self.ref_resolver.reset_write()
self.class_resolver.reset_write()
self.serialization_context.reset()
self._native_objects.clear()
self.pickler.clear_memo()
self._buffer_callback = None
self._unsupported_callback = None
def reset_read(self):
self.ref_resolver.reset_read()
self.class_resolver.reset_read()
self.serialization_context.reset()
self._native_objects.clear()
self.unpickler = None
self._buffers = None
self._unsupported_objects = None
def reset(self):
self.reset_write()
self.reset_read()
_ENABLE_CLASS_REGISTRATION_FORCIBLY = os.getenv(
"ENABLE_CLASS_REGISTRATION_FORCIBLY", "0"
) in {
"1",
"true",
}
class _PicklerStub:
def __init__(self, buf):
self.buf = buf
def dump(self, o):
raise ValueError(
f"Class {type(o)} is not registered, "
f"pickle is not allowed when class registration enabled, Please register"
f"the class or pass unsupported_callback"
)
def clear_memo(self):
pass
class _UnpicklerStub:
def __init__(self, buf):
self.buf = buf
def load(self):
raise ValueError(
"pickle is not allowed when class registration enabled, Please register"
"the class or pass unsupported_callback"
)