# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import textwrap
from typing import Dict, List, Tuple, Union

import pandas as pd
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import DateTime, String
from sqlalchemy.sql import column

from superset import app, db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.exceptions import NoDataException
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_database

from .helpers import (
    get_example_data,
    get_slice_json,
    get_table_connector_registry,
    merge_slice,
    misc_dash_slices,
    update_slice_ids,
)


def get_admin_user() -> User:
    admin = security_manager.find_user("admin")
    if admin is None:
        raise NoDataException(
            "Admin user does not exist. "
            "Please, check if test users are properly loaded "
            "(`superset load_test_users`)."
        )
    return admin


def gen_filter(
    subject: str, comparator: str, operator: str = "=="
) -> Dict[str, Union[bool, str]]:
    return {
        "clause": "WHERE",
        "comparator": comparator,
        "expressionType": "SIMPLE",
        "operator": operator,
        "subject": subject,
    }


def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
    pdf = pd.read_json(get_example_data("birth_names2.json.gz"))
    # TODO(bkyryliuk): move load examples data into the pytest fixture
    if database.backend == "presto":
        pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
        pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S")
    else:
        pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
    pdf = pdf.head(100) if sample else pdf

    pdf.to_sql(
        tbl_name,
        database.get_sqla_engine(),
        if_exists="replace",
        chunksize=500,
        dtype={
            # TODO(bkyryliuk): use TIMESTAMP type for presto
            "ds": DateTime if database.backend != "presto" else String(255),
            "gender": String(16),
            "state": String(10),
            "name": String(255),
        },
        method="multi",
        index=False,
    )
    print("Done loading table!")
    print("-" * 80)


def load_birth_names(
    only_metadata: bool = False, force: bool = False, sample: bool = False
) -> None:
    """Loading birth name dataset from a zip file in the repo"""
    tbl_name = "birth_names"
    database = get_example_database()
    table_exists = database.has_table_by_name(tbl_name)

    if not only_metadata and (not table_exists or force):
        load_data(tbl_name, database, sample=sample)

    table = get_table_connector_registry()
    obj = db.session.query(table).filter_by(table_name=tbl_name).first()
    if not obj:
        print(f"Creating table [{tbl_name}] reference")
        obj = table(table_name=tbl_name)
        db.session.add(obj)

    _set_table_metadata(obj, database)
    _add_table_metrics(obj)

    db.session.commit()

    slices, _ = create_slices(obj, admin_owner=True)
    create_dashboard(slices)


def _set_table_metadata(datasource: "BaseDatasource", database: "Database") -> None:
    datasource.main_dttm_col = "ds"  # type: ignore
    datasource.database = database
    datasource.filter_select_enabled = True
    datasource.fetch_metadata()


def _add_table_metrics(datasource: "BaseDatasource") -> None:
    if not any(col.column_name == "num_california" for col in datasource.columns):
        col_state = str(column("state").compile(db.engine))
        col_num = str(column("num").compile(db.engine))
        datasource.columns.append(
            TableColumn(
                column_name="num_california",
                expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
            )
        )

    if not any(col.metric_name == "sum__num" for col in datasource.metrics):
        col = str(column("num").compile(db.engine))
        datasource.metrics.append(
            SqlMetric(metric_name="sum__num", expression=f"SUM({col})")
        )

    for col in datasource.columns:
        if col.column_name == "ds":
            col.is_dttm = True  # type: ignore
            break


