# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""

from datetime import datetime, timedelta
from unittest import mock
import random
import string

import pytest
import prison
from sqlalchemy.sql import func

import tests.integration_tests.test_app  # noqa: F401
from superset import db, security_manager
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.utils.database import get_example_database, get_main_database
from superset.utils import json
from superset.models.sql_lab import Query

from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_SQLLAB_USERNAME

QUERIES_FIXTURE_COUNT = 10


class TestQueryApi(SupersetTestCase):
    def insert_query(
        self,
        database_id: int,
        user_id: int,
        client_id: str,
        sql: str = "",
        select_sql: str = "",
        executed_sql: str = "",
        limit: int = 100,
        progress: int = 100,
        rows: int = 100,
        tab_name: str = "",
        status: str = "success",
        changed_on: datetime = datetime(2020, 1, 1),
    ) -> Query:
        database = db.session.query(Database).get(database_id)
        user = db.session.query(security_manager.user_model).get(user_id)
        query = Query(
            database=database,
            user=user,
            client_id=client_id,
            sql=sql,
            select_sql=select_sql,
            executed_sql=executed_sql,
            limit=limit,
            progress=progress,
            rows=rows,
            tab_name=tab_name,
            status=status,
            changed_on=changed_on,
        )
        db.session.add(query)
        db.session.commit()
        return query

    @pytest.fixture
    def create_queries(self):
        with self.create_app().app_context():
            queries = []
            admin_id = self.get_user("admin").id
            alpha_id = self.get_user("alpha").id
            example_database_id = get_example_database().id
            main_database_id = get_main_database().id
            for cx in range(QUERIES_FIXTURE_COUNT - 1):
                queries.append(
                    self.insert_query(
                        example_database_id,
                        admin_id,
                        self.get_random_string(),
                        sql=f"SELECT col1, col2 from table{cx}",  # noqa: S608
                        rows=cx,
                        status=QueryStatus.SUCCESS
                        if (cx % 2) == 0
                        else QueryStatus.RUNNING,
                    )
                )
            queries.append(
                self.insert_query(
                    main_database_id,
                    alpha_id,
                    self.get_random_string(),
                    sql=f"SELECT col1, col2 from table{QUERIES_FIXTURE_COUNT}",  # noqa: S608
                    rows=QUERIES_FIXTURE_COUNT,
                    status=QueryStatus.SUCCESS,
                )
            )

            yield queries

            # rollback changes
            for query in queries:
                db.session.delete(query)
            db.session.commit()

    @staticmethod
    def get_random_string(length: int = 10):
        letters = string.ascii_letters
        return "".join(random.choice(letters) for i in range(length))  # noqa: S311

    def test_get_query(self):
        """
        Query API: Test get query
        """
        admin = self.get_user("admin")
        client_id = self.get_random_string()
        example_db = get_example_database()
        query = self.insert_query(
            example_db.id,
            admin.id,
            client_id,
            sql="SELECT col1, col2 from table1",
            select_sql="SELECT col1, col2 from table1",
            executed_sql="SELECT col1, col2 from table1 LIMIT 100",
        )
        self.login(ADMIN_USERNAME)
        uri = f"api/v1/query/{query.id}"
        rv = self.client.get(uri)
        assert rv.status_code == 200

        expected_result = {
            "database": {"id": example_db.id},
            "client_id": client_id,
            "end_result_backend_time": None,
            "error_message": None,
            "executed_sql": "SELECT col1, col2 from table1 LIMIT 100",
            "limit": 100,
            "progress": 100,
            "results_key": None,
            "rows": 100,
            "schema": None,
            "select_as_cta": None,
            "select_as_cta_used": False,
            "select_sql": "SELECT col1, col2 from table1",
            "sql": "SELECT col1, col2 from table1",
            "sql_editor_id": None,
            "status": "success",
            "tab_name": "",
            "tmp_schema_name": None,
            "tmp_table_name": None,
            "tracking_url": None,
        }
        data = json.loads(rv.data.decode("utf-8"))
        assert "changed_on" in data["result"]
        for key, value in data["result"].items():
            # We can't assert timestamp
            if key not in (
                "changed_on",
                "end_time",
                "start_running_time",
                "start_time",
                "id",
            ):
                assert value == expected_result[key]
        # rollback changes
        db.session.delete(query)
        db.session.commit()

    def test_get_query_not_found(self):
        """
        Query API: Test get query not found
        """
        admin = self.get_user("admin")
        client_id = self.get_random_string()
        query = self.insert_query(get_example_database().id, admin.id, client_id)
        max_id = db.session.query(func.max(Query.id)).scalar()
        self.login(ADMIN_USERNAME)
        uri = f"api/v1/query/{max_id + 1}"
        rv = self.client.get(uri)
        assert rv.status_code == 404

        db.session.delete(query)
        db.session.commit()

    def test_get_query_no_data_access(self):
        """
        Query API: Test get query without data access
        """
        gamma1 = self.create_user(
            "gamma_1", "password", "Gamma", email="gamma1@superset.org"
        )
        gamma2 = self.create_user(
            "gamma_2", "password", "Gamma", email="gamma2@superset.org"
        )
        # Add SQLLab role to these gamma users, so they have access to queries
        sqllab_role = self.get_role("sql_lab")
        gamma1.roles.append(sqllab_role)
        gamma2.roles.append(sqllab_role)

        gamma1_client_id = self.get_random_string()
        gamma2_client_id = self.get_random_string()
        query_gamma1 = self.insert_query(
            get_example_database().id, gamma1.id, gamma1_client_id
        )
        query_gamma2 = self.insert_query(
            get_example_database().id, gamma2.id, gamma2_client_id
        )

        # Gamma1 user, only sees their own queries
        self.login(username="gamma_1", password="password")  # noqa: S106
        uri = f"api/v1/query/{query_gamma2.id}"
        rv = self.client.get(uri)
        assert rv.status_code == 404
        uri = f"api/v1/query/{query_gamma1.id}"
        rv = self.client.get(uri)
        assert rv.status_code == 200

        # Gamma2 user, only sees their own queries
        self.logout()
        self.login(username="gamma_2", password="password")  # noqa: S106
        uri = f"api/v1/query/{query_gamma1.id}"
        rv = self.client.get(uri)
        assert rv.status_code == 404
        uri = f"api/v1/query/{query_gamma2.id}"
        rv = self.client.get(uri)
        assert rv.status_code == 200

        # Admin's have the "all query access" permission
        self.logout()
        self.login(ADMIN_USERNAME)
        uri = f"api/v1/query/{query_gamma1.id}"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        uri = f"api/v1/query/{query_gamma2.id}"
        rv = self.client.get(uri)
        assert rv.status_code == 200

        # rollback changes
        db.session.delete(query_gamma1)
        db.session.delete(query_gamma2)
        db.session.delete(gamma1)
        db.session.delete(gamma2)
        db.session.commit()

    @pytest.mark.usefixtures("create_queries")
    def test_get_list_query(self):
        """
        Query API: Test get list query
        """
        self.login(ADMIN_USERNAME)
        uri = "api/v1/query/"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["count"] == QUERIES_FIXTURE_COUNT
        # check expected columns
        assert sorted(list(data["result"][0].keys())) == [  # noqa: C414
            "changed_on",
            "database",
            "end_time",
            "executed_sql",
            "id",
            "rows",
            "schema",
            "sql",
            "sql_tables",
            "start_time",
            "status",
            "tab_name",
            "tmp_table_name",
            "tracking_url",
            "user",
        ]
        assert sorted(list(data["result"][0]["user"].keys())) == [  # noqa: C414
            "first_name",
            "id",
            "last_name",
        ]
        assert list(data["result"][0]["database"].keys()) == [
            "database_name",
        ]

    @pytest.mark.usefixtures("create_queries")
    def test_get_list_query_filter_sql(self):
        """
        Query API: Test get list query filter
        """
        self.login(ADMIN_USERNAME)
        arguments = {"filters": [{"col": "sql", "opr": "ct", "value": "table2"}]}
        uri = f"api/v1/query/?q={prison.dumps(arguments)}"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["count"] == 1

    @pytest.mark.usefixtures("create_queries")
    def test_get_list_query_filter_database(self):
        """
        Query API: Test get list query filter database
        """
        self.login(ADMIN_USERNAME)
        database_id = get_main_database().id
        arguments = {
            "filters": [{"col": "database", "opr": "rel_o_m", "value": database_id}]
        }
        uri = f"api/v1/query/?q={prison.dumps(arguments)}"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["count"] == 1

    @pytest.mark.usefixtures("create_queries")
    def test_get_list_query_filter_user(self):
        """
        Query API: Test get list query filter user
        """
        self.login(ADMIN_USERNAME)
        alpha_id = self.get_user("alpha").id
        arguments = {"filters": [{"col": "user", "opr": "rel_o_m", "value": alpha_id}]}
        uri = f"api/v1/query/?q={prison.dumps(arguments)}"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["count"] == 1

    @pytest.mark.usefixtures("create_queries")
    def test_get_list_query_filter_changed_on(self):
        """
        Query API: Test get list query filter changed_on
        """
        self.login(ADMIN_USERNAME)
        arguments = {
            "filters": [
                {"col": "changed_on", "opr": "lt", "value": "2020-02-01T00:00:00Z"},
                {"col": "changed_on", "opr": "gt", "value": "2019-12-30T00:00:00Z"},
            ]
        }
        uri = f"api/v1/query/?q={prison.dumps(arguments)}"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["count"] == QUERIES_FIXTURE_COUNT

    @pytest.mark.usefixtures("create_queries")
    def test_get_list_query_order(self):
        """
        Query API: Test get list query filter changed_on
        """
        self.login(ADMIN_USERNAME)
        order_columns = [
            "changed_on",
            "database.database_name",
            "rows",
            "schema",
            "sql",
            "tab_name",
            "user.first_name",
        ]

        for order_column in order_columns:
            arguments = {"order_column": order_column, "order_direction": "asc"}
            uri = f"api/v1/query/?q={prison.dumps(arguments)}"
            rv = self.client.get(uri)
            assert rv.status_code == 200

    def test_get_list_query_no_data_access(self):
        """
        Query API: Test get queries no data access
        """
        admin = self.get_user("admin")
        client_id = self.get_random_string()
        query = self.insert_query(
            get_example_database().id,
            admin.id,
            client_id,
            sql="SELECT col1, col2 from table1",
        )

        self.login(GAMMA_SQLLAB_USERNAME)
        arguments = {"filters": [{"col": "sql", "opr": "sw", "value": "SELECT col1"}]}
        uri = f"api/v1/query/?q={prison.dumps(arguments)}"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["count"] == 0

        # rollback changes
        db.session.delete(query)
        db.session.commit()

    def test_get_updated_since(self):
        """
        Query API: Test get queries updated since timestamp
        """
        now = datetime.utcnow()
        client_id = self.get_random_string()

        admin = self.get_user("admin")
        example_db = get_example_database()

        old_query = self.insert_query(
            example_db.id,
            admin.id,
            self.get_random_string(),
            sql="SELECT col1, col2 from table1",
            select_sql="SELECT col1, col2 from table1",
            executed_sql="SELECT col1, col2 from table1 LIMIT 100",
            changed_on=now - timedelta(days=3),
        )
        updated_query = self.insert_query(
            example_db.id,
            admin.id,
            client_id,
            sql="SELECT col1, col2 from table1",
            select_sql="SELECT col1, col2 from table1",
            executed_sql="SELECT col1, col2 from table1 LIMIT 100",
            changed_on=now - timedelta(days=1),
        )

        self.login(ADMIN_USERNAME)
        timestamp = datetime.timestamp(now - timedelta(days=2)) * 1000
        uri = f"api/v1/query/updated_since?q={prison.dumps({'last_updated_ms': timestamp})}"  # noqa: E501
        rv = self.client.get(uri)
        assert rv.status_code == 200

        expected_result = updated_query.to_dict()
        data = json.loads(rv.data.decode("utf-8"))
        assert len(data["result"]) == 1
        for key, value in data["result"][0].items():
            # We can't assert timestamp
            if key not in (
                "changed_on",
                "end_time",
                "start_running_time",
                "start_time",
                "id",
            ):
                assert value == expected_result[key]
        # rollback changes
        db.session.delete(old_query)
        db.session.delete(updated_query)
        db.session.commit()

    @mock.patch("superset.sql_lab.cancel_query")
    @mock.patch("superset.views.core.db.session")
    def test_stop_query_not_found(
        self, mock_superset_db_session, mock_sql_lab_cancel_query
    ):
        """
        Handles stop query when the DB engine spec does not
        have a cancel query method (with invalid client_id).
        """
        form_data = {"client_id": "foo2"}
        query_mock = mock.Mock()
        query_mock.return_value = None
        self.login(ADMIN_USERNAME)
        mock_superset_db_session.query().filter_by().one_or_none = query_mock
        mock_sql_lab_cancel_query.return_value = True
        rv = self.client.post(
            "/api/v1/query/stop",
            data=json.dumps(form_data),
            content_type="application/json",
        )

        assert rv.status_code == 404
        data = json.loads(rv.data.decode("utf-8"))
        assert data["message"] == "Query with client_id foo2 not found"

    @mock.patch("superset.sql_lab.cancel_query")
    @mock.patch("superset.views.core.db.session")
    def test_stop_query(self, mock_superset_db_session, mock_sql_lab_cancel_query):
        """
        Handles stop query when the DB engine spec does not
        have a cancel query method.
        """
        form_data = {"client_id": "foo"}
        query_mock = mock.Mock()
        query_mock.client_id = "foo"
        query_mock.status = QueryStatus.RUNNING
        self.login(ADMIN_USERNAME)
        mock_superset_db_session.query().filter_by().one_or_none().return_value = (
            query_mock
        )
        mock_sql_lab_cancel_query.return_value = True
        rv = self.client.post(
            "/api/v1/query/stop",
            data=json.dumps(form_data),
            content_type="application/json",
        )

        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["result"] == "OK"
