# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC server implementation.

Note
----
Server is TCP based with the following protocol:
- Initial handshake to the peer
  - [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
   - {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
import os
import ctypes
import socket
import select
import struct
import logging
import multiprocessing
import subprocess
import time
import sys
import signal
import platform
import tvm._ffi

from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
from tvm.runtime.module import load_module as _load_module
from tvm.contrib import util
from . import _ffi_api
from . import base
from .base import TrackerCode

logger = logging.getLogger("RPCServer")


def _server_env(load_library, work_path=None):
    """Server environment function return temp dir"""
    if work_path:
        temp = work_path
    else:
        temp = util.tempdir()

    # pylint: disable=unused-variable
    @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True)
    def get_workpath(path):
        return temp.relpath(path)

    @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True)
    def load_module(file_name):
        """Load module from remote side."""
        path = temp.relpath(file_name)
        m = _load_module(path)
        logger.info("load_module %s", path)
        return m

    @tvm._ffi.register_func("tvm.rpc.server.download_linked_module", override=True)
    def download_linked_module(file_name):
        """Load module from remote side."""
        # c++ compiler/linker
        cc = os.environ.get("CXX", "g++")

        # pylint: disable=import-outside-toplevel
        path = temp.relpath(file_name)

        if path.endswith(".o"):
            # Extra dependencies during runtime.
            from tvm.contrib import cc as _cc

            _cc.create_shared(path + ".so", path, cc=cc)
            path += ".so"
        elif path.endswith(".tar"):
            # Extra dependencies during runtime.
            from tvm.contrib import cc as _cc, tar as _tar

            tar_temp = util.tempdir(custom_path=path.replace(".tar", ""))
            _tar.untar(path, tar_temp.temp_dir)
            files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
            _cc.create_shared(path + ".so", files, cc=cc)
            path += ".so"
        elif path.endswith(".dylib") or path.endswith(".so"):
            pass
        else:
            raise RuntimeError("Do not know how to link %s" % file_name)
        logger.info("Send linked module %s to client", path)
        return bytearray(open(path, "rb").read())

    libs = []
    load_library = load_library.split(":") if load_library else []
    for file_name in load_library:
        file_name = find_lib_path(file_name)[0]
        libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
        logger.info("Load additional library %s", file_name)
    temp.libs = libs
    return temp


def _serve_loop(sock, addr, load_library, work_path=None):
    """Server loop"""
    sockfd = sock.fileno()
    temp = _server_env(load_library, work_path)
    _ffi_api.ServerLoop(sockfd)
    if not work_path:
        temp.remove()
    logger.info("Finish serving %s", addr)


def _parse_server_opt(opts):
    # parse client options
    ret = {}
    for kv in opts:
        if kv.startswith("-timeout="):
            ret["timeout"] = float(kv[9:])
    return ret


