# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from typing import Any, Dict, List, Optional

from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import contains_eager

from superset.dao.base import BaseDAO
from superset.dashboards.commands.exceptions import DashboardNotFoundError
from superset.dashboards.filters import DashboardFilter
from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName
from superset.models.dashboard import Dashboard, id_or_slug_filter
from superset.models.slice import Slice
from superset.utils import core
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes

logger = logging.getLogger(__name__)


class DashboardDAO(BaseDAO):
    model_cls = Dashboard
    base_filter = DashboardFilter

    @staticmethod
    def get_by_id_or_slug(id_or_slug: str) -> Dashboard:
        query = (
            db.session.query(Dashboard)
            .filter(id_or_slug_filter(id_or_slug))
            .outerjoin(Slice, Dashboard.slices)
            .outerjoin(Slice.table)
            .outerjoin(Dashboard.owners)
            .outerjoin(Dashboard.roles)
        )
        # Apply dashboard base filters
        query = DashboardFilter("id", SQLAInterface(Dashboard, db.session)).apply(
            query, None
        )
        dashboard = query.one_or_none()
        if not dashboard:
            raise DashboardNotFoundError()
        return dashboard

    @staticmethod
    def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]:
        query = (
            db.session.query(Dashboard)
            .filter(id_or_slug_filter(id_or_slug))
            .outerjoin(Slice, Dashboard.slices)
            .outerjoin(Slice.table)
        )
        # Apply dashboard base filters
        query = DashboardFilter("id", SQLAInterface(Dashboard, db.session)).apply(
            query, None
        )
        dashboard = query.one_or_none()
        if not dashboard:
            raise DashboardNotFoundError()
        datasource_slices = core.indexed(dashboard.slices, "datasource")
        data = [
            datasource.data_for_slices(slices)
            for datasource, slices in datasource_slices.items()
            if datasource
        ]
        return data

    @staticmethod
    def get_charts_for_dashboard(dashboard_id: int) -> List[Slice]:
        query = (
            db.session.query(Dashboard)
            .outerjoin(Slice, Dashboard.slices)
            .outerjoin(Slice.table)
            .filter(Dashboard.id == dashboard_id)
            .options(contains_eager(Dashboard.slices))
        )
        # Apply dashboard base filters
        query = DashboardFilter("id", SQLAInterface(Dashboard, db.session)).apply(
            query, None
        )

        dashboard = query.one_or_none()
        if not dashboard:
            raise DashboardNotFoundError()
        return dashboard.slices

    @staticmethod
    def validate_slug_uniqueness(slug: str) -> bool:
        if not slug:
            return True
        dashboard_query = db.session.query(Dashboard).filter(Dashboard.slug == slug)
        return not db.session.query(dashboard_query.exists()).scalar()

    @staticmethod
    def validate_update_slug_uniqueness(dashboard_id: int, slug: Optional[str]) -> bool:
        if slug is not None:
            dashboard_query = db.session.query(Dashboard).filter(
                Dashboard.slug == slug, Dashboard.id != dashboard_id
            )
            return not db.session.query(dashboard_query.exists()).scalar()
        return True

    @staticmethod
    def update_charts_owners(model: Dashboard, commit: bool = True) -> Dashboard:
        owners = list(model.owners)
        for slc in model.slices:
            slc.owners = list(set(owners) | set(slc.owners))
        if commit:
            db.session.commit()
        return model

    @staticmethod
    def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None:
        item_ids = [model.id for model in models] if models else []
        # bulk delete, first delete related data
        if models:
            for model in models:
                model.slices = []
                model.owners = []
                db.session.merge(model)
        # bulk delete itself
        try:
            db.session.query(Dashboard).filter(Dashboard.id.in_(item_ids)).delete(
                synchronize_session="fetch"
            )
            if commit:
                db.session.commit()
        except SQLAlchemyError as ex:
            if commit:
                db.session.rollback()
            raise ex

    @staticmethod
    def set_dash_metadata(
        dashboard: Dashboard,
        data: Dict[Any, Any],
        old_to_new_slice_ids: Optional[Dict[int, int]] = None,
    ) -> None:
        positions = data["positions"]
        # find slices in the position data
        slice_ids = [
            value.get("meta", {}).get("chartId")
            for value in positions.values()
            if isinstance(value, dict)
        ]

        session = db.session()
        current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()

        dashboard.slices = current_slices

        # add UUID to positions
        uuid_map = {slice.id: str(slice.uuid) for slice in current_slices}
        for obj in positions.values():
            if (
                isinstance(obj, dict)
                and obj["type"] == "CHART"
                and obj["meta"]["chartId"]
            ):
                chart_id = obj["meta"]["chartId"]
                obj["meta"]["uuid"] = uuid_map.get(chart_id)

        # remove leading and trailing white spaces in the dumped json
        dashboard.position_json = json.dumps(
            positions, indent=None, separators=(",", ":"), sort_keys=True
        )
        md = dashboard.params_dict
        dashboard.css = data.get("css")
        dashboard.dashboard_title = data["dashboard_title"]

        if "timed_refresh_immune_slices" not in md:
            md["timed_refresh_immune_slices"] = []
        new_filter_scopes = {}
        if "filter_scopes" in data:
            # replace filter_id and immune ids from old slice id to new slice id:
            # and remove slice ids that are not in dash anymore
            slc_id_dict: Dict[int, int] = {}
            if old_to_new_slice_ids:
                slc_id_dict = {
                    old: new
                    for old, new in old_to_new_slice_ids.items()
                    if new in slice_ids
                }
            else:
                slc_id_dict = {sid: sid for sid in slice_ids}
            new_filter_scopes = copy_filter_scopes(
                old_to_new_slc_id_dict=slc_id_dict,
                old_filter_scopes=json.loads(data["filter_scopes"] or "{}"),
            )
        if new_filter_scopes:
            md["filter_scopes"] = new_filter_scopes
        else:
            md.pop("filter_scopes", None)
        md["expanded_slices"] = data.get("expanded_slices", {})
        md["refresh_frequency"] = data.get("refresh_frequency", 0)
        default_filters_data = json.loads(data.get("default_filters", "{}"))
        applicable_filters = {
            key: v for key, v in default_filters_data.items() if int(key) in slice_ids
        }
        md["default_filters"] = json.dumps(applicable_filters)
        md["color_scheme"] = data.get("color_scheme")
        if data.get("color_namespace"):
            md["color_namespace"] = data.get("color_namespace")
        if data.get("label_colors"):
            md["label_colors"] = data.get("label_colors")
        dashboard.json_metadata = json.dumps(md)

    @staticmethod
    def favorited_ids(
        dashboards: List[Dashboard], current_user_id: int
    ) -> List[FavStar]:
        ids = [dash.id for dash in dashboards]
        return [
            star.obj_id
            for star in db.session.query(FavStar.obj_id)
            .filter(
                FavStar.class_name == FavStarClassName.DASHBOARD,
                FavStar.obj_id.in_(ids),
                FavStar.user_id == current_user_id,
            )
            .all()
        ]
