# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import textwrap
import unittest
from unittest import mock
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices

import pandas
import pytest
from sqlalchemy.engine.url import make_url

import tests.test_app
from superset import app, db as metadata_db
from superset.models.core import Database
from superset.models.slice import Slice
from superset.utils.core import get_example_database, QueryStatus

from .base_tests import SupersetTestCase
from .fixtures.energy_dashboard import load_energy_table_with_slice


class TestDatabaseModel(SupersetTestCase):
    @unittest.skipUnless(
        SupersetTestCase.is_module_installed("requests"), "requests not installed"
    )
    def test_database_schema_presto(self):
        sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default"
        model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEqual("hive/default", db)

        db = make_url(model.get_sqla_engine(schema="core_db").url).database
        self.assertEqual("hive/core_db", db)

        sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
        model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEqual("hive", db)

        db = make_url(model.get_sqla_engine(schema="core_db").url).database
        self.assertEqual("hive/core_db", db)

    def test_database_schema_postgres(self):
        sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
        model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEqual("prod", db)

        db = make_url(model.get_sqla_engine(schema="foo").url).database
        self.assertEqual("prod", db)

    @unittest.skipUnless(
        SupersetTestCase.is_module_installed("thrift"), "thrift not installed"
    )
    @unittest.skipUnless(
        SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
    )
    def test_database_schema_hive(self):
        sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL"
        model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
        db = make_url(model.get_sqla_engine().url).database
        self.assertEqual("default", db)

        db = make_url(model.get_sqla_engine(schema="core_db").url).database
        self.assertEqual("core_db", db)

    @unittest.skipUnless(
        SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
    )
    def test_database_schema_mysql(self):
        sqlalchemy_uri = "mysql://root@localhost/superset"
        model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEqual("superset", db)

        db = make_url(model.get_sqla_engine(schema="staging").url).database
        self.assertEqual("staging", db)

    @unittest.skipUnless(
        SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
    )
    def test_database_impersonate_user(self):
        uri = "mysql://root@localhost"
        example_user = "giuseppe"
        model = Database(database_name="test_database", sqlalchemy_uri=uri)

        model.impersonate_user = True
        user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
        self.assertEqual(example_user, user_name)

        model.impersonate_user = False
        user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
        self.assertNotEqual(example_user, user_name)

    @mock.patch("superset.models.core.create_engine")
    def test_impersonate_user_presto(self, mocked_create_engine):
        uri = "presto://localhost"
        principal_user = "logged_in_user"
        extra = """
                {
                    "metadata_params": {},
                    "engine_params": {
                               "connect_args":{
                                  "protocol": "https",
                                  "username":"original_user",
                                  "password":"original_user_password"
                               }
                    },
                    "metadata_cache_timeout": {},
                    "schemas_allowed_for_csv_upload": []
                }
                """

        model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)

        model.impersonate_user = True
        model.get_sqla_engine(user_name=principal_user)
        call_args = mocked_create_engine.call_args

        assert str(call_args[0][0]) == "presto://logged_in_user@localhost"

        assert call_args[1]["connect_args"] == {
            "protocol": "https",
            "username": "original_user",
            "password": "original_user_password",
            "principal_username": "logged_in_user",
        }

        model.impersonate_user = False
        model.get_sqla_engine(user_name=principal_user)
        call_args = mocked_create_engine.call_args

        assert str(call_args[0][0]) == "presto://localhost"

        assert call_args[1]["connect_args"] == {
            "protocol": "https",
            "username": "original_user",
            "password": "original_user_password",
        }

    @mock.patch("superset.models.core.create_engine")
    def test_impersonate_user_hive(self, mocked_create_engine):
        uri = "hive://localhost"
        principal_user = "logged_in_user"
        extra = """
                {
                    "metadata_params": {},
                    "engine_params": {
                               "connect_args":{
                                  "protocol": "https",
                                  "username":"original_user",
                                  "password":"original_user_password"
                               }
                    },
                    "metadata_cache_timeout": {},
                    "schemas_allowed_for_csv_upload": []
                }
                """

        model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)

        model.impersonate_user = True
        model.get_sqla_engine(user_name=principal_user)
        call_args = mocked_create_engine.call_args

        assert str(call_args[0][0]) == "hive://localhost"

        assert call_args[1]["connect_args"] == {
            "protocol": "https",
            "username": "original_user",
            "password": "original_user_password",
            "configuration": {"hive.server2.proxy.user": "logged_in_user"},
        }

        model.impersonate_user = False
        model.get_sqla_engine(user_name=principal_user)
        call_args = mocked_create_engine.call_args

        assert str(call_args[0][0]) == "hive://localhost"

        assert call_args[1]["connect_args"] == {
            "protocol": "https",
            "username": "original_user",
            "password": "original_user_password",
        }

    @pytest.mark.usefixtures("load_energy_table_with_slice")
    def test_select_star(self):
        db = get_example_database()
        table_name = "energy_usage"
        sql = db.select_star(table_name, show_cols=False, latest_partition=False)
        quote = db.inspector.engine.dialect.identifier_preparer.quote_identifier
        expected = (
            textwrap.dedent(
                f"""\
        SELECT *
        FROM {quote(table_name)}
        LIMIT 100"""
            )
            if db.backend in {"presto", "hive"}
            else textwrap.dedent(
                f"""\
        SELECT *
        FROM {table_name}
        LIMIT 100"""
            )
        )
        assert expected in sql
        sql = db.select_star(table_name, show_cols=True, latest_partition=False)
        # TODO(bkyryliuk): unify sql generation
        if db.backend == "presto":
            assert (
                textwrap.dedent(
                    """\
                SELECT "source" AS "source",
                       "target" AS "target",
                       "value" AS "value"
                FROM "energy_usage"
                LIMIT 100"""
                )
                == sql
            )
        elif db.backend == "hive":
            assert (
                textwrap.dedent(
                    """\
                SELECT `source`,
                       `target`,
                       `value`
                FROM `energy_usage`
                LIMIT 100"""
                )
                == sql
            )
        else:
            assert (
                textwrap.dedent(
                    """\
                SELECT source,
                       target,
                       value
                FROM energy_usage
                LIMIT 100"""
                )
                in sql
            )

    def test_select_star_fully_qualified_names(self):
        db = get_example_database()
        schema = "schema.name"
        table_name = "table/name"
        sql = db.select_star(
            table_name, schema=schema, show_cols=False, latest_partition=False
        )
        fully_qualified_names = {
            "sqlite": '"schema.name"."table/name"',
            "mysql": "`schema.name`.`table/name`",
            "postgres": '"schema.name"."table/name"',
        }
        fully_qualified_name = fully_qualified_names.get(db.db_engine_spec.engine)
        if fully_qualified_name:
            expected = textwrap.dedent(
                f"""\
            SELECT *
            FROM {fully_qualified_name}
            LIMIT 100"""
            )
            assert sql.startswith(expected)

    def test_single_statement(self):
        main_db = get_example_database()

        if main_db.backend == "mysql":
            df = main_db.get_df("SELECT 1", None)
            self.assertEqual(df.iat[0, 0], 1)

            df = main_db.get_df("SELECT 1;", None)
            self.assertEqual(df.iat[0, 0], 1)

    def test_multi_statement(self):
        main_db = get_example_database()

        if main_db.backend == "mysql":
            df = main_db.get_df("USE superset; SELECT 1", None)
            self.assertEqual(df.iat[0, 0], 1)

            df = main_db.get_df("USE superset; SELECT ';';", None)
            self.assertEqual(df.iat[0, 0], ";")


