# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import random
import requests
from libcloud.common.base import Response
from libcloud.http import LibcloudConnection
from libcloud.utils.py3 import PY2

if PY2:
    from StringIO import StringIO
else:
    from io import StringIO

import requests_mock

from libcloud.utils.py3 import httplib
from libcloud.utils.py3 import urlparse
from libcloud.utils.py3 import parse_qs
from libcloud.utils.py3 import parse_qsl
from libcloud.utils.py3 import urlquote


XML_HEADERS = {"content-type": "application/xml"}


class LibcloudTestCase(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        self._visited_urls = []
        self._executed_mock_methods = []
        super(LibcloudTestCase, self).__init__(*args, **kwargs)

    def setUp(self):
        self._visited_urls = []
        self._executed_mock_methods = []

    def _add_visited_url(self, url):
        self._visited_urls.append(url)

    def _add_executed_mock_method(self, method_name):
        self._executed_mock_methods.append(method_name)

    def assertExecutedMethodCount(self, expected):
        actual = len(self._executed_mock_methods)
        self.assertEqual(
            actual,
            expected,
            "expected %d, but %d mock methods were executed" % (expected, actual),
        )


class multipleresponse(object):
    """
    A decorator that allows MockHttp objects to return multi responses
    """

    count = 0
    func = None

    def __init__(self, f):
        self.func = f

    def __call__(self, *args, **kwargs):
        ret = self.func(self.func.__class__, *args, **kwargs)
        response = ret[self.count]
        self.count = self.count + 1
        return response


class BodyStream(StringIO):
    def next(self, chunk_size=None):
        return StringIO.next(self)

    def __next__(self, chunk_size=None):
        return StringIO.__next__(self)

    def read(self, chunk_size=None):
        return StringIO.read(self)


class MockHttp(LibcloudConnection):
    """
    A mock HTTP client/server suitable for testing purposes. This replaces
    `HTTPConnection` by implementing its API and returning a mock response.

    Define methods by request path, replacing slashes (/) with underscores (_).
    Each of these mock methods should return a tuple of:

        (int status, str body, dict headers, str reason)
    """

    type = None
    use_param = None  # will use this param to namespace the request function
    test = None  # TestCase instance which is using this mock
    proxy_url = None

    def __init__(self, *args, **kwargs):
        # Load assertion methods into the class, incase people want to assert
        # within a response
        if isinstance(self, unittest.TestCase):
            unittest.TestCase.__init__(self, "__init__")
        super(MockHttp, self).__init__(*args, **kwargs)

    def _get_request(self, method, url, body=None, headers=None):
        # Find a method we can use for this request
        parsed = urlparse.urlparse(url)
        _, _, path, _, query, _ = parsed
        qs = parse_qs(query)
        if path.endswith("/"):
            path = path[:-1]
        meth_name = self._get_method_name(
            type=self.type, use_param=self.use_param, qs=qs, path=path
        )
        meth = getattr(self, meth_name.replace("%", "_"))

        if self.test and isinstance(self.test, LibcloudTestCase):
            self.test._add_visited_url(url=url)
            self.test._add_executed_mock_method(method_name=meth_name)
        return meth(method, url, body, headers)

    def request(self, method, url, body=None, headers=None, raw=False, stream=False):
        headers = self._normalize_headers(headers=headers)
        r_status, r_body, r_headers, r_reason = self._get_request(
            method, url, body, headers
        )
        if r_body is None:
            r_body = ""
        # this is to catch any special chars e.g. ~ in the request. URL
        url = urlquote(url)

        with requests_mock.mock() as m:
            m.register_uri(
                method,
                url,
                text=r_body,
                reason=r_reason,
                headers=r_headers,
                status_code=r_status,
            )
            try:
                super(MockHttp, self).request(
                    method=method,
                    url=url,
                    body=body,
                    headers=headers,
                    raw=raw,
                    stream=stream,
                )
            except requests_mock.exceptions.NoMockAddress as nma:
                raise AttributeError(
                    "Failed to mock out URL {0} - {1}".format(url, nma.request.url)
                )

    def prepared_request(
        self, method, url, body=None, headers=None, raw=False, stream=False
    ):
        headers = self._normalize_headers(headers=headers)
        r_status, r_body, r_headers, r_reason = self._get_request(
            method, url, body, headers
        )

        with requests_mock.mock() as m:
            m.register_uri(
                method,
                url,
                text=r_body,
                reason=r_reason,
                headers=r_headers,
                status_code=r_status,
            )
            super(MockHttp, self).prepared_request(
                method=method,
                url=url,
                body=body,
                headers=headers,
                raw=raw,
                stream=stream,
            )

    # Mock request/response example
    def _example(self, method, url, body, headers):
        """
        Return a simple message and header, regardless of input.
        """
        return (
            httplib.OK,
            "Hello World!",
            {"X-Foo": "libcloud"},
            httplib.responses[httplib.OK],
        )

    def _example_fail(self, method, url, body, headers):
        return (
            httplib.FORBIDDEN,
            "Oh Noes!",
            {"X-Foo": "fail"},
            httplib.responses[httplib.FORBIDDEN],
        )

    def _get_method_name(self, type, use_param, qs, path):
        path = path.split("?")[0]
        meth_name = (
            path.replace("/", "_")
            .replace(".", "_")
            .replace("-", "_")
            .replace("~", "%7E")
        )  # Python 3.7 no longer quotes ~

        if type:
            meth_name = "%s_%s" % (meth_name, self.type)

        if use_param and use_param in qs:
            param = qs[use_param][0].replace(".", "_").replace("-", "_")
            meth_name = "%s_%s" % (meth_name, param)

        if meth_name == "":
            meth_name = "root"

        return meth_name

    def assertUrlContainsQueryParams(self, url, expected_params, strict=False):
        """
        Assert that provided url contains provided query parameters.

        :param url: URL to assert.
        :type url: ``str``

        :param expected_params: Dictionary of expected query parameters.
        :type expected_params: ``dict``

        :param strict: Assert that provided url contains only expected_params.
                       (defaults to ``False``)
        :type strict: ``bool``
        """
        question_mark_index = url.find("?")

        if question_mark_index != -1:
            url = url[question_mark_index + 1 :]

        params = dict(parse_qsl(url))

        if strict:
            assert params == expected_params
        else:
            for key, value in expected_params.items():
                assert key in params
                assert params[key] == value


class MockConnection(object):
    def __init__(self, action):
        self.action = action


StorageMockHttp = MockHttp


def make_response(status=200, headers={}, connection=None):
    response = requests.Response()
    response.status_code = status
    response.headers = headers
    return Response(response, connection)


def generate_random_data(size):
    data = ""
    current_size = 0
    while current_size < size:
        value = str(random.randint(0, 9))
        value_size = len(value)
        data += value
        current_size += value_size
    return data


if __name__ == "__main__":
    import doctest

    doctest.testmod()
