blob: 73c0ca58a0a14025bbc05f7ac8ca03599e87b29e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-branches
import gzip
import json
import logging
import re
from typing import Any, Dict
from urllib import request
import pandas as pd
from flask import current_app
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.orm import Session
from sqlalchemy.sql.visitors import VisitableType
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import get_example_database
logger = logging.getLogger(__name__)
CHUNKSIZE = 512
VARCHAR = re.compile(r"VARCHAR\((\d+)\)", re.IGNORECASE)
JSON_KEYS = {"params", "template_params", "extra"}
type_map = {
"BOOLEAN": Boolean(),
"VARCHAR": String(255),
"STRING": String(255),
"TEXT": Text(),
"BIGINT": BigInteger(),
"FLOAT": Float(),
"FLOAT64": Float(),
"DOUBLE PRECISION": Float(),
"DATE": Date(),
"DATETIME": DateTime(),
"TIMESTAMP WITHOUT TIME ZONE": DateTime(timezone=False),
"TIMESTAMP WITH TIME ZONE": DateTime(timezone=True),
}
def get_sqla_type(native_type: str) -> VisitableType:
if native_type.upper() in type_map:
return type_map[native_type.upper()]
match = VARCHAR.match(native_type)
if match:
size = int(match.group(1))
return String(size)
raise Exception(f"Unknown type: {native_type}")
def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
return {
column.column_name: get_sqla_type(column.type)
for column in dataset.columns
if column.column_name in df.keys()
}
def import_dataset(
session: Session,
config: Dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
) -> SqlaTable:
existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite:
return existing
config["id"] = existing.id
# TODO (betodealmeida): move this logic to import_from_dict
config = config.copy()
for key in JSON_KEYS:
if config.get(key):
try:
config[key] = json.dumps(config[key])
except TypeError:
logger.info("Unable to encode `%s` field: %s", key, config[key])
for metric in config.get("metrics", []):
if metric.get("extra"):
try:
metric["extra"] = json.dumps(metric["extra"])
except TypeError:
logger.info("Unable to encode `extra` field: %s", metric["extra"])
# should we delete columns and metrics not present in the current import?
sync = ["columns", "metrics"] if overwrite else []
# should we also load data into the dataset?
data_uri = config.get("data")
# import recursively to include columns and metrics
dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
if dataset.id is None:
session.flush()
example_database = get_example_database()
table_exists = example_database.has_table_by_name(dataset.table_name)
if data_uri and (not table_exists or force_data):
load_data(data_uri, dataset, example_database, session)
return dataset
def load_data(
data_uri: str, dataset: SqlaTable, example_database: Database, session: Session
) -> None:
data = request.urlopen(data_uri)
if data_uri.endswith(".gz"):
data = gzip.open(data)
df = pd.read_csv(data, encoding="utf-8")
dtype = get_dtype(df, dataset)
# convert temporal columns
for column_name, sqla_type in dtype.items():
if isinstance(sqla_type, (Date, DateTime)):
df[column_name] = pd.to_datetime(df[column_name])
# reuse session when loading data if possible, to make import atomic
if example_database.sqlalchemy_uri == current_app.config.get(
"SQLALCHEMY_DATABASE_URI"
) or not current_app.config.get("SQLALCHEMY_EXAMPLES_URI"):
logger.info("Loading data inside the import transaction")
connection = session.connection()
else:
logger.warning("Loading data outside the import transaction")
connection = example_database.get_sqla_engine()
df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)