# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import json
import re
import socket
import threading
import traceback
import Queue
import StringIO
import wsgiref.simple_server

import bottle
from aria import modeling

from .. import exceptions


class CtxProxy(object):

    def __init__(self, ctx, ctx_patcher=(lambda *args, **kwargs: None)):
        self.ctx = ctx
        self._ctx_patcher = ctx_patcher
        self.port = _get_unused_port()
        self.socket_url = 'http://localhost:{0}'.format(self.port)
        self.server = None
        self._started = Queue.Queue(1)
        self.thread = self._start_server()
        self._started.get(timeout=5)

    def _start_server(self):

        class BottleServerAdapter(bottle.ServerAdapter):
            proxy = self

            def close_session(self):
                self.proxy.ctx.model.log._session.remove()

            def run(self, app):

                class Server(wsgiref.simple_server.WSGIServer):
                    allow_reuse_address = True
                    bottle_server = self

                    def handle_error(self, request, client_address):
                        pass

                    def serve_forever(self, poll_interval=0.5):
                        try:
                            wsgiref.simple_server.WSGIServer.serve_forever(self, poll_interval)
                        finally:
                            # Once shutdown is called, we need to close the session.
                            # If the session is not closed properly, it might raise warnings,
                            # or even lock the database.
                            self.bottle_server.close_session()

                class Handler(wsgiref.simple_server.WSGIRequestHandler):
                    def address_string(self):
                        return self.client_address[0]

                    def log_request(*args, **kwargs):  # pylint: disable=no-method-argument
                        if not self.quiet:
                            return wsgiref.simple_server.WSGIRequestHandler.log_request(*args,
                                                                                        **kwargs)
                server = wsgiref.simple_server.make_server(
                    host=self.host,
                    port=self.port,
                    app=app,
                    server_class=Server,
                    handler_class=Handler)
                self.proxy.server = server
                self.proxy._started.put(True)
                server.serve_forever(poll_interval=0.1)

        def serve():
            # Since task is a thread_local object, we need to patch it inside the server thread.
            self._ctx_patcher(self.ctx)

            bottle_app = bottle.Bottle()
            bottle_app.post('/', callback=self._request_handler)
            bottle.run(
                app=bottle_app,
                host='localhost',
                port=self.port,
                quiet=True,
                server=BottleServerAdapter)
        thread = threading.Thread(target=serve)
        thread.start()
        return thread

    def close(self):
        if self.server:
            self.server.shutdown()
            self.server.server_close()

    def _request_handler(self):
        request = bottle.request.body.read()  # pylint: disable=no-member
        response = self._process(request)
        return bottle.LocalResponse(
            body=json.dumps(response, cls=modeling.utils.ModelJSONEncoder),
            status=200,
            headers={'content-type': 'application/json'}
        )

    def _process(self, request):
        try:
            with self.ctx.model.instrument(*self.ctx.INSTRUMENTATION_FIELDS):
                typed_request = json.loads(request)
                args = typed_request['args']
                payload = _process_ctx_request(self.ctx, args)
                result_type = 'result'
                if isinstance(payload, exceptions.ScriptException):
                    payload = dict(message=str(payload))
                    result_type = 'stop_operation'
                result = {'type': result_type, 'payload': payload}
        except Exception as e:
            traceback_out = StringIO.StringIO()
            traceback.print_exc(file=traceback_out)
            payload = {
                'type': type(e).__name__,
                'message': str(e),
                'traceback': traceback_out.getvalue()
            }
            result = {'type': 'error', 'payload': payload}

        return result

    def __enter__(self):
        return self

    def __exit__(self, *args, **kwargs):
        self.close()


def _process_ctx_request(ctx, args):
    current = ctx
    num_args = len(args)
    index = 0
    while index < num_args:
        arg = args[index]
        attr = _desugar_attr(current, arg)
        if attr:
            current = getattr(current, attr)
        elif isinstance(current, collections.MutableMapping):
            key = arg
            path_dict = _PathDictAccess(current)
            if index + 1 == num_args:
                # read dict prop by path
                value = path_dict.get(key)
                current = value
            elif index + 2 == num_args:
                # set dict prop by path
                value = args[index + 1]
                path_dict.set(key, value)
                current = None
            else:
                raise RuntimeError('Illegal argument while accessing dict')
            break
        elif callable(current):
            kwargs = {}
            remaining_args = args[index:]
            if isinstance(remaining_args[-1], collections.MutableMapping):
                kwargs = remaining_args[-1]
                remaining_args = remaining_args[:-1]
            current = current(*remaining_args, **kwargs)
            break
        else:
            raise RuntimeError('{0} cannot be processed in {1}'.format(arg, args))
        index += 1
    if callable(current):
        current = current()
    return current


def _desugar_attr(obj, attr):
    if not isinstance(attr, basestring):
        return None
    if hasattr(obj, attr):
        return attr
    attr = attr.replace('-', '_')
    if hasattr(obj, attr):
        return attr
    return None


class _PathDictAccess(object):
    pattern = re.compile(r"(.+)\[(\d+)\]")

    def __init__(self, obj):
        self.obj = obj

    def set(self, prop_path, value):
        obj, prop_name = self._get_parent_obj_prop_name_by_path(prop_path)
        obj[prop_name] = value

    def get(self, prop_path):
        value = self._get_object_by_path(prop_path)
        return value

    def _get_object_by_path(self, prop_path, fail_on_missing=True):
        # when setting a nested object, make sure to also set all the
        # intermediate path objects
        current = self.obj
        for prop_segment in prop_path.split('.'):
            match = self.pattern.match(prop_segment)
            if match:
                index = int(match.group(2))
                property_name = match.group(1)
                if property_name not in current:
                    self._raise_illegal(prop_path)
                if not isinstance(current[property_name], list):
                    self._raise_illegal(prop_path)
                current = current[property_name][index]
            else:
                if prop_segment not in current:
                    if fail_on_missing:
                        self._raise_illegal(prop_path)
                    else:
                        current[prop_segment] = {}
                current = current[prop_segment]
        return current

    def _get_parent_obj_prop_name_by_path(self, prop_path):
        split = prop_path.split('.')
        if len(split) == 1:
            return self.obj, prop_path
        parent_path = '.'.join(split[:-1])
        parent_obj = self._get_object_by_path(parent_path, fail_on_missing=False)
        prop_name = split[-1]
        return parent_obj, prop_name

    @staticmethod
    def _raise_illegal(prop_path):
        raise RuntimeError('illegal path: {0}'.format(prop_path))


def _get_unused_port():
    sock = socket.socket()
    sock.bind(('127.0.0.1', 0))
    _, port = sock.getsockname()
    sock.close()
    return port