def create_slices(
    tbl: BaseDatasource, admin_owner: bool
) -> Tuple[List[Slice], List[Slice]]:
    metrics = [
        {
            "expressionType": "SIMPLE",
            "column": {"column_name": "num", "type": "BIGINT"},
            "aggregate": "SUM",
            "label": "Births",
            "optionName": "metric_11",
        }
    ]
    metric = "sum__num"

    defaults = {
        "compare_lag": "10",
        "compare_suffix": "o10Y",
        "limit": "25",
        "time_range": "No filter",
        "time_range_endpoints": ["inclusive", "exclusive"],
        "granularity_sqla": "ds",
        "groupby": [],
        "row_limit": app.config["ROW_LIMIT"],
        "since": "100 years ago",
        "until": "now",
        "viz_type": "table",
        "markup_type": "markdown",
    }

    admin = get_admin_user()
    if admin_owner:
        slice_props = dict(
            datasource_id=tbl.id,
            datasource_type="table",
            owners=[admin],
            created_by=admin,
        )
    else:
        slice_props = dict(
            datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin
        )

    print("Creating some slices")
    slices = [
        Slice(
            **slice_props,
            slice_name="Participants",
            viz_type="big_number",
            params=get_slice_json(
                defaults,
                viz_type="big_number",
                granularity_sqla="ds",
                compare_lag="5",
                compare_suffix="over 5Y",
                metric=metric,
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Genders",
            viz_type="pie",
            params=get_slice_json(
                defaults, viz_type="pie", groupby=["gender"], metric=metric
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Trends",
            viz_type="line",
            params=get_slice_json(
                defaults,
                viz_type="line",
                groupby=["name"],
                granularity_sqla="ds",
                rich_tooltip=True,
                show_legend=True,
                metrics=metrics,
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Genders by State",
            viz_type="dist_bar",
            params=get_slice_json(
                defaults,
                adhoc_filters=[
                    {
                        "clause": "WHERE",
                        "expressionType": "SIMPLE",
                        "filterOptionName": "2745eae5",
                        "comparator": ["other"],
                        "operator": "NOT IN",
                        "subject": "state",
                    }
                ],
                viz_type="dist_bar",
                metrics=[
                    {
                        "expressionType": "SIMPLE",
                        "column": {"column_name": "num_boys", "type": "BIGINT(20)"},
                        "aggregate": "SUM",
                        "label": "Boys",
                        "optionName": "metric_11",
                    },
                    {
                        "expressionType": "SIMPLE",
                        "column": {"column_name": "num_girls", "type": "BIGINT(20)"},
                        "aggregate": "SUM",
                        "label": "Girls",
                        "optionName": "metric_12",
                    },
                ],
                groupby=["state"],
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Girls",
            viz_type="table",
            params=get_slice_json(
                defaults,
                groupby=["name"],
                adhoc_filters=[gen_filter("gender", "girl")],
                row_limit=50,
                timeseries_limit_metric=metric,
                metrics=[metric],
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Girl Name Cloud",
            viz_type="word_cloud",
            params=get_slice_json(
                defaults,
                viz_type="word_cloud",
                size_from="10",
                series="name",
                size_to="70",
                rotation="square",
                limit="100",
                adhoc_filters=[gen_filter("gender", "girl")],
                metric=metric,
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Boys",
            viz_type="table",
            params=get_slice_json(
                defaults,
                groupby=["name"],
                adhoc_filters=[gen_filter("gender", "boy")],
                row_limit=50,
                timeseries_limit_metric=metric,
                metrics=[metric],
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Boy Name Cloud",
            viz_type="word_cloud",
            params=get_slice_json(
                defaults,
                viz_type="word_cloud",
                size_from="10",
                series="name",
                size_to="70",
                rotation="square",
                limit="100",
                adhoc_filters=[gen_filter("gender", "boy")],
                metric=metric,
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Top 10 Girl Name Share",
            viz_type="area",
            params=get_slice_json(
                defaults,
                adhoc_filters=[gen_filter("gender", "girl")],
                comparison_type="values",
                groupby=["name"],
                limit=10,
                stacked_style="expand",
                time_grain_sqla="P1D",
                viz_type="area",
                x_axis_forma="smart_date",
                metrics=metrics,
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Top 10 Boy Name Share",
            viz_type="area",
            params=get_slice_json(
                defaults,
                adhoc_filters=[gen_filter("gender", "boy")],
                comparison_type="values",
                groupby=["name"],
                limit=10,
                stacked_style="expand",
                time_grain_sqla="P1D",
                viz_type="area",
                x_axis_forma="smart_date",
                metrics=metrics,
            ),
        ),
    ]
    misc_slices = [
        Slice(
            **slice_props,
            slice_name="Average and Sum Trends",
            viz_type="dual_line",
            params=get_slice_json(
                defaults,
                viz_type="dual_line",
                metric={
                    "expressionType": "SIMPLE",
                    "column": {"column_name": "num", "type": "BIGINT(20)"},
                    "aggregate": "AVG",
                    "label": "AVG(num)",
                    "optionName": "metric_vgops097wej_g8uff99zhk7",
                },
                metric_2="sum__num",
                granularity_sqla="ds",
                metrics=metrics,
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Num Births Trend",
            viz_type="line",
            params=get_slice_json(defaults, viz_type="line", metrics=metrics),
        ),
        Slice(
            **slice_props,
            slice_name="Daily Totals",
            viz_type="table",
            params=get_slice_json(
                defaults,
                groupby=["ds"],
                since="40 years ago",
                until="now",
                viz_type="table",
                metrics=metrics,
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Number of California Births",
            viz_type="big_number_total",
            params=get_slice_json(
                defaults,
                metric={
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num_california",
                        "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    "aggregate": "SUM",
                    "label": "SUM(num_california)",
                },
                viz_type="big_number_total",
                granularity_sqla="ds",
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Top 10 California Names Timeseries",
            viz_type="line",
            params=get_slice_json(
                defaults,
                metrics=[
                    {
                        "expressionType": "SIMPLE",
                        "column": {
                            "column_name": "num_california",
                            "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                        },
                        "aggregate": "SUM",
                        "label": "SUM(num_california)",
                    }
                ],
                viz_type="line",
                granularity_sqla="ds",
                groupby=["name"],
                timeseries_limit_metric={
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num_california",
                        "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    "aggregate": "SUM",
                    "label": "SUM(num_california)",
                },
                limit="10",
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Names Sorted by Num in California",
            viz_type="table",
            params=get_slice_json(
                defaults,
                metrics=metrics,
                groupby=["name"],
                row_limit=50,
                timeseries_limit_metric={
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num_california",
                        "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    "aggregate": "SUM",
                    "label": "SUM(num_california)",
                },
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Number of Girls",
            viz_type="big_number_total",
            params=get_slice_json(
                defaults,
                metric=metric,
                viz_type="big_number_total",
                granularity_sqla="ds",
                adhoc_filters=[gen_filter("gender", "girl")],
                subheader="total female participants",
            ),
        ),
        Slice(
            **slice_props,
            slice_name="Pivot Table",
            viz_type="pivot_table",
            params=get_slice_json(
                defaults,
                viz_type="pivot_table",
                groupby=["name"],
                columns=["state"],
                metrics=metrics,
            ),
        ),
    ]
    for slc in slices:
        merge_slice(slc)

    for slc in misc_slices:
        merge_slice(slc)
        misc_dash_slices.add(slc.slice_name)

    return slices, misc_slices


def create_dashboard(slices: List[Slice]) -> Dashboard:
    print("Creating a dashboard")
    admin = get_admin_user()
    dash = db.session.query(Dashboard).filter_by(slug="births").first()
    if not dash:
        dash = Dashboard()
        dash.owners = [admin]
        dash.created_by = admin
        db.session.add(dash)

    dash.published = True
    dash.json_metadata = textwrap.dedent(
        """\
    {
        "label_colors": {
            "Girls": "#FF69B4",
            "Boys": "#ADD8E6",
            "girl": "#FF69B4",
            "boy": "#ADD8E6"
        }
    }"""
    )
    # pylint: disable=line-too-long
    pos = json.loads(
        textwrap.dedent(
            """\
        {
          "CHART-6GdlekVise": {
            "children": [],
            "id": "CHART-6GdlekVise",
            "meta": {
              "chartId": 5547,
              "height": 50,
              "sliceName": "Top 10 Girl Name Share",
              "width": 5
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-eh0w37bWbR"
            ],
            "type": "CHART"
          },
          "CHART-6n9jxb30JG": {
            "children": [],
            "id": "CHART-6n9jxb30JG",
            "meta": {
              "chartId": 5540,
              "height": 36,
              "sliceName": "Genders by State",
              "width": 5
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW--EyBZQlDi"
            ],
            "type": "CHART"
          },
          "CHART-Jj9qh1ol-N": {
            "children": [],
            "id": "CHART-Jj9qh1ol-N",
            "meta": {
              "chartId": 5545,
              "height": 50,
              "sliceName": "Boy Name Cloud",
              "width": 4
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-kzWtcvo8R1"
            ],
            "type": "CHART"
          },
          "CHART-ODvantb_bF": {
            "children": [],
            "id": "CHART-ODvantb_bF",
            "meta": {
              "chartId": 5548,
              "height": 50,
              "sliceName": "Top 10 Boy Name Share",
              "width": 5
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-kzWtcvo8R1"
            ],
            "type": "CHART"
          },
          "CHART-PAXUUqwmX9": {
            "children": [],
            "id": "CHART-PAXUUqwmX9",
            "meta": {
              "chartId": 5538,
              "height": 34,
              "sliceName": "Genders",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-2n0XgiHDgs"
            ],
            "type": "CHART"
          },
          "CHART-_T6n_K9iQN": {
            "children": [],
            "id": "CHART-_T6n_K9iQN",
            "meta": {
              "chartId": 5539,
              "height": 36,
              "sliceName": "Trends",
              "width": 7
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW--EyBZQlDi"
            ],
            "type": "CHART"
          },
          "CHART-eNY0tcE_ic": {
            "children": [],
            "id": "CHART-eNY0tcE_ic",
            "meta": {
              "chartId": 5537,
              "height": 34,
              "sliceName": "Participants",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-2n0XgiHDgs"
            ],
            "type": "CHART"
          },
          "CHART-g075mMgyYb": {
            "children": [],
            "id": "CHART-g075mMgyYb",
            "meta": {
              "chartId": 5541,
              "height": 50,
              "sliceName": "Girls",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-eh0w37bWbR"
            ],
            "type": "CHART"
          },
          "CHART-n-zGGE6S1y": {
            "children": [],
            "id": "CHART-n-zGGE6S1y",
            "meta": {
              "chartId": 5542,
              "height": 50,
              "sliceName": "Girl Name Cloud",
              "width": 4
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-eh0w37bWbR"
            ],
            "type": "CHART"
          },
          "CHART-vJIPjmcbD3": {
            "children": [],
            "id": "CHART-vJIPjmcbD3",
            "meta": {
              "chartId": 5543,
              "height": 50,
              "sliceName": "Boys",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-kzWtcvo8R1"
            ],
            "type": "CHART"
          },
          "DASHBOARD_VERSION_KEY": "v2",
          "GRID_ID": {
            "children": [
              "ROW-2n0XgiHDgs",
              "ROW--EyBZQlDi",
              "ROW-eh0w37bWbR",
              "ROW-kzWtcvo8R1"
            ],
            "id": "GRID_ID",
            "parents": [
              "ROOT_ID"
            ],
            "type": "GRID"
          },
          "HEADER_ID": {
            "id": "HEADER_ID",
            "meta": {
              "text": "Births"
            },
            "type": "HEADER"
          },
          "MARKDOWN-zaflB60tbC": {
            "children": [],
            "id": "MARKDOWN-zaflB60tbC",
            "meta": {
              "code": "<div style=\\"text-align:center\\">  <h1>Birth Names Dashboard</h1>  <img src=\\"/static/assets/images/babies.png\\" style=\\"width:50%;\\"></div>",
              "height": 34,
              "width": 6
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-2n0XgiHDgs"
            ],
            "type": "MARKDOWN"
          },
          "ROOT_ID": {
            "children": [
              "GRID_ID"
            ],
            "id": "ROOT_ID",
            "type": "ROOT"
          },
          "ROW--EyBZQlDi": {
            "children": [
              "CHART-_T6n_K9iQN",
              "CHART-6n9jxb30JG"
            ],
            "id": "ROW--EyBZQlDi",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          },
          "ROW-2n0XgiHDgs": {
            "children": [
              "CHART-eNY0tcE_ic",
              "MARKDOWN-zaflB60tbC",
              "CHART-PAXUUqwmX9"
            ],
            "id": "ROW-2n0XgiHDgs",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          },
          "ROW-eh0w37bWbR": {
            "children": [
              "CHART-g075mMgyYb",
              "CHART-n-zGGE6S1y",
              "CHART-6GdlekVise"
            ],
            "id": "ROW-eh0w37bWbR",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          },
          "ROW-kzWtcvo8R1": {
            "children": [
              "CHART-vJIPjmcbD3",
              "CHART-Jj9qh1ol-N",
              "CHART-ODvantb_bF"
            ],
            "id": "ROW-kzWtcvo8R1",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          }
        }
        """
        )
    )
    # pylint: enable=line-too-long
    # dashboard v2 doesn't allow add markup slice
    dash.slices = [slc for slc in slices if slc.viz_type != "markup"]
    update_slice_ids(pos, dash.slices)
    dash.dashboard_title = "USA Births Names"
    dash.position_json = json.dumps(pos, indent=4)
    dash.slug = "births"
    db.session.commit()
    return dash