def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
    """Listening loop of the server master."""

    def _accept_conn(listen_sock, tracker_conn, ping_period=2):
        """Accept connection from the other places.

        Parameters
        ----------
        listen_sock: Socket
            The socket used by listening process.

        tracker_conn : connnection to tracker
            Tracker connection

        ping_period : float, optional
            ping tracker every k seconds if no connection is accepted.
        """
        old_keyset = set()
        # Report resource to tracker
        if tracker_conn:
            matchkey = base.random_key(rpc_key + ":")
            base.sendjson(tracker_conn, [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
            assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
        else:
            matchkey = rpc_key

        unmatch_period_count = 0
        unmatch_timeout = 4
        # Wait until we get a valid connection
        while True:
            if tracker_conn:
                trigger = select.select([listen_sock], [], [], ping_period)
                if not listen_sock in trigger[0]:
                    base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
                    pending_keys = base.recvjson(tracker_conn)
                    old_keyset.add(matchkey)
                    # if match key not in pending key set
                    # it means the key is acquired by a client but not used.
                    if matchkey not in pending_keys:
                        unmatch_period_count += 1
                    else:
                        unmatch_period_count = 0
                    # regenerate match key if key is acquired but not used for a while
                    if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
                        logger.info("no incoming connections, regenerate key ...")
                        matchkey = base.random_key(rpc_key + ":", old_keyset)
                        base.sendjson(
                            tracker_conn, [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]
                        )
                        assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
                        unmatch_period_count = 0
                    continue
            conn, addr = listen_sock.accept()
            magic = struct.unpack("<i", base.recvall(conn, 4))[0]
            if magic != base.RPC_MAGIC:
                conn.close()
                continue
            keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
            key = py_str(base.recvall(conn, keylen))
            arr = key.split()
            expect_header = "client:" + matchkey
            server_key = "server:" + rpc_key
            if arr[0] != expect_header:
                conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
                conn.close()
                logger.warning("mismatch key from %s", addr)
                continue
            conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
            conn.sendall(struct.pack("<i", len(server_key)))
            conn.sendall(server_key.encode("utf-8"))
            return conn, addr, _parse_server_opt(arr[1:])

    # Server logic
    tracker_conn = None
    while True:
        try:
            # step 1: setup tracker and report to tracker
            if tracker_addr and tracker_conn is None:
                tracker_conn = base.connect_with_retry(tracker_addr)
                tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
                magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
                if magic != base.RPC_TRACKER_MAGIC:
                    raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
                # report status of current queue
                cinfo = {"key": "server:" + rpc_key}
                base.sendjson(tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
                assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS

            # step 2: wait for in-coming connections
            conn, addr, opts = _accept_conn(sock, tracker_conn)
        except (socket.error, IOError):
            # retry when tracker is dropped
            if tracker_conn:
                tracker_conn.close()
                tracker_conn = None
            continue
        except RuntimeError as exc:
            raise exc

        # step 3: serving
        work_path = util.tempdir()
        logger.info("connection from %s", addr)
        server_proc = multiprocessing.Process(
            target=_serve_loop, args=(conn, addr, load_library, work_path)
        )
        server_proc.deamon = True
        server_proc.start()
        # close from our side.
        conn.close()
        # wait until server process finish or timeout
        server_proc.join(opts.get("timeout", None))
        if server_proc.is_alive():
            logger.info("Timeout in RPC session, kill..")
            # pylint: disable=import-outside-toplevel
            import psutil

            parent = psutil.Process(server_proc.pid)
            # terminate worker childs
            for child in parent.children(recursive=True):
                child.terminate()
            # terminate the worker
            server_proc.terminate()
        work_path.remove()


def _connect_proxy_loop(addr, key, load_library):
    key = "server:" + key
    retry_count = 0
    max_retry = 5
    retry_period = 5
    while True:
        try:
            sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM)
            sock.connect(addr)
            sock.sendall(struct.pack("<i", base.RPC_MAGIC))
            sock.sendall(struct.pack("<i", len(key)))
            sock.sendall(key.encode("utf-8"))
            magic = struct.unpack("<i", base.recvall(sock, 4))[0]
            if magic == base.RPC_CODE_DUPLICATE:
                raise RuntimeError("key: %s has already been used in proxy" % key)

            if magic == base.RPC_CODE_MISMATCH:
                logger.warning("RPCProxy do not have matching client key %s", key)
            elif magic != base.RPC_CODE_SUCCESS:
                raise RuntimeError("%s is not RPC Proxy" % str(addr))
            keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
            remote_key = py_str(base.recvall(sock, keylen))
            opts = _parse_server_opt(remote_key.split()[1:])
            logger.info("connected to %s", str(addr))
            process = multiprocessing.Process(target=_serve_loop, args=(sock, addr, load_library))
            process.deamon = True
            process.start()
            sock.close()
            process.join(opts.get("timeout", None))
            if process.is_alive():
                logger.info("Timeout in RPC session, kill..")
                process.terminate()
            retry_count = 0
        except (socket.error, IOError) as err:
            retry_count += 1
            logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
            if retry_count > max_retry:
                raise RuntimeError("Maximum retry error: last error: %s" % str(err))
            time.sleep(retry_period)


def _popen(cmd):
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=os.environ)
    (out, _) = proc.communicate()
    if proc.returncode != 0:
        msg = "Server invoke error:\n"
        msg += out
        raise RuntimeError(msg)


class Server(object):
    """Start RPC server on a separate process.

    This is a simple python implementation based on multi-processing.
    It is also possible to implement a similar C based server with
    TVM runtime which does not depend on the python.

    Parameters
    ----------
    host : str
        The host url of the server.

    port : int
        The port to be bind to

    port_end : int, optional
        The end port to search

    is_proxy : bool, optional
        Whether the address specified is a proxy.
        If this is true, the host and port actually corresponds to the
        address of the proxy server.

    use_popen : bool, optional
        Whether to use Popen to start a fresh new process instead of fork.
        This is recommended to switch on if we want to do local RPC demonstration
        for GPU devices to avoid fork safety issues.

    tracker_addr: Tuple (str, int) , optional
        The address of RPC Tracker in tuple(host, ip) format.
        If is not None, the server will register itself to the tracker.

    key : str, optional
        The key used to identify the device type in tracker.

    load_library : str, optional
        List of additional libraries to be loaded during execution.

    custom_addr: str, optional
        Custom IP Address to Report to RPC Tracker

    silent: bool, optional
        Whether run this server in silent mode.
    """

    def __init__(
        self,
        host,
        port=9091,
        port_end=9199,
        is_proxy=False,
        use_popen=False,
        tracker_addr=None,
        key="",
        load_library=None,
        custom_addr=None,
        silent=False,
        utvm_dev_id=None,
        utvm_dev_config_args=None,
    ):
        try:
            if _ffi_api.ServerLoop is None:
                raise RuntimeError("Please compile with USE_RPC=1")
        except NameError:
            raise RuntimeError("Please compile with USE_RPC=1")
        self.host = host
        self.port = port
        self.libs = []
        self.custom_addr = custom_addr
        self.use_popen = use_popen

        if silent:
            logger.setLevel(logging.ERROR)

        if use_popen:
            cmd = [
                sys.executable,
                "-m",
                "tvm.exec.rpc_server",
                "--host=%s" % host,
                "--port=%s" % port,
                "--port-end=%s" % port_end,
            ]
            if tracker_addr:
                assert key
                cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key]
            if load_library:
                cmd += ["--load-library", load_library]
            if custom_addr:
                cmd += ["--custom-addr", custom_addr]
            if silent:
                cmd += ["--silent"]
            if utvm_dev_id is not None:
                assert utvm_dev_config_args is not None
                cmd += [f"--utvm-dev-id={utvm_dev_id}"]
                cmd += [f"--utvm-dev-config-args={utvm_dev_config_args}"]

            # prexec_fn is not thread safe and may result in deadlock.
            # python 3.2 introduced the start_new_session parameter as
            # an alternative to the common use case of
            # prexec_fn=os.setsid.  Once the minimum version of python
            # supported by TVM reaches python 3.2 this code can be
            # rewritten in favour of start_new_session.  In the
            # interim, stop the pylint diagnostic.
            #
            # pylint: disable=subprocess-popen-preexec-fn
            if platform.system() == "Windows":
                self.proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
            else:
                self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
            time.sleep(0.5)
        elif not is_proxy:
            sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
            self.port = None
            for my_port in range(port, port_end):
                try:
                    sock.bind((host, my_port))
                    self.port = my_port
                    break
                except socket.error as sock_err:
                    if sock_err.errno in [98, 48]:
                        continue
                    raise sock_err
            if not self.port:
                raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
            logger.info("bind to %s:%d", host, self.port)
            sock.listen(1)
            self.sock = sock
            self.proc = multiprocessing.Process(
                target=_listen_loop,
                args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr),
            )
            self.proc.deamon = True
            self.proc.start()
        else:
            self.proc = multiprocessing.Process(
                target=_connect_proxy_loop, args=((host, port), key, load_library)
            )
            self.proc.deamon = True
            self.proc.start()

    def terminate(self):
        """Terminate the server process"""
        if self.use_popen:
            if self.proc:
                if platform.system() == "Windows":
                    os.kill(self.proc.pid, signal.CTRL_C_EVENT)
                else:
                    os.killpg(self.proc.pid, signal.SIGTERM)
                self.proc = None
        else:
            if self.proc:
                self.proc.terminate()
                self.proc = None

    def __del__(self):
        self.terminate()
