blob: e01520198e2cbf1d80ba8d1151a5844fadecba45 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import pickle
import struct
from abc import ABCMeta, abstractmethod
from itertools import chain
class SpecialLengths(object):
END_OF_DATA_SECTION = -1
NULL = -2
class Serializer(object, metaclass=ABCMeta):
# Note: our notion of "equality" is that output generated by
# equal serializers can be deserialized using the same serializer.
# This default implementation handles the simple cases;
# subclasses should override __eq__ as appropriate.
def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "%s()" % self.__class__.__name__
def __hash__(self):
return hash(str(self))
@abstractmethod
def dump_to_stream(self, iterator, stream):
"""
Serializes an iterator of objects to the output stream.
"""
pass
@abstractmethod
def load_from_stream(self, stream):
"""
Returns an iterator of deserialized objects from the input stream.
"""
pass
def _load_from_stream_without_unbatching(self, stream):
"""
Returns an iterator of deserialized batches (iterable) of objects from the input stream.
If the serializer does not operate on batches the default implementation returns an
iterator of single element lists.
"""
return map(lambda x: [x], self.load_from_stream(stream))
class VarLengthDataSerializer(Serializer):
"""
Serializer that writes objects as a stream of (length, data) pairs,
where length is a 32-bit integer and data is length bytes.
"""
def dump_to_stream(self, iterator, stream):
for obj in iterator:
self._write_with_length(obj, stream)
def load_from_stream(self, stream):
while True:
try:
yield self._read_with_length(stream)
except EOFError:
return
def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
if serialized is None:
raise ValueError("Serialized value should not be None")
if len(serialized) > (1 << 31):
raise ValueError("Can not serialize object larger than 2G")
write_int(len(serialized), stream)
stream.write(serialized)
def _read_with_length(self, stream):
length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError
elif length == SpecialLengths.NULL:
return None
obj = stream.read(length)
if len(obj) < length:
raise EOFError
return self.loads(obj)
@abstractmethod
def dumps(self, obj):
"""
Serialize an object into a byte array.
When batching is used, this will be called with an array of objects.
"""
pass
@abstractmethod
def loads(self, obj):
"""
Deserialize an object from a byte array.
"""
pass
class PickleSerializer(VarLengthDataSerializer):
"""
Serializes objects using Python's pickle serializer:
http://docs.python.org/3/library/pickle.html
This serializer supports nearly any Python object, but may
not be as fast as more specialized serializers.
"""
def dumps(self, obj):
return pickle.dumps(obj, 3)
def loads(self, obj):
return pickle.loads(obj, encoding="bytes")
class BatchedSerializer(Serializer):
"""
Serializes a stream of objects in batches by calling its wrapped
Serializer with streams of objects.
"""
UNLIMITED_BATCH_SIZE = -1
UNKNOWN_BATCH_SIZE = 0
def __init__(self, serializer, batch_size=UNLIMITED_BATCH_SIZE):
self.serializer = serializer
self.batch_size = batch_size
def __repr__(self):
return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batch_size)
def _batched(self, iterator):
if self.batch_size == self.UNLIMITED_BATCH_SIZE:
yield list(iterator)
elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
n = len(iterator)
for i in range(0, n, self.batch_size):
yield iterator[i: i + self.batch_size]
else:
items = []
count = 0
for item in iterator:
items.append(item)
count += 1
if count == self.batch_size:
yield items
items = []
count = 0
if items:
yield items
def dump_to_stream(self, iterator, stream):
self.serializer.dump_to_stream(self._batched(iterator), stream)
def load_from_stream(self, stream):
return chain.from_iterable(self._load_from_stream_without_unbatching(stream))
def _load_from_stream_without_unbatching(self, stream):
return self.serializer.load_from_stream(stream)
def read_int(stream):
length = stream.read(4)
if not length:
raise EOFError
return struct.unpack("!i", length)[0]
def write_int(value, stream):
stream.write(struct.pack("!i", value))