class TestSqlaTableModel(SupersetTestCase):
    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
    def test_get_timestamp_expression(self):
        tbl = self.get_table_by_name("birth_names")
        ds_col = tbl.get_column("ds")
        sqla_literal = ds_col.get_timestamp_expression(None)
        self.assertEqual(str(sqla_literal.compile()), "ds")

        sqla_literal = ds_col.get_timestamp_expression("P1D")
        compiled = "{}".format(sqla_literal.compile())
        if tbl.database.backend == "mysql":
            self.assertEqual(compiled, "DATE(ds)")

        prev_ds_expr = ds_col.expression
        ds_col.expression = "DATE_ADD(ds, 1)"
        sqla_literal = ds_col.get_timestamp_expression("P1D")
        compiled = "{}".format(sqla_literal.compile())
        if tbl.database.backend == "mysql":
            self.assertEqual(compiled, "DATE(DATE_ADD(ds, 1))")
        ds_col.expression = prev_ds_expr

    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
    def test_get_timestamp_expression_epoch(self):
        tbl = self.get_table_by_name("birth_names")
        ds_col = tbl.get_column("ds")

        ds_col.expression = None
        ds_col.python_date_format = "epoch_s"
        sqla_literal = ds_col.get_timestamp_expression(None)
        compiled = "{}".format(sqla_literal.compile())
        if tbl.database.backend == "mysql":
            self.assertEqual(compiled, "from_unixtime(ds)")

        ds_col.python_date_format = "epoch_s"
        sqla_literal = ds_col.get_timestamp_expression("P1D")
        compiled = "{}".format(sqla_literal.compile())
        if tbl.database.backend == "mysql":
            self.assertEqual(compiled, "DATE(from_unixtime(ds))")

        prev_ds_expr = ds_col.expression
        ds_col.expression = "DATE_ADD(ds, 1)"
        sqla_literal = ds_col.get_timestamp_expression("P1D")
        compiled = "{}".format(sqla_literal.compile())
        if tbl.database.backend == "mysql":
            self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))")
        ds_col.expression = prev_ds_expr

    def query_with_expr_helper(self, is_timeseries, inner_join=True):
        tbl = self.get_table_by_name("birth_names")
        ds_col = tbl.get_column("ds")
        ds_col.expression = None
        ds_col.python_date_format = None
        spec = self.get_database_by_id(tbl.database_id).db_engine_spec
        if not spec.allows_joins and inner_join:
            # if the db does not support inner joins, we cannot force it so
            return None
        old_inner_join = spec.allows_joins
        spec.allows_joins = inner_join
        arbitrary_gby = "state || gender || '_test'"
        arbitrary_metric = dict(
            label="arbitrary", expressionType="SQL", sqlExpression="SUM(num_boys)"
        )
        query_obj = dict(
            groupby=[arbitrary_gby, "name"],
            metrics=[arbitrary_metric],
            filter=[],
            is_timeseries=is_timeseries,
            columns=[],
            granularity="ds",
            from_dttm=None,
            to_dttm=None,
            extras=dict(time_grain_sqla="P1Y"),
        )
        qr = tbl.query(query_obj)
        self.assertEqual(qr.status, QueryStatus.SUCCESS)
        sql = qr.query
        self.assertIn(arbitrary_gby, sql)
        self.assertIn("name", sql)
        if inner_join and is_timeseries:
            self.assertIn("JOIN", sql.upper())
        else:
            self.assertNotIn("JOIN", sql.upper())
        spec.allows_joins = old_inner_join
        self.assertFalse(qr.df.empty)
        return qr.df

    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
    def test_query_with_expr_groupby_timeseries(self):
        if get_example_database().backend == "presto":
            # TODO(bkyryliuk): make it work for presto.
            return

        def cannonicalize_df(df):
            ret = df.sort_values(by=list(df.columns.values), inplace=False)
            ret.reset_index(inplace=True, drop=True)
            return ret

        df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True)
        name_list1 = cannonicalize_df(df1).name.values.tolist()
        df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
        name_list2 = cannonicalize_df(df1).name.values.tolist()
        self.assertFalse(df2.empty)

        assert name_list2 == name_list1

    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
    def test_query_with_expr_groupby(self):
        self.query_with_expr_helper(is_timeseries=False)

    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
    def test_sql_mutator(self):
        tbl = self.get_table_by_name("birth_names")
        query_obj = dict(
            groupby=[],
            metrics=None,
            filter=[],
            is_timeseries=False,
            columns=["name"],
            granularity=None,
            from_dttm=None,
            to_dttm=None,
            extras={},
        )
        sql = tbl.get_query_str(query_obj)
        self.assertNotIn("-- COMMENT", sql)

        def mutator(*args):
            return "-- COMMENT\n" + args[0]

        app.config["SQL_QUERY_MUTATOR"] = mutator
        sql = tbl.get_query_str(query_obj)
        self.assertIn("-- COMMENT", sql)

        app.config["SQL_QUERY_MUTATOR"] = None

    def test_query_with_non_existent_metrics(self):
        tbl = self.get_table_by_name("birth_names")

        query_obj = dict(
            groupby=[],
            metrics=["invalid"],
            filter=[],
            is_timeseries=False,
            columns=["name"],
            granularity=None,
            from_dttm=None,
            to_dttm=None,
            extras={},
        )

        with self.assertRaises(Exception) as context:
            tbl.get_query_str(query_obj)

        self.assertTrue("Metric 'invalid' does not exist", context.exception)

    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
    def test_data_for_slices(self):
        tbl = self.get_table_by_name("birth_names")
        slc = (
            metadata_db.session.query(Slice)
            .filter_by(
                datasource_id=tbl.id,
                datasource_type=tbl.type,
                slice_name="Participants",
            )
            .first()
        )
        data_for_slices = tbl.data_for_slices([slc])
        self.assertEqual(len(data_for_slices["columns"]), 0)
        self.assertEqual(len(data_for_slices["metrics"]), 1)
        self.assertEqual(len(data_for_slices["verbose_map"].keys()), 2)
