# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: this testing tool is *nix specific

import os
import sys
import re
import contextlib
import subprocess
import signal
import math
from time import time
from . import basecase
from os.path import join, normpath


def is_win():
    return sys.platform in ("cygwin", "win32")

if is_win():
    from winpty import WinPty
    DEFAULT_PREFIX = ''
else:
    import pty
    DEFAULT_PREFIX = os.linesep

DEFAULT_CQLSH_PROMPT = DEFAULT_PREFIX + '(\S+@)?cqlsh(:\S+)?> '
DEFAULT_CQLSH_TERM = 'xterm'

cqlshlog = basecase.cqlshlog

def set_controlling_pty(master, slave):
    os.setsid()
    os.close(master)
    for i in range(3):
        os.dup2(slave, i)
    if slave > 2:
        os.close(slave)
    os.close(os.open(os.ttyname(1), os.O_RDWR))

@contextlib.contextmanager
def raising_signal(signum, exc):
    """
    Within the wrapped context, the given signal will interrupt signal
    calls and will raise the given exception class. The preexisting signal
    handling will be reinstated on context exit.
    """
    def raiser(signum, frames):
        raise exc()
    oldhandlr = signal.signal(signum, raiser)
    try:
        yield
    finally:
        signal.signal(signum, oldhandlr)

class TimeoutError(Exception):
    pass

@contextlib.contextmanager
def timing_out_itimer(seconds):
    if seconds is None:
        yield
        return
    with raising_signal(signal.SIGALRM, TimeoutError):
        oldval, oldint = signal.getitimer(signal.ITIMER_REAL)
        if oldval != 0.0:
            raise RuntimeError("ITIMER_REAL already in use")
        signal.setitimer(signal.ITIMER_REAL, seconds)
        try:
            yield
        finally:
            signal.setitimer(signal.ITIMER_REAL, 0)

@contextlib.contextmanager
def timing_out_alarm(seconds):
    if seconds is None:
        yield
        return
    with raising_signal(signal.SIGALRM, TimeoutError):
        oldval = signal.alarm(int(math.ceil(seconds)))
        if oldval != 0:
            signal.alarm(oldval)
            raise RuntimeError("SIGALRM already in use")
        try:
            yield
        finally:
            signal.alarm(0)

if is_win():
    try:
        import eventlet
    except ImportError, e:
        sys.exit("evenlet library required to run cqlshlib tests on Windows")

    def timing_out(seconds):
        return eventlet.Timeout(seconds, TimeoutError)
else:
    # setitimer is new in 2.6, but it's still worth supporting, for potentially
    # faster tests because of sub-second resolution on timeouts.
    if hasattr(signal, 'setitimer'):
        timing_out = timing_out_itimer
    else:
        timing_out = timing_out_alarm

def noop(*a):
    pass

class ProcRunner:
    def __init__(self, path, tty=True, env=None, args=()):
        self.exe_path = path
        self.args = args
        self.tty = bool(tty)
        self.realtty = self.tty and not is_win()
        if env is None:
            env = {}
        self.env = env
        self.readbuf = ''

        self.start_proc()

    def start_proc(self):
        preexec = noop
        stdin = stdout = stderr = None
        cqlshlog.info("Spawning %r subprocess with args: %r and env: %r"
                      % (self.exe_path, self.args, self.env))
        if self.realtty:
            masterfd, slavefd = pty.openpty()
            preexec = (lambda: set_controlling_pty(masterfd, slavefd))
            self.proc = subprocess.Popen((self.exe_path,) + tuple(self.args),
                                         env=self.env, preexec_fn=preexec,
                                         stdin=stdin, stdout=stdout, stderr=stderr,
                                         close_fds=False)
            os.close(slavefd)
            self.childpty = masterfd
            self.send = self.send_tty
            self.read = self.read_tty
        else:
            stdin = stdout = subprocess.PIPE
            stderr = subprocess.STDOUT
            self.proc = subprocess.Popen((self.exe_path,) + tuple(self.args),
                                         env=self.env, stdin=stdin, stdout=stdout,
                                         stderr=stderr, bufsize=0, close_fds=False)
            self.send = self.send_pipe
            if self.tty:
                self.winpty = WinPty(self.proc.stdout)
                self.read = self.read_winpty
            else:
                self.read = self.read_pipe

    def close(self):
        cqlshlog.info("Closing %r subprocess." % (self.exe_path,))
        if self.realtty:
            os.close(self.childpty)
        else:
            self.proc.stdin.close()
        cqlshlog.debug("Waiting for exit")
        return self.proc.wait()

    def send_tty(self, data):
        os.write(self.childpty, data)

    def send_pipe(self, data):
        self.proc.stdin.write(data)

    def read_tty(self, blksize, timeout=None):
        return os.read(self.childpty, blksize)

    def read_pipe(self, blksize, timeout=None):
        return self.proc.stdout.read(blksize)

    def read_winpty(self, blksize, timeout=None):
        return self.winpty.read(blksize, timeout)

    def read_until(self, until, blksize=4096, timeout=None,
                   flags=0, ptty_timeout=None):
        if not isinstance(until, re._pattern_type):
            until = re.compile(until, flags)

        cqlshlog.debug("Searching for %r" % (until.pattern,))
        got = self.readbuf
        self.readbuf = ''
        with timing_out(timeout):
            while True:
                val = self.read(blksize, ptty_timeout)
                cqlshlog.debug("read %r from subproc" % (val,))
                if val == '':
                    raise EOFError("'until' pattern %r not found" % (until.pattern,))
                got += val
                m = until.search(got)
                if m is not None:
                    self.readbuf = got[m.end():]
                    got = got[:m.end()]
                    return got

    def read_lines(self, numlines, blksize=4096, timeout=None):
        lines = []
        with timing_out(timeout):
            for n in range(numlines):
                lines.append(self.read_until('\n', blksize=blksize))
        return lines

    def read_up_to_timeout(self, timeout, blksize=4096):
        got = self.readbuf
        self.readbuf = ''
        curtime = time()
        stoptime = curtime + timeout
        while curtime < stoptime:
            try:
                with timing_out(stoptime - curtime):
                    stuff = self.read(blksize)
            except TimeoutError:
                break
            cqlshlog.debug("read %r from subproc" % (stuff,))
            if stuff == '':
                break
            got += stuff
            curtime = time()
        return got

