#!/usr/bin/env python3
# -*- encoding: utf-8 -*-

#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

'''heron_client.py'''

import asyncore
import socket
import traceback
from abc import abstractmethod

import time
from heron.common.src.python.utils.log import Log
import heron.instance.src.python.utils.system_constants as constants
from heron.instance.src.python.network import HeronProtocol, REQID, StatusCode, OutgoingPacket

# pylint: disable=too-many-instance-attributes
# pylint: disable=fixme
class HeronClient(asyncore.dispatcher):
  """Python implementation of HeronClient, using asyncore module"""
  def __init__(self, looper, hostname, port, socket_map, socket_options):
    """Initializes HeronClient

    :type looper: ``GatewayLooper`` (heron.instance.src.python.network)
    :param looper: looper object
    :type hostname: str
    :param hostname: endpoint hostname
    :type port: int
    :param port: endpoint port
    :type socket_map: dict
    :param socket_map: socket map used for asyncore.dispatcher
    :type socket_options: ``SocketOptions`` (heron.common.src.python.network)
    :param socket_options: options for the socket and this client
    """
    asyncore.dispatcher.__init__(self, map=socket_map)
    self.looper = looper
    self.hostname = hostname
    self.port = int(port)
    self.endpoint = (self.hostname, self.port)
    self.out_buffer = []
    self.socket_options = socket_options

    # map <message name -> message.Message object>
    self.registered_message_map = dict()
    self.response_message_map = dict()
    self.context_map = dict()
    self.incomplete_pkt = None

    self.total_bytes_written = 0
    self.total_pkt_written = 0
    self.total_bytes_received = 0
    self.total_pkt_received = 0

    # for compatibility with 2.7.3
    self._connecting = False

    Log.debug("Initializing %s with endpoint: %s, \nsocket_map: %s, \nsocket_options: %s"
              % (self._get_classname(), str(self.endpoint),
                 str(socket_map), str(self.socket_options)))


  ##################################
  # asyncore.dispatcher override
  ##################################

  # called when connect is ready
  def handle_connect(self):
    Log.info("Connected to %s:%d" % (self.hostname, self.port))
    self._connecting = False
    self.on_connect(StatusCode.OK)

  # called when close is ready
  def handle_close(self):
    Log.info("%s: handle_close() called" % self._get_classname())
    self._handle_close()
    self.on_error()

  def _handle_close(self):
    self._clean_up_state()
    self.close()

  def _clean_up_state(self):
    self.out_buffer = []
    self.total_bytes_written = 0
    self.total_pkt_written = 0
    self.total_bytes_received = 0
    self.total_pkt_received = 0

    self.registered_message_map = dict()
    self.response_message_map = dict()
    self.context_map = dict()
    self.incomplete_pkt = None
    self._connecting = False

  # read bytes stream from socket and convert them into a list of IncomingPacket
  def handle_read(self):
    start_cycle_time = time.time()
    bytes_read = 0
    num_pkt_read = 0
    read_pkt_list = []

    read_batch_time_sec = self.socket_options.nw_read_batch_time_ms * constants.MS_TO_SEC
    read_batch_size_bytes = self.socket_options.nw_read_batch_size_bytes

    while (time.time() - start_cycle_time - read_batch_time_sec) < 0 and \
            bytes_read < read_batch_size_bytes:
      if self.incomplete_pkt is None:
        # incomplete packet doesn't exist
        pkt = HeronProtocol.read_new_packet(self)
        pkt.read(self)
      else:
        # continue reading into the incomplete packet
        Log.debug("In handle_read(): Continue reading")
        pkt = self.incomplete_pkt
        pkt.read(self)

      if pkt.is_complete:
        num_pkt_read += 1
        bytes_read += pkt.get_pktsize()
        Log.debug("Read a complete packet of size %d" % bytes_read)
        self.incomplete_pkt = None
        read_pkt_list.append(pkt)
      else:
        Log.debug("In handle_read(): Packet read not yet complete")
        self.incomplete_pkt = pkt
        break

    self.total_bytes_received += bytes_read
    self.total_pkt_received += num_pkt_read

    for pkt in read_pkt_list:
      self._handle_packet(pkt)

  def handle_write(self):
    if len(self.out_buffer) == 0:
      return
    start_cycle_time = time.time()
    bytes_written = 0
    num_pkt_written = 0

    write_batch_time_sec = self.socket_options.nw_write_batch_time_ms * constants.MS_TO_SEC
    write_batch_size_bytes = self.socket_options.nw_write_batch_size_bytes

    while (time.time() - start_cycle_time - write_batch_time_sec) < 0 and \
            bytes_written < write_batch_size_bytes and len(self.out_buffer) > 0:
      outgoing_pkt = self.out_buffer[0]
      outgoing_pkt.send(self)

      if outgoing_pkt.sent_complete:
        num_pkt_written += 1
        bytes_written += len(outgoing_pkt)
        self.out_buffer.remove(outgoing_pkt)
      else:
        # sending this packet not complete yet, will continue later
        break

    self.total_bytes_written += bytes_written
    self.total_pkt_written += num_pkt_written

  def writable(self):
    if self._connecting:
      return True
    return len(self.out_buffer) != 0

  def readable(self):
    return True

  #################################

  def start_connect(self):
    """Tries to connect to the Heron Server

    ``loop()`` method needs to be called after this.
    """
    Log.debug("In start_connect() of %s" % self._get_classname())
    # TODO: specify buffer size, exception handling
    self.create_socket(socket.AF_INET, socket.SOCK_STREAM)

    # when ready, handle_connect is called
    self._connecting = True
    self.connect(self.endpoint)

  def stop(self):
    """Disconnects and stops the client"""
    # TODO: cleanup things and close the connection
    self._handle_close()

  def register_on_message(self, msg_builder):
    """Registers protobuf message builders that this client wants to receive

    :param msg_builder: callable to create a protobuf message that this client wants to receive
    """
    message = msg_builder()
    Log.debug("In register_on_message(): %s" % message.DESCRIPTOR.full_name)
    self.registered_message_map[message.DESCRIPTOR.full_name] = msg_builder

  def send_request(self, request, context, response_type, timeout_sec):
    """Sends a request message (REQID is non-zero)"""
    # generates a unique request id
    reqid = REQID.generate()
    Log.debug("%s: In send_request() with REQID: %s" % (self._get_classname(), str(reqid)))
    # register response message type
    self.response_message_map[reqid] = response_type
    self.context_map[reqid] = context

    # Add timeout for this request if necessary
    if timeout_sec > 0:
      def timeout_task():
        self.handle_timeout(reqid)
      self.looper.register_timer_task_in_sec(timeout_task, timeout_sec)

    outgoing_pkt = OutgoingPacket.create_packet(reqid, request)
    self._send_packet(outgoing_pkt)

  def send_message(self, message):
    """Sends a message (REQID is zero)"""
    Log.debug("In send_message() of %s" % self._get_classname())
    outgoing_pkt = OutgoingPacket.create_packet(REQID.generate_zero(), message)
    self._send_packet(outgoing_pkt)

  def handle_timeout(self, reqid):
    """Handles timeout"""
    if reqid in self.context_map:
      context = self.context_map.pop(reqid)
      self.response_message_map.pop(reqid)
      self.on_response(StatusCode.TIMEOUT_ERROR, context, None)

  def handle_error(self):
    _, t, v, tbinfo = asyncore.compact_traceback()

    self_msg = "%s failed for object at %0x" % (self._get_classname(), id(self))
    Log.error("Uncaptured python exception, closing channel %s (%s:%s %s)" %
              (self_msg, t, v, tbinfo))

    if self._connecting:
      # Error when trying to connect
      # first cleanup by handle_close(), and tells a subclass about this error.
      # the subclass can then call start_connect() again, if appropriate
      self._handle_close()
      self.on_connect(StatusCode.CONNECT_ERROR)
    else:
      self._handle_close()
      self.on_error()

  def _handle_packet(self, packet):
    # only called when packet.is_complete is True
    # otherwise, it's just an message -- call on_incoming_message()
    typename, reqid, serialized_msg = HeronProtocol.decode_packet(packet)
    if reqid in self.context_map:
      # this incoming packet has the response of a request
      context = self.context_map.pop(reqid)
      response_msg = self.response_message_map.pop(reqid)

      try:
        response_msg.ParseFromString(serialized_msg)
      except Exception as e:
        Log.error("Invalid Packet Error: %s" % str(e))
        self._handle_close()
        self.on_error()
        return

      if response_msg.IsInitialized():
        self.on_response(StatusCode.OK, context, response_msg)
      else:
        Log.error("Response not initialized")
        self._handle_close()
        self.on_error()
    elif reqid.is_zero():
      # this is a Message -- no need to send back response
      try:
        if typename not in self.registered_message_map:
          raise ValueError("%s is not registered in message map" % typename)
        msg_builder = self.registered_message_map[typename]
        message = msg_builder()
        message.ParseFromString(serialized_msg)
        if message.IsInitialized():
          self.on_incoming_message(message)
        else:
          raise RuntimeError("Message not initialized")
      except Exception as e:
        Log.error("Error when handling message packet: %s" % str(e))
        Log.error(traceback.format_exc())
        raise RuntimeError("Problem reading message")
    else:
      # might be a timeout response
      Log.info("In handle_packet(): Received message whose REQID is not registered: %s"
               % str(reqid))

  def _send_packet(self, pkt):
    """Pushes a packet to a send buffer, the content of which will be send when available"""
    self.out_buffer.append(pkt)

  def _get_classname(self):
    return self.__class__.__name__

  ############################################################
  # Below are the interfaces to be implemented by a subclass #
  ############################################################

  @abstractmethod
  def on_connect(self, status):
    """Called when the client is connected

    Should be implemented by a subclass.
    """
    pass

  @abstractmethod
  def on_response(self, status, context, response):
    """Called when the client receives a response

    Should be implemented by a subclass.
    """
    pass

  @abstractmethod
  def on_incoming_message(self, message):
    """Called when the client receives a message

    Should be implemented by a subclass.
    """
    pass

  @abstractmethod
  def on_error(self):
    """Called when the client meets errors

    Note that this method is not called when a connection is not yet established.
    In such a case, ``on_connect()`` with status == StatusCode.CONNECT_ERROR is called.
    """
    pass
