blob: f84f5296b59b0e632cd67209841abe8be776e299 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Tuple
from marshmallow import Schema
from sqlalchemy.orm import Session
from sqlalchemy.sql import select
from superset import db
from superset.charts.commands.importers.v1.utils import import_chart
from superset.charts.schemas import ImportV1ChartSchema
from superset.commands.exceptions import CommandException
from superset.commands.importers.v1 import ImportModelsCommand
from superset.dao.base import BaseDAO
from superset.dashboards.commands.importers.v1.utils import (
find_chart_uuids,
import_dashboard,
update_id_refs,
)
from superset.dashboards.schemas import ImportV1DashboardSchema
from superset.databases.commands.importers.v1.utils import import_database
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database
from superset.models.dashboard import dashboard_slices
class ImportExamplesCommand(ImportModelsCommand):
"""Import examples"""
dao = BaseDAO
model_name = "model"
schemas: Dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
"databases/": ImportV1DatabaseSchema(),
}
import_error = CommandException
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
super().__init__(contents, *args, **kwargs)
self.force_data = kwargs.get("force_data", False)
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
self._import(db.session, self._configs, self.overwrite, self.force_data)
db.session.commit()
except Exception:
db.session.rollback()
raise self.import_error()
# pylint: disable=too-many-locals, arguments-differ, too-many-branches
@staticmethod
def _import(
session: Session,
configs: Dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
) -> None:
# import databases
database_ids: Dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=overwrite)
database_ids[str(database.uuid)] = database.id
# import datasets
# If database_uuid is not in the list of UUIDs it means that the examples
# database was created before its UUID was frozen, so it has a random UUID.
# We need to determine its ID so we can point the dataset to it.
examples_db = (
db.session.query(Database).filter_by(database_name="examples").first()
)
dataset_info: Dict[str, Dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
# find the ID of the corresponding database
if config["database_uuid"] not in database_ids:
if examples_db is None:
raise Exception("Cannot find examples database")
config["database_id"] = examples_db.id
else:
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(
session, config, overwrite=overwrite, force_data=force_data
)
dataset_info[str(dataset.uuid)] = {
"datasource_id": dataset.id,
"datasource_type": "view" if dataset.is_sqllab_view else "table",
"datasource_name": dataset.table_name,
}
# import charts
chart_ids: Dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("charts/"):
# update datasource id, type, and name
config.update(dataset_info[config["dataset_uuid"]])
chart = import_chart(session, config, overwrite=overwrite)
chart_ids[str(chart.uuid)] = chart.id
# store the existing relationship between dashboards and charts
existing_relationships = session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
).fetchall()
# import dashboards
dashboard_chart_ids: List[Tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids)
dashboard = import_dashboard(session, config, overwrite=overwrite)
dashboard.published = True
for uuid in find_chart_uuids(config["position"]):
chart_id = chart_ids[uuid]
if (dashboard.id, chart_id) not in existing_relationships:
dashboard_chart_ids.append((dashboard.id, chart_id))
# set ref in the dashboard_slices table
values = [
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
# pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
session.execute(dashboard_slices.insert(), values)