class CqlshRunner(ProcRunner):
    def __init__(self, path=None, host=None, port=None, keyspace=None, cqlver=None,
                 args=(), prompt=DEFAULT_CQLSH_PROMPT, env=None,
                 win_force_colors=True, tty=True, **kwargs):
        if path is None:
            cqlsh_bin = 'cqlsh'
            if is_win():
                cqlsh_bin = 'cqlsh.bat'
            path = normpath(join(basecase.cqlshdir, cqlsh_bin))
        if host is None:
            host = basecase.TEST_HOST
        if port is None:
            port = basecase.TEST_PORT
        if env is None:
            env = {}
        if is_win():
            env['PYTHONUNBUFFERED'] = '1'
            env.update(os.environ.copy())
        env.setdefault('TERM', 'xterm')
        env.setdefault('CQLSH_NO_BUNDLED', os.environ.get('CQLSH_NO_BUNDLED', ''))
        env.setdefault('PYTHONPATH', os.environ.get('PYTHONPATH', ''))
        args = tuple(args) + (host, str(port))
        if cqlver is not None:
            args += ('--cqlversion', str(cqlver))
        if keyspace is not None:
            args += ('--keyspace', keyspace)
        if tty and is_win():
            args += ('--tty',)
            args += ('--encoding', 'utf-8')
            if win_force_colors:
                args += ('--color',)
        self.keyspace = keyspace
        ProcRunner.__init__(self, path, tty=tty, args=args, env=env, **kwargs)
        self.prompt = prompt
        if self.prompt is None:
            self.output_header = ''
        else:
            self.output_header = self.read_to_next_prompt()

    def read_to_next_prompt(self):
        return self.read_until(self.prompt, timeout=10.0, ptty_timeout=3)

    def read_up_to_timeout(self, timeout, blksize=4096):
        output = ProcRunner.read_up_to_timeout(self, timeout, blksize=blksize)
        # readline trying to be friendly- remove these artifacts
        output = output.replace(' \r', '')
        output = output.replace('\r', '')
        return output

    def cmd_and_response(self, cmd):
        self.send(cmd + '\n')
        output = self.read_to_next_prompt()
        # readline trying to be friendly- remove these artifacts
        output = output.replace(' \r', '')
        output = output.replace('\r', '')
        output = output.replace(' \b', '')
        if self.realtty:
            echo, output = output.split('\n', 1)
            assert echo == cmd, "unexpected echo %r instead of %r" % (echo, cmd)
        try:
            output, promptline = output.rsplit('\n', 1)
        except ValueError:
            promptline = output
            output = ''
        assert re.match(self.prompt, DEFAULT_PREFIX + promptline), \
                'last line of output %r does not match %r?' % (promptline, self.prompt)
        return output + '\n'

def run_cqlsh(**kwargs):
    return contextlib.closing(CqlshRunner(**kwargs))

def call_cqlsh(**kwargs):
    kwargs.setdefault('prompt', None)
    proginput = kwargs.pop('input', '')
    kwargs['tty'] = False
    c = CqlshRunner(**kwargs)
    output, _ = c.proc.communicate(proginput)
    result = c.close()
    return output, result
