# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC proxy, allows both client/server to connect and match connection.

In normal RPC, client directly connect to server's IP address.
Sometimes this cannot be done when server do not have a static address.
RPCProxy allows both client and server connect to the proxy server,
the proxy server will forward the message between the client and server.
"""
# pylint: disable=unused-variable, unused-argument
import os
import asyncio
import logging
import socket
import threading
import errno
import struct
import time

try:
    import tornado
    from tornado import gen
    from tornado import websocket
    from tornado import ioloop
    from . import tornado_util
except ImportError as error_msg:
    raise ImportError(
        "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg
    )

from tvm.contrib.popen_pool import PopenWorker
from . import _ffi_api
from . import base
from .base import TrackerCode
from .server import _server_env
from .._ffi.base import py_str


class ForwardHandler(object):
    """Forward handler to forward the message."""

    def _init_handler(self):
        """Initialize handler."""
        self._init_message = bytes()
        self._init_req_nbytes = 4
        self._magic = None
        self.timeout = None
        self._rpc_key_length = None
        self._done = False
        self._proxy = ProxyServerHandler.current
        assert self._proxy
        self.rpc_key = None
        self.match_key = None
        self.forward_proxy = None
        self.alloc_time = None

    def __del__(self):
        logging.info("Delete %s...", self.name())

    def name(self):
        """Name of this connection."""
        return "RPCConnection"

    def _init_step(self, message):
        if self._magic is None:
            assert len(message) == 4
            self._magic = struct.unpack("<i", message)[0]
            if self._magic != base.RPC_MAGIC:
                logging.info("Invalid RPC magic from %s", self.name())
                self.close()
            self._init_req_nbytes = 4
        elif self._rpc_key_length is None:
            assert len(message) == 4
            self._rpc_key_length = struct.unpack("<i", message)[0]
            self._init_req_nbytes = self._rpc_key_length
        elif self.rpc_key is None:
            assert len(message) == self._rpc_key_length
            self.rpc_key = py_str(message)
            # match key is used to do the matching
            self.match_key = self.rpc_key[7:].split()[0]
            self.on_start()
        else:
            assert False

    def on_start(self):
        """Event when the initialization is completed"""
        self._proxy.handler_ready(self)

    def on_data(self, message):
        """on data"""
        assert isinstance(message, bytes)
        if self.forward_proxy:
            self.forward_proxy.send_data(message)
        else:
            while message and self._init_req_nbytes > len(self._init_message):
                nbytes = self._init_req_nbytes - len(self._init_message)
                self._init_message += message[:nbytes]
                message = message[nbytes:]
                if self._init_req_nbytes == len(self._init_message):
                    temp = self._init_message
                    self._init_req_nbytes = 0
                    self._init_message = bytes()
                    self._init_step(temp)
            if message:
                logging.info("Invalid RPC protocol, too many bytes %s", self.name())
                self.close()

    def on_error(self, err):
        logging.info("%s: Error in RPC %s", self.name(), err)
        self.close_pair()

    def close_pair(self):
        if self.forward_proxy:
            self.forward_proxy.signal_close()
            self.forward_proxy = None
        self.close()

    def on_close_event(self):
        """on close event"""
        assert not self._done
        logging.info("RPCProxy:on_close_event %s ...", self.name())
        if self.match_key:
            key = self.match_key
            if self._proxy._client_pool.get(key, None) == self:
                self._proxy._client_pool.pop(key)
            if self._proxy._server_pool.get(key, None) == self:
                self._proxy._server_pool.pop(key)
        self._done = True
        self.forward_proxy = None


class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
    """Event driven TCP handler."""

    def __init__(self, sock, addr):
        super(TCPHandler, self).__init__(sock)
        self._init_handler()
        self.addr = addr

    def name(self):
        return "TCPSocketProxy:%s:%s" % (str(self.addr[0]), self.rpc_key)

    def send_data(self, message, binary=True):
        self.write_message(message, True)

    def on_message(self, message):
        self.on_data(message)

    def on_close(self):
        logging.info("RPCProxy: on_close %s ...", self.name())
        self._close_process = True

        if self.forward_proxy:
            self.forward_proxy.signal_close()
            self.forward_proxy = None
        self.on_close_event()


class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
    """Handler for websockets."""

    def __init__(self, *args, **kwargs):
        super(WebSocketHandler, self).__init__(*args, **kwargs)
        self._init_handler()

    def name(self):
        return "WebSocketProxy:%s" % (self.rpc_key)

    def on_message(self, message):
        self.on_data(message)

    def data_received(self, _):
        raise NotImplementedError()

    def send_data(self, message):
        try:
            self.write_message(message, True)
        except websocket.WebSocketClosedError as err:
            self.on_error(err)

    def on_close(self):
        logging.info("RPCProxy: on_close %s ...", self.name())
        if self.forward_proxy:
            self.forward_proxy.signal_close()
            self.forward_proxy = None
        self.on_close_event()

    def signal_close(self):
        self.close()


class RequestHandler(tornado.web.RequestHandler):
    """Handles html request."""

    def __init__(self, *args, **kwargs):
        file_path = kwargs.pop("file_path")
        if file_path.endswith("html"):
            self.page = open(file_path).read()
            web_port = kwargs.pop("rpc_web_port", None)
            if web_port:
                self.page = self.page.replace(
                    "ws://localhost:9190/ws", "ws://localhost:%d/ws" % web_port
                )
        else:
            self.page = open(file_path, "rb").read()
        super(RequestHandler, self).__init__(*args, **kwargs)

    def data_received(self, _):
        pass

    def get(self, *args, **kwargs):
        self.write(self.page)


class ProxyServerHandler(object):
    """Internal proxy server handler class."""

    current = None

    def __init__(
        self,
        sock,
        listen_port,
        web_port,
        timeout_client,
        timeout_server,
        tracker_addr,
        index_page=None,
        resource_files=None,
    ):
        assert ProxyServerHandler.current is None
        ProxyServerHandler.current = self
        if web_port:
            handlers = [
                (r"/ws", WebSocketHandler),
            ]
            if index_page:
                handlers.append(
                    (r"/", RequestHandler, {"file_path": index_page, "rpc_web_port": web_port})
                )
                logging.info("Serving RPC index html page at http://localhost:%d", web_port)
            resource_files = resource_files if resource_files else []
            for fname in resource_files:
                basename = os.path.basename(fname)
                pair = (r"/%s" % basename, RequestHandler, {"file_path": fname})
                handlers.append(pair)
                logging.info(pair)
            self.app = tornado.web.Application(handlers)
            self.app.listen(web_port)

        self.sock = sock
        self.sock.setblocking(0)
        self.loop = ioloop.IOLoop.current()

        def event_handler(_, events):
            self._on_event(events)

        self.loop.add_handler(self.sock.fileno(), event_handler, self.loop.READ)
        self._client_pool = {}
        self._server_pool = {}
        self.timeout_alloc = 5
        self.timeout_client = timeout_client
        self.timeout_server = timeout_server
        # tracker information
        self._listen_port = listen_port
        self._tracker_addr = tracker_addr
        self._tracker_conn = None
        self._tracker_pending_puts = []
        self._key_set = set()
        self.update_tracker_period = 2
        if tracker_addr:
            logging.info("Tracker address:%s", str(tracker_addr))

            def _callback():
                self._update_tracker(True)

            self.loop.call_later(self.update_tracker_period, _callback)
        logging.info("RPCProxy: Websock port bind to %d", web_port)

    def _on_event(self, _):
        while True:
            try:
                conn, addr = self.sock.accept()
                TCPHandler(conn, addr)
            except socket.error as err:
                if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
                    break

    def _pair_up(self, lhs, rhs):
        lhs.forward_proxy = rhs
        rhs.forward_proxy = lhs

        lhs.send_data(struct.pack("<i", base.RPC_CODE_SUCCESS))
        lhs.send_data(struct.pack("<i", len(rhs.rpc_key)))
        lhs.send_data(rhs.rpc_key.encode("utf-8"))

        rhs.send_data(struct.pack("<i", base.RPC_CODE_SUCCESS))
        rhs.send_data(struct.pack("<i", len(lhs.rpc_key)))
        rhs.send_data(lhs.rpc_key.encode("utf-8"))
        logging.info("Pairup connect %s  and %s", lhs.name(), rhs.name())

    def _regenerate_server_keys(self, keys):
        """Regenerate keys for server pool"""
        keyset = set(self._server_pool.keys())
        new_keys = []
        # re-generate the server match key, so old information is invalidated.
        for key in keys:
            rpc_key, _ = key.split(":")
            handle = self._server_pool[key]
            del self._server_pool[key]
            new_key = base.random_key(rpc_key + ":", keyset)
            self._server_pool[new_key] = handle
            keyset.add(new_key)
            new_keys.append(new_key)
        return new_keys

    def _update_tracker(self, period_update=False):
        """Update information on tracker."""
        try:
            if self._tracker_conn is None:
                self._tracker_conn = socket.socket(
                    base.get_addr_family(self._tracker_addr), socket.SOCK_STREAM
                )
                self._tracker_conn.connect(self._tracker_addr)
                self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
                magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
                if magic != base.RPC_TRACKER_MAGIC:
                    self.loop.stop()
                    raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr))
                # just connect to tracker, need to update all keys
                self._tracker_pending_puts = self._server_pool.keys()

            if self._tracker_conn and period_update:
                # periodically update tracker information
                # regenerate key if the key is not in tracker anymore
                # and there is no in-coming connection after timeout_alloc
                base.sendjson(self._tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
                pending_keys = set(base.recvjson(self._tracker_conn))
                update_keys = []
                for k, v in self._server_pool.items():
                    if k not in pending_keys:
                        if v.alloc_time is None:
                            v.alloc_time = time.time()
                        elif time.time() - v.alloc_time > self.timeout_alloc:
                            update_keys.append(k)
                            v.alloc_time = None
                if update_keys:
                    logging.info(
                        "RPCProxy: No incoming conn on %s, regenerate keys...", str(update_keys)
                    )
                    new_keys = self._regenerate_server_keys(update_keys)
                    self._tracker_pending_puts += new_keys

            need_update_info = False
            # report new connections
            for key in self._tracker_pending_puts:
                rpc_key = key.split(":")[0]
                base.sendjson(
                    self._tracker_conn, [TrackerCode.PUT, rpc_key, (self._listen_port, key), None]
                )
                assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
                if rpc_key not in self._key_set:
                    self._key_set.add(rpc_key)
                    need_update_info = True

            if need_update_info:
                keylist = "[" + ",".join(self._key_set) + "]"
                cinfo = {"key": "server:proxy" + keylist, "addr": [None, self._listen_port]}
                base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
                assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
            self._tracker_pending_puts = []
        except (socket.error, IOError) as err:
            logging.info(
                "Lost tracker connection: %s, try reconnect in %g sec",
                str(err),
                self.update_tracker_period,
            )
            self._tracker_conn.close()
            self._tracker_conn = None
            self._regenerate_server_keys(self._server_pool.keys())

        if period_update:

            def _callback():
                self._update_tracker(True)

            self.loop.call_later(self.update_tracker_period, _callback)

    def _handler_ready_tracker_mode(self, handler):
        """tracker mode to handle handler ready."""
        if handler.rpc_key.startswith("server:"):
            key = base.random_key(handler.match_key + ":", self._server_pool)
            handler.match_key = key
            self._server_pool[key] = handler
            self._tracker_pending_puts.append(key)
            self._update_tracker()
        else:
            if handler.match_key in self._server_pool:
                self._pair_up(self._server_pool.pop(handler.match_key), handler)
            else:
                handler.send_data(struct.pack("<i", base.RPC_CODE_MISMATCH))
                handler.signal_close()

    def _handler_ready_proxy_mode(self, handler):
        """Normal proxy mode when handler is ready."""
        if handler.rpc_key.startswith("server:"):
            pool_src, pool_dst = self._client_pool, self._server_pool
            timeout = self.timeout_server
        else:
            pool_src, pool_dst = self._server_pool, self._client_pool
            timeout = self.timeout_client

        key = handler.match_key
        if key in pool_src:
            self._pair_up(pool_src.pop(key), handler)
            return
        if key not in pool_dst:
            pool_dst[key] = handler

            def cleanup():
                """Cleanup client connection if timeout"""
                if pool_dst.get(key, None) == handler:
                    logging.info(
                        "Timeout client connection %s, cannot find match key=%s",
                        handler.name(),
                        key,
                    )
                    pool_dst.pop(key)
                    handler.send_data(struct.pack("<i", base.RPC_CODE_MISMATCH))
                    handler.signal_close()

            self.loop.call_later(timeout, cleanup)
        else:
            logging.info("Duplicate connection with same key=%s", key)
            handler.send_data(struct.pack("<i", base.RPC_CODE_DUPLICATE))
            handler.signal_close()

    def handler_ready(self, handler):
        """Report handler to be ready."""
        logging.info("Handler ready %s", handler.name())
        if self._tracker_addr:
            self._handler_ready_tracker_mode(handler)
        else:
            self._handler_ready_proxy_mode(handler)

    def run(self):
        """Run the proxy server"""
        ioloop.IOLoop.current().start()


def _proxy_server(
    listen_sock,
    listen_port,
    web_port,
    timeout_client,
    timeout_server,
    tracker_addr,
    index_page,
    resource_files,
):
    asyncio.set_event_loop(asyncio.new_event_loop())
    handler = ProxyServerHandler(
        listen_sock,
        listen_port,
        web_port,
        timeout_client,
        timeout_server,
        tracker_addr,
        index_page,
        resource_files,
    )
    handler.run()


class PopenProxyServerState(object):
    """Internal PopenProxy State for Popen"""

    current = None

    def __init__(
        self,
        host,
        port=9091,
        port_end=9199,
        web_port=0,
        timeout_client=600,
        timeout_server=600,
        tracker_addr=None,
        index_page=None,
        resource_files=None,
    ):

        sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
        self.port = None
        for my_port in range(port, port_end):
            try:
                sock.bind((host, my_port))
                self.port = my_port
                break
            except socket.error as sock_err:
                if sock_err.errno in [errno.EADDRINUSE]:
                    continue
                raise sock_err
        if not self.port:
            raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
        logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
        sock.listen(1)
        self.thread = threading.Thread(
            target=_proxy_server,
            args=(
                sock,
                self.port,
                web_port,
                timeout_client,
                timeout_server,
                tracker_addr,
                index_page,
                resource_files,
            ),
        )
        # start the server in a different thread
        # so we can return the port directly
        self.thread.start()


def _popen_start_proxy_server(
    host,
    port=9091,
    port_end=9199,
    web_port=0,
    timeout_client=600,
    timeout_server=600,
    tracker_addr=None,
    index_page=None,
    resource_files=None,
):
    # This is a function that will be sent to the
    # Popen worker to run on a separate process.
    # Create and start the server in a different thread
    state = PopenProxyServerState(
        host,
        port,
        port_end,
        web_port,
        timeout_client,
        timeout_server,
        tracker_addr,
        index_page,
        resource_files,
    )
    PopenProxyServerState.current = state
    # returns the port so that the main can get the port number.
    return state.port


class Proxy(object):
    """Start RPC proxy server on a separate process.

    Python implementation based on PopenWorker.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The TCP port to be bind to

    port_end : int, optional
        The end TCP port to search

    web_port : int, optional
        The http/websocket port of the server.

    timeout_client : float, optional
        Timeout of client until it sees a matching connection.

    timeout_server : float, optional
        Timeout of server until it sees a matching connection.

    tracker_addr: Tuple (str, int) , optional
        The address of RPC Tracker in tuple (host, ip) format.
        If is not None, the server will register itself to the tracker.

    index_page : str, optional
        Path to an index page that can be used to display at proxy index.

    resource_files : str, optional
        Path to local resources that can be included in the http request
    """

    def __init__(
        self,
        host,
        port=9091,
        port_end=9199,
        web_port=0,
        timeout_client=600,
        timeout_server=600,
        tracker_addr=None,
        index_page=None,
        resource_files=None,
    ):
        self.proc = PopenWorker()
        # send the function
        self.proc.send(
            _popen_start_proxy_server,
            [
                host,
                port,
                port_end,
                web_port,
                timeout_client,
                timeout_server,
                tracker_addr,
                index_page,
                resource_files,
            ],
        )
        # receive the port
        self.port = self.proc.recv()
        self.host = host

    def terminate(self):
        """Terminate the server process"""
        if self.proc:
            logging.info("Terminating Proxy Server...")
            self.proc.kill()
            self.proc = None

    def __del__(self):
        self.terminate()


def websocket_proxy_server(url, key=""):
    """Create a RPC server that uses an websocket that connects to a proxy.

    Parameters
    ----------
    url : str
        The url to be connected.

    key : str
        The key to identify the server.
    """

    def create_on_message(conn):
        def _fsend(data):
            data = bytes(data)
            conn.write_message(data, binary=True)
            return len(data)

        on_message = _ffi_api.CreateEventDrivenServer(_fsend, "WebSocketProxyServer", "%toinit")
        return on_message

    @gen.coroutine
    def _connect(key):
        conn = yield websocket.websocket_connect(url)
        on_message = create_on_message(conn)
        temp = _server_env(None)
        # Start connecton
        conn.write_message(struct.pack("<i", base.RPC_MAGIC), binary=True)
        key = "server:" + key
        conn.write_message(struct.pack("<i", len(key)), binary=True)
        conn.write_message(key.encode("utf-8"), binary=True)
        msg = yield conn.read_message()
        assert len(msg) >= 4
        magic = struct.unpack("<i", msg[:4])[0]
        if magic == base.RPC_CODE_DUPLICATE:
            raise RuntimeError("key: %s has already been used in proxy" % key)
        if magic == base.RPC_CODE_MISMATCH:
            logging.info("RPCProxy do not have matching client key %s", key)
        elif magic != base.RPC_CODE_SUCCESS:
            raise RuntimeError("%s is not RPC Proxy" % url)
        msg = msg[4:]

        logging.info("Connection established with remote")

        if msg:
            on_message(bytearray(msg), 3)

        while True:
            try:
                msg = yield conn.read_message()
                if msg is None:
                    break
                on_message(bytearray(msg), 3)
            except websocket.WebSocketClosedError as err:
                break
        logging.info("WebSocketProxyServer closed...")
        temp.remove()
        ioloop.IOLoop.current().stop()

    ioloop.IOLoop.current().spawn_callback(_connect, key)
    ioloop.IOLoop.current().start()
