blob: 40b7103c1b19e869b1bb484bddaff2d581770de4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import textwrap
from typing import Dict, List, Tuple, Union
import pandas as pd
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import DateTime, String
from sqlalchemy.sql import column
from superset import app, db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.exceptions import NoDataException
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_database
from .helpers import (
get_example_data,
get_slice_json,
get_table_connector_registry,
merge_slice,
misc_dash_slices,
update_slice_ids,
)
def get_admin_user() -> User:
admin = security_manager.find_user("admin")
if admin is None:
raise NoDataException(
"Admin user does not exist. "
"Please, check if test users are properly loaded "
"(`superset load_test_users`)."
)
return admin
def gen_filter(
subject: str, comparator: str, operator: str = "=="
) -> Dict[str, Union[bool, str]]:
return {
"clause": "WHERE",
"comparator": comparator,
"expressionType": "SIMPLE",
"operator": operator,
"subject": subject,
}
def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
pdf = pd.read_json(get_example_data("birth_names2.json.gz"))
# TODO(bkyryliuk): move load examples data into the pytest fixture
if database.backend == "presto":
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S")
else:
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf = pdf.head(100) if sample else pdf
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={
# TODO(bkyryliuk): use TIMESTAMP type for presto
"ds": DateTime if database.backend != "presto" else String(255),
"gender": String(16),
"state": String(10),
"name": String(255),
},
method="multi",
index=False,
)
print("Done loading table!")
print("-" * 80)
def load_birth_names(
only_metadata: bool = False, force: bool = False, sample: bool = False
) -> None:
"""Loading birth name dataset from a zip file in the repo"""
# pylint: disable=too-many-locals
tbl_name = "birth_names"
database = get_example_database()
table_exists = database.has_table_by_name(tbl_name)
if not only_metadata and (not table_exists or force):
load_data(tbl_name, database, sample=sample)
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
print(f"Creating table [{tbl_name}] reference")
obj = table(table_name=tbl_name)
db.session.add(obj)
_set_table_metadata(obj, database)
_add_table_metrics(obj)
db.session.commit()
slices, _ = create_slices(obj, admin_owner=True)
create_dashboard(slices)
def _set_table_metadata(datasource: "BaseDatasource", database: "Database") -> None:
datasource.main_dttm_col = "ds" # type: ignore
datasource.database = database
datasource.filter_select_enabled = True
datasource.fetch_metadata()
def _add_table_metrics(datasource: "BaseDatasource") -> None:
if not any(col.column_name == "num_california" for col in datasource.columns):
col_state = str(column("state").compile(db.engine))
col_num = str(column("num").compile(db.engine))
datasource.columns.append(
TableColumn(
column_name="num_california",
expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
)
)
if not any(col.metric_name == "sum__num" for col in datasource.metrics):
col = str(column("num").compile(db.engine))
datasource.metrics.append(
SqlMetric(metric_name="sum__num", expression=f"SUM({col})")
)
for col in datasource.columns:
if col.column_name == "ds":
col.is_dttm = True # type: ignore
break
def create_slices(
tbl: BaseDatasource, admin_owner: bool
) -> Tuple[List[Slice], List[Slice]]:
metrics = [
{
"expressionType": "SIMPLE",
"column": {"column_name": "num", "type": "BIGINT"},
"aggregate": "SUM",
"label": "Births",
"optionName": "metric_11",
}
]
metric = "sum__num"
defaults = {
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"time_range": "No filter",
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"groupby": [],
"row_limit": app.config["ROW_LIMIT"],
"since": "100 years ago",
"until": "now",
"viz_type": "table",
"markup_type": "markdown",
}
admin = get_admin_user()
if admin_owner:
slice_props = dict(
datasource_id=tbl.id,
datasource_type="table",
owners=[admin],
created_by=admin,
)
else:
slice_props = dict(
datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin
)
print("Creating some slices")
slices = [
Slice(
**slice_props,
slice_name="Participants",
viz_type="big_number",
params=get_slice_json(
defaults,
viz_type="big_number",
granularity_sqla="ds",
compare_lag="5",
compare_suffix="over 5Y",
metric=metric,
),
),
Slice(
**slice_props,
slice_name="Genders",
viz_type="pie",
params=get_slice_json(
defaults, viz_type="pie", groupby=["gender"], metric=metric
),
),
Slice(
**slice_props,
slice_name="Trends",
viz_type="line",
params=get_slice_json(
defaults,
viz_type="line",
groupby=["name"],
granularity_sqla="ds",
rich_tooltip=True,
show_legend=True,
metrics=metrics,
),
),
Slice(
**slice_props,
slice_name="Genders by State",
viz_type="dist_bar",
params=get_slice_json(
defaults,
adhoc_filters=[
{
"clause": "WHERE",
"expressionType": "SIMPLE",
"filterOptionName": "2745eae5",
"comparator": ["other"],
"operator": "NOT IN",
"subject": "state",
}
],
viz_type="dist_bar",
metrics=[
{
"expressionType": "SIMPLE",
"column": {"column_name": "num_boys", "type": "BIGINT(20)"},
"aggregate": "SUM",
"label": "Boys",
"optionName": "metric_11",
},
{
"expressionType": "SIMPLE",
"column": {"column_name": "num_girls", "type": "BIGINT(20)"},
"aggregate": "SUM",
"label": "Girls",
"optionName": "metric_12",
},
],
groupby=["state"],
),
),
Slice(
**slice_props,
slice_name="Girls",
viz_type="table",
params=get_slice_json(
defaults,
groupby=["name"],
adhoc_filters=[gen_filter("gender", "girl")],
row_limit=50,
timeseries_limit_metric=metric,
metrics=[metric],
),
),
Slice(
**slice_props,
slice_name="Girl Name Cloud",
viz_type="word_cloud",
params=get_slice_json(
defaults,
viz_type="word_cloud",
size_from="10",
series="name",
size_to="70",
rotation="square",
limit="100",
adhoc_filters=[gen_filter("gender", "girl")],
metric=metric,
),
),
Slice(
**slice_props,
slice_name="Boys",
viz_type="table",
params=get_slice_json(
defaults,
groupby=["name"],
adhoc_filters=[gen_filter("gender", "boy")],
row_limit=50,
timeseries_limit_metric=metric,
metrics=[metric],
),
),
Slice(
**slice_props,
slice_name="Boy Name Cloud",
viz_type="word_cloud",
params=get_slice_json(
defaults,
viz_type="word_cloud",
size_from="10",
series="name",
size_to="70",
rotation="square",
limit="100",
adhoc_filters=[gen_filter("gender", "boy")],
metric=metric,
),
),
Slice(
**slice_props,
slice_name="Top 10 Girl Name Share",
viz_type="area",
params=get_slice_json(
defaults,
adhoc_filters=[gen_filter("gender", "girl")],
comparison_type="values",
groupby=["name"],
limit=10,
stacked_style="expand",
time_grain_sqla="P1D",
viz_type="area",
x_axis_forma="smart_date",
metrics=metrics,
),
),
Slice(
**slice_props,
slice_name="Top 10 Boy Name Share",
viz_type="area",
params=get_slice_json(
defaults,
adhoc_filters=[gen_filter("gender", "boy")],
comparison_type="values",
groupby=["name"],
limit=10,
stacked_style="expand",
time_grain_sqla="P1D",
viz_type="area",
x_axis_forma="smart_date",
metrics=metrics,
),
),
]
misc_slices = [
Slice(
**slice_props,
slice_name="Average and Sum Trends",
viz_type="dual_line",
params=get_slice_json(
defaults,
viz_type="dual_line",
metric={
"expressionType": "SIMPLE",
"column": {"column_name": "num", "type": "BIGINT(20)"},
"aggregate": "AVG",
"label": "AVG(num)",
"optionName": "metric_vgops097wej_g8uff99zhk7",
},
metric_2="sum__num",
granularity_sqla="ds",
metrics=metrics,
),
),
Slice(
**slice_props,
slice_name="Num Births Trend",
viz_type="line",
params=get_slice_json(defaults, viz_type="line", metrics=metrics),
),
Slice(
**slice_props,
slice_name="Daily Totals",
viz_type="table",
params=get_slice_json(
defaults,
groupby=["ds"],
since="40 years ago",
until="now",
viz_type="table",
metrics=metrics,
),
),
Slice(
**slice_props,
slice_name="Number of California Births",
viz_type="big_number_total",
params=get_slice_json(
defaults,
metric={
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
"aggregate": "SUM",
"label": "SUM(num_california)",
},
viz_type="big_number_total",
granularity_sqla="ds",
),
),
Slice(
**slice_props,
slice_name="Top 10 California Names Timeseries",
viz_type="line",
params=get_slice_json(
defaults,
metrics=[
{
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
"aggregate": "SUM",
"label": "SUM(num_california)",
}
],
viz_type="line",
granularity_sqla="ds",
groupby=["name"],
timeseries_limit_metric={
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
"aggregate": "SUM",
"label": "SUM(num_california)",
},
limit="10",
),
),
Slice(
**slice_props,
slice_name="Names Sorted by Num in California",
viz_type="table",
params=get_slice_json(
defaults,
metrics=metrics,
groupby=["name"],
row_limit=50,
timeseries_limit_metric={
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
"aggregate": "SUM",
"label": "SUM(num_california)",
},
),
),
Slice(
**slice_props,
slice_name="Number of Girls",
viz_type="big_number_total",
params=get_slice_json(
defaults,
metric=metric,
viz_type="big_number_total",
granularity_sqla="ds",
adhoc_filters=[gen_filter("gender", "girl")],
subheader="total female participants",
),
),
Slice(
**slice_props,
slice_name="Pivot Table",
viz_type="pivot_table",
params=get_slice_json(
defaults,
viz_type="pivot_table",
groupby=["name"],
columns=["state"],
metrics=metrics,
),
),
]
for slc in slices:
merge_slice(slc)
for slc in misc_slices:
merge_slice(slc)
misc_dash_slices.add(slc.slice_name)
return slices, misc_slices
def create_dashboard(slices: List[Slice]) -> Dashboard:
print("Creating a dashboard")
admin = get_admin_user()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
if not dash:
dash = Dashboard()
dash.owners = [admin]
dash.created_by = admin
db.session.add(dash)
dash.published = True
dash.json_metadata = textwrap.dedent(
"""\
{
"label_colors": {
"Girls": "#FF69B4",
"Boys": "#ADD8E6",
"girl": "#FF69B4",
"boy": "#ADD8E6"
}
}"""
)
pos = json.loads(
textwrap.dedent(
# pylint: disable=line-too-long
"""\
{
"CHART-6GdlekVise": {
"children": [],
"id": "CHART-6GdlekVise",
"meta": {
"chartId": 5547,
"height": 50,
"sliceName": "Top 10 Girl Name Share",
"width": 5
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-eh0w37bWbR"
],
"type": "CHART"
},
"CHART-6n9jxb30JG": {
"children": [],
"id": "CHART-6n9jxb30JG",
"meta": {
"chartId": 5540,
"height": 36,
"sliceName": "Genders by State",
"width": 5
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW--EyBZQlDi"
],
"type": "CHART"
},
"CHART-Jj9qh1ol-N": {
"children": [],
"id": "CHART-Jj9qh1ol-N",
"meta": {
"chartId": 5545,
"height": 50,
"sliceName": "Boy Name Cloud",
"width": 4
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-kzWtcvo8R1"
],
"type": "CHART"
},
"CHART-ODvantb_bF": {
"children": [],
"id": "CHART-ODvantb_bF",
"meta": {
"chartId": 5548,
"height": 50,
"sliceName": "Top 10 Boy Name Share",
"width": 5
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-kzWtcvo8R1"
],
"type": "CHART"
},
"CHART-PAXUUqwmX9": {
"children": [],
"id": "CHART-PAXUUqwmX9",
"meta": {
"chartId": 5538,
"height": 34,
"sliceName": "Genders",
"width": 3
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-2n0XgiHDgs"
],
"type": "CHART"
},
"CHART-_T6n_K9iQN": {
"children": [],
"id": "CHART-_T6n_K9iQN",
"meta": {
"chartId": 5539,
"height": 36,
"sliceName": "Trends",
"width": 7
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW--EyBZQlDi"
],
"type": "CHART"
},
"CHART-eNY0tcE_ic": {
"children": [],
"id": "CHART-eNY0tcE_ic",
"meta": {
"chartId": 5537,
"height": 34,
"sliceName": "Participants",
"width": 3
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-2n0XgiHDgs"
],
"type": "CHART"
},
"CHART-g075mMgyYb": {
"children": [],
"id": "CHART-g075mMgyYb",
"meta": {
"chartId": 5541,
"height": 50,
"sliceName": "Girls",
"width": 3
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-eh0w37bWbR"
],
"type": "CHART"
},
"CHART-n-zGGE6S1y": {
"children": [],
"id": "CHART-n-zGGE6S1y",
"meta": {
"chartId": 5542,
"height": 50,
"sliceName": "Girl Name Cloud",
"width": 4
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-eh0w37bWbR"
],
"type": "CHART"
},
"CHART-vJIPjmcbD3": {
"children": [],
"id": "CHART-vJIPjmcbD3",
"meta": {
"chartId": 5543,
"height": 50,
"sliceName": "Boys",
"width": 3
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-kzWtcvo8R1"
],
"type": "CHART"
},
"DASHBOARD_VERSION_KEY": "v2",
"GRID_ID": {
"children": [
"ROW-2n0XgiHDgs",
"ROW--EyBZQlDi",
"ROW-eh0w37bWbR",
"ROW-kzWtcvo8R1"
],
"id": "GRID_ID",
"parents": [
"ROOT_ID"
],
"type": "GRID"
},
"HEADER_ID": {
"id": "HEADER_ID",
"meta": {
"text": "Births"
},
"type": "HEADER"
},
"MARKDOWN-zaflB60tbC": {
"children": [],
"id": "MARKDOWN-zaflB60tbC",
"meta": {
"code": "<div style=\\"text-align:center\\"> <h1>Birth Names Dashboard</h1> <img src=\\"/static/assets/images/babies.png\\" style=\\"width:50%;\\"></div>",
"height": 34,
"width": 6
},
"parents": [
"ROOT_ID",
"GRID_ID",
"ROW-2n0XgiHDgs"
],
"type": "MARKDOWN"
},
"ROOT_ID": {
"children": [
"GRID_ID"
],
"id": "ROOT_ID",
"type": "ROOT"
},
"ROW--EyBZQlDi": {
"children": [
"CHART-_T6n_K9iQN",
"CHART-6n9jxb30JG"
],
"id": "ROW--EyBZQlDi",
"meta": {
"background": "BACKGROUND_TRANSPARENT"
},
"parents": [
"ROOT_ID",
"GRID_ID"
],
"type": "ROW"
},
"ROW-2n0XgiHDgs": {
"children": [
"CHART-eNY0tcE_ic",
"MARKDOWN-zaflB60tbC",
"CHART-PAXUUqwmX9"
],
"id": "ROW-2n0XgiHDgs",
"meta": {
"background": "BACKGROUND_TRANSPARENT"
},
"parents": [
"ROOT_ID",
"GRID_ID"
],
"type": "ROW"
},
"ROW-eh0w37bWbR": {
"children": [
"CHART-g075mMgyYb",
"CHART-n-zGGE6S1y",
"CHART-6GdlekVise"
],
"id": "ROW-eh0w37bWbR",
"meta": {
"background": "BACKGROUND_TRANSPARENT"
},
"parents": [
"ROOT_ID",
"GRID_ID"
],
"type": "ROW"
},
"ROW-kzWtcvo8R1": {
"children": [
"CHART-vJIPjmcbD3",
"CHART-Jj9qh1ol-N",
"CHART-ODvantb_bF"
],
"id": "ROW-kzWtcvo8R1",
"meta": {
"background": "BACKGROUND_TRANSPARENT"
},
"parents": [
"ROOT_ID",
"GRID_ID"
],
"type": "ROW"
}
}
""" # pylint: enable=line-too-long
)
)
# dashboard v2 doesn't allow add markup slice
dash.slices = [slc for slc in slices if slc.viz_type != "markup"]
update_slice_ids(pos, dash.slices)
dash.dashboard_title = "USA Births Names"
dash.position_json = json.dumps(pos, indent=4)
dash.slug = "births"
db.session.commit()
return dash