#!/usr/bin/env python3
# -*- encoding: utf-8 -*-

#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

"""Implementation of Heron's application level protocol for Python"""
import random
import socket
import struct

from heron.common.src.python.utils.log import Log

class HeronProtocol(object):
  """Heron's application level network protocol"""
  INT_PACK_FMT = ">I"
  HEADER_SIZE = 4

  @staticmethod
  def pack_int(i):
    """Packs int to bytestring"""
    return struct.pack(HeronProtocol.INT_PACK_FMT, i)

  @staticmethod
  def unpack_int(i):
    """Unpacks bytestring to int"""
    return struct.unpack(HeronProtocol.INT_PACK_FMT, i)[0]

  @staticmethod
  def get_size_to_pack_string(string):
    """Get size to pack string, four byte used for specifying length of the string"""
    return 4 + len(string)

  @staticmethod
  def get_size_to_pack_message(message):
    """Get size to pack message, four byte used for specifying length of the message"""
    return 4 + message.ByteSize()

  @staticmethod
  def read_new_packet(dispatcher):
    """Reads new packet and returns an IncomingPacket object

    :param dispatcher: asyncore's dispatcher class from which packets is read
    """
    packet = IncomingPacket()
    packet.read(dispatcher)
    return packet

  @staticmethod
  def decode_packet(packet):
    """Decodes an IncomingPacket object and returns (typename, reqid, serialized message)"""
    if not packet.is_complete:
      raise RuntimeError("In decode_packet(): Packet corrupted")

    data = packet.data

    len_typename = HeronProtocol.unpack_int(data[:4])
    data = data[4:]

    typename = data[:len_typename]
    data = data[len_typename:]

    reqid = REQID.unpack(data[:REQID.REQID_SIZE])
    data = data[REQID.REQID_SIZE:]

    len_msg = HeronProtocol.unpack_int(data[:4])
    data = data[4:]

    serialized_msg = data[:len_msg]

    return typename, reqid, serialized_msg

class OutgoingPacket(object):
  """Wrapper class for outgoing packet"""
  def __init__(self, raw_data):
    self.raw = str(raw_data)
    self.to_send = str(raw_data)

  def __len__(self):
    return len(self.raw)

  @staticmethod
  def create_packet(reqid, message):
    """Creates Outgoing Packet from a given reqid and message

    :param reqid: REQID object
    :param message: protocol buffer object
    """
    assert message.IsInitialized()
    packet = ''

    # calculate the totla size of the packet incl. header
    typename = message.DESCRIPTOR.full_name

    datasize = HeronProtocol.get_size_to_pack_string(typename) + \
               REQID.REQID_SIZE + HeronProtocol.get_size_to_pack_message(message)

    # first write out how much data is there as the header
    packet += HeronProtocol.pack_int(datasize)

    # next write the type string
    packet += HeronProtocol.pack_int(len(typename))
    packet += typename

    # reqid
    packet += reqid.pack()

    # add the proto
    packet += HeronProtocol.pack_int(message.ByteSize())
    packet += message.SerializeToString()
    return OutgoingPacket(packet)

  @property
  def sent_complete(self):
    """Indicates whether this packet is successfully sent"""
    return len(self.to_send) == 0

  def send(self, dispatcher):
    """Sends this outgoing packet to dispatcher's socket"""
    if self.sent_complete:
      return

    sent = dispatcher.send(self.to_send)
    self.to_send = self.to_send[sent:]

class IncomingPacket(object):
  """Helper class for incoming packet"""
  def __init__(self):
    """Initializes IncomingPacket object"""
    self.header = ''
    self.data = ''
    self.is_header_read = False
    self.is_complete = False
    # for debugging identification purposes
    self.id = random.getrandbits(8)

  @staticmethod
  def create_packet(header, data):
    """Creates an IncomingPacket object from header and data

    This method is for testing purposes
    """
    packet = IncomingPacket()
    packet.header = header
    packet.data = data

    if len(header) == HeronProtocol.HEADER_SIZE:
      packet.is_header_read = True
      if len(data) == packet.get_datasize():
        packet.is_complete = True

    return packet

  def convert_to_raw(self):
    """Converts this IncomingPacket object to raw string

    This method is for testing purposes
    """
    return self.header + self.data

  def get_datasize(self):
    """Returns the datasize of the packet

    :returns: (int) datasize of the packet, or -1 if header is incomplete
    """
    if not self.is_header_read:
      return -1
    return HeronProtocol.unpack_int(self.header)

  def get_pktsize(self):
    """Returns the size of this packet, including header and data"""
    return len(self.data) + len(self.header)

  def read(self, dispatcher):
    """Reads incoming data from asyncore.dispatcher"""
    try:
      if not self.is_header_read:
        # try reading header
        to_read = HeronProtocol.HEADER_SIZE - len(self.header)
        self.header += dispatcher.recv(to_read)
        if len(self.header) == HeronProtocol.HEADER_SIZE:
          self.is_header_read = True
        else:
          Log.debug("Header read incomplete; read %d bytes of header" % len(self.header))
          return

      if self.is_header_read and not self.is_complete:
        # try reading data
        to_read = self.get_datasize() - len(self.data)
        self.data += dispatcher.recv(to_read)
        if len(self.data) == self.get_datasize():
          self.is_complete = True
    except socket.error as e:
      if e.errno == socket.errno.EAGAIN or e.errno == socket.errno.EWOULDBLOCK:
        # Try again later -> call continue_read later
        Log.debug("Try again error")
      else:
        # Fatal error
        Log.debug("Fatal error when reading IncomingPacket")
        raise RuntimeError("Fatal error occured in IncomingPacket.read()")

  def __str__(self):
    return "Packet ID: %s, header: %s, complete: %s" % \
           (str(self.id), self.is_header_read, self.is_complete)


class REQID(object):
  """Helper class for REQID"""
  REQID_SIZE = 32

  def __init__(self, data_bytes):
    self.bytes = data_bytes

  @staticmethod
  def generate():
    """Generates a random REQID for request"""
    data_bytes = bytearray(random.getrandbits(8) for i in range(REQID.REQID_SIZE))
    return REQID(data_bytes)

  @staticmethod
  def generate_zero():
    """Generates a zero REQID for message"""
    data_bytes = bytearray(0 for i in range(REQID.REQID_SIZE))
    return REQID(data_bytes)

  def pack(self):
    """Packs this REQID to bytestring"""
    return self.bytes

  def is_zero(self):
    """Checks if this REQID is zero"""
    return self == REQID.generate_zero()

  @staticmethod
  def unpack(raw_data):
    """Unpacks a given bytestring and returns REQID object"""
    return REQID(bytearray(raw_data))

  def __eq__(self, another):
    return hasattr(another, 'bytes') and self.bytes == another.bytes

  def __hash__(self):
    return hash(self.__str__())

  def __str__(self):
    if self.is_zero():
      return "ZERO"
    else:
      return ''.join([str(i) for i in list(self.bytes)])

class StatusCode(object):
  """StatusCode for Response"""
  OK = 0
  WRITE_ERROR = 1
  READ_ERROR = 2
  INVALID_PACKET = 3
  CONNECT_ERROR = 4
  CLOSE_ERROR = 5
  TIMEOUT_ERROR = 6
