################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

import pickle
import struct
from abc import ABCMeta, abstractmethod
from itertools import chain


class SpecialLengths(object):
    END_OF_DATA_SECTION = -1
    NULL = -2


class Serializer(object, metaclass=ABCMeta):

    # Note: our notion of "equality" is that output generated by
    # equal serializers can be deserialized using the same serializer.

    # This default implementation handles the simple cases;
    # subclasses should override __eq__ as appropriate.

    def __eq__(self, other):
        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

    def __ne__(self, other):
        return not self.__eq__(other)

    def __repr__(self):
        return "%s()" % self.__class__.__name__

    def __hash__(self):
        return hash(str(self))

    @abstractmethod
    def dump_to_stream(self, iterator, stream):
        """
        Serializes an iterator of objects to the output stream.
        """
        pass

    @abstractmethod
    def load_from_stream(self, stream):
        """
        Returns an iterator of deserialized objects from the input stream.
        """
        pass

    def _load_from_stream_without_unbatching(self, stream):
        """
        Returns an iterator of deserialized batches (iterable) of objects from the input stream.
        If the serializer does not operate on batches the default implementation returns an
        iterator of single element lists.
        """
        return map(lambda x: [x], self.load_from_stream(stream))


class VarLengthDataSerializer(Serializer):
    """
    Serializer that writes objects as a stream of (length, data) pairs,
    where length is a 32-bit integer and data is length bytes.
    """

    def dump_to_stream(self, iterator, stream):
        for obj in iterator:
            self._write_with_length(obj, stream)

    def load_from_stream(self, stream):
        while True:
            try:
                yield self._read_with_length(stream)
            except EOFError:
                return

    def _write_with_length(self, obj, stream):
        serialized = self.dumps(obj)
        if serialized is None:
            raise ValueError("Serialized value should not be None")
        if len(serialized) > (1 << 31):
            raise ValueError("Can not serialize object larger than 2G")
        write_int(len(serialized), stream)
        stream.write(serialized)

    def _read_with_length(self, stream):
        length = read_int(stream)
        if length == SpecialLengths.END_OF_DATA_SECTION:
            raise EOFError
        elif length == SpecialLengths.NULL:
            return None
        obj = stream.read(length)
        if len(obj) < length:
            raise EOFError
        return self.loads(obj)

    @abstractmethod
    def dumps(self, obj):
        """
        Serialize an object into a byte array.
        When batching is used, this will be called with an array of objects.
        """
        pass

    @abstractmethod
    def loads(self, obj):
        """
        Deserialize an object from a byte array.
        """
        pass


class PickleSerializer(VarLengthDataSerializer):
    """
    Serializes objects using Python's pickle serializer:

        http://docs.python.org/3/library/pickle.html

    This serializer supports nearly any Python object, but may
    not be as fast as more specialized serializers.
    """

    def dumps(self, obj):
        return pickle.dumps(obj, 3)

    def loads(self, obj):
        return pickle.loads(obj, encoding="bytes")


class BatchedSerializer(Serializer):
    """
    Serializes a stream of objects in batches by calling its wrapped
    Serializer with streams of objects.
    """

    UNLIMITED_BATCH_SIZE = -1
    UNKNOWN_BATCH_SIZE = 0

    def __init__(self, serializer, batch_size=UNLIMITED_BATCH_SIZE):
        self.serializer = serializer
        self.batch_size = batch_size

    def __repr__(self):
        return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batch_size)

    def _batched(self, iterator):
        if self.batch_size == self.UNLIMITED_BATCH_SIZE:
            yield list(iterator)
        elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
            n = len(iterator)
            for i in range(0, n, self.batch_size):
                yield iterator[i: i + self.batch_size]
        else:
            items = []
            count = 0
            for item in iterator:
                items.append(item)
                count += 1
                if count == self.batch_size:
                    yield items
                    items = []
                    count = 0
            if items:
                yield items

    def dump_to_stream(self, iterator, stream):
        self.serializer.dump_to_stream(self._batched(iterator), stream)

    def load_from_stream(self, stream):
        return chain.from_iterable(self._load_from_stream_without_unbatching(stream))

    def _load_from_stream_without_unbatching(self, stream):
        return self.serializer.load_from_stream(stream)


def read_int(stream):
    length = stream.read(4)
    if not length:
        raise EOFError
    return struct.unpack("!i", length)[0]


def write_int(value, stream):
    stream.write(struct.pack("!i", value))
