blob: ed254a45f214ec4e1ddb1b70f9e637e3cbd91eab [file] [log] [blame]
"""A collection of ORM sqlalchemy models for Superset"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from collections import OrderedDict
import functools
import json
import logging
import pickle
import re
import textwrap
from copy import deepcopy, copy
from datetime import timedelta, datetime, date
import humanize
import pandas as pd
import requests
import sqlalchemy as sqla
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm import subqueryload
import sqlparse
from dateutil.parser import parse
from flask import escape, g, Markup, request
from flask_appbuilder import Model
from flask_appbuilder.models.mixins import AuditMixin
from flask_appbuilder.models.decorators import renders
from flask_babel import lazy_gettext as _
from pydruid.client import PyDruid
from pydruid.utils.aggregators import count
from pydruid.utils.filters import Dimension, Filter
from pydruid.utils.postaggregator import Postaggregator
from pydruid.utils.having import Aggregation
from six import string_types
from sqlalchemy import (
Column, Integer, String, ForeignKey, Text, Boolean,
DateTime, Date, Table, Numeric,
create_engine, MetaData, desc, asc, select, and_
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm.session import make_transient
from sqlalchemy.sql import table, literal_column, text, column
from sqlalchemy.sql.expression import ColumnClause, TextAsFrom
from sqlalchemy_utils import EncryptedType
from werkzeug.datastructures import ImmutableMultiDict
from superset import (
app, db, db_engine_specs, get_session, utils, sm, import_util
)
from superset.source_registry import SourceRegistry
from superset.viz import viz_types
from superset.jinja_context import get_template_processor
from superset.utils import (
flasher, MetricPermException, DimSelector, wrap_clause_in_parens,
DTTM_ALIAS, QueryStatus,
)
config = app.config
class QueryResult(object):
"""Object returned by the query interface"""
def __init__( # noqa
self,
df,
query,
duration,
status=QueryStatus.SUCCESS,
error_message=None):
self.df = df
self.query = query
self.duration = duration
self.status = status
self.error_message = error_message
FillterPattern = re.compile(r'''((?:[^,"']|"[^"]*"|'[^']*')+)''')
def set_perm(mapper, connection, target): # noqa
if target.perm != target.get_perm():
link_table = target.__table__
connection.execute(
link_table.update()
.where(link_table.c.id == target.id)
.values(perm=target.get_perm())
)
def set_related_perm(mapper, connection, target): # noqa
src_class = target.cls_model
id_ = target.datasource_id
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
target.perm = ds.perm
class JavascriptPostAggregator(Postaggregator):
def __init__(self, name, field_names, function):
self.post_aggregator = {
'type': 'javascript',
'fieldNames': field_names,
'name': name,
'function': function,
}
self.name = name
class ImportMixin(object):
def override(self, obj):
"""Overrides the plain fields of the dashboard."""
for field in obj.__class__.export_fields:
setattr(self, field, getattr(obj, field))
def copy(self):
"""Creates a copy of the dashboard without relationships."""
new_obj = self.__class__()
new_obj.override(self)
return new_obj
def alter_params(self, **kwargs):
d = self.params_dict
d.update(kwargs)
self.params = json.dumps(d)
@property
def params_dict(self):
if self.params:
params = re.sub(",[ \t\r\n]+}", "}", self.params)
params = re.sub(",[ \t\r\n]+\]", "]", params)
return json.loads(params)
else:
return {}
class AuditMixinNullable(AuditMixin):
"""Altering the AuditMixin to use nullable fields
Allows creating objects programmatically outside of CRUD
"""
created_on = Column(DateTime, default=datetime.now, nullable=True)
changed_on = Column(
DateTime, default=datetime.now,
onupdate=datetime.now, nullable=True)
@declared_attr
def created_by_fk(cls): # noqa
return Column(Integer, ForeignKey('ab_user.id'),
default=cls.get_user_id, nullable=True)
@declared_attr
def changed_by_fk(cls): # noqa
return Column(
Integer, ForeignKey('ab_user.id'),
default=cls.get_user_id, onupdate=cls.get_user_id, nullable=True)
def _user_link(self, user):
if not user:
return ''
url = '/superset/profile/{}/'.format(user.username)
return Markup('<a href="{}">{}</a>'.format(url, escape(user) or ''))
@renders('created_by')
def creator(self): # noqa
return self._user_link(self.created_by)
@property
def changed_by_(self):
return self._user_link(self.changed_by)
@renders('changed_on')
def changed_on_(self):
return Markup(
'<span class="no-wrap">{}</span>'.format(self.changed_on))
@renders('changed_on')
def modified(self):
s = humanize.naturaltime(datetime.now() - self.changed_on)
return Markup('<span class="no-wrap">{}</span>'.format(s))
@property
def icons(self):
return """
<a
href="{self.datasource_edit_url}"
data-toggle="tooltip"
title="{self.datasource}">
<i class="fa fa-database"></i>
</a>
""".format(**locals())
class Url(Model, AuditMixinNullable):
"""Used for the short url feature"""
__tablename__ = 'url'
id = Column(Integer, primary_key=True)
url = Column(Text)
class CssTemplate(Model, AuditMixinNullable):
"""CSS templates for dashboards"""
__tablename__ = 'css_templates'
id = Column(Integer, primary_key=True)
template_name = Column(String(250))
css = Column(Text, default='')
slice_user = Table('slice_user', Model.metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('ab_user.id')),
Column('slice_id', Integer, ForeignKey('slices.id'))
)
class Slice(Model, AuditMixinNullable, ImportMixin):
"""A slice is essentially a report or a view on data"""
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
slice_name = Column(String(250))
datasource_id = Column(Integer)
datasource_type = Column(String(200))
datasource_name = Column(String(2000))
viz_type = Column(String(250))
params = Column(Text)
description = Column(Text)
cache_timeout = Column(Integer)
perm = Column(String(1000))
owners = relationship("User", secondary=slice_user)
export_fields = ('slice_name', 'datasource_type', 'datasource_name',
'viz_type', 'params', 'cache_timeout')
def __repr__(self):
return self.slice_name
@property
def cls_model(self):
return SourceRegistry.sources[self.datasource_type]
@property
def datasource(self):
return self.get_datasource
@datasource.getter
@utils.memoized
def get_datasource(self):
ds = db.session.query(
self.cls_model).filter_by(
id=self.datasource_id).first()
return ds
@renders('datasource_name')
def datasource_link(self):
datasource = self.datasource
if datasource:
return self.datasource.link
@property
def datasource_edit_url(self):
self.datasource.url
@property
@utils.memoized
def viz(self):
d = json.loads(self.params)
viz_class = viz_types[self.viz_type]
return viz_class(self.datasource, form_data=d)
@property
def description_markeddown(self):
return utils.markdown(self.description)
@property
def data(self):
"""Data used to render slice in templates"""
d = {}
self.token = ''
try:
d = self.viz.data
self.token = d.get('token')
except Exception as e:
logging.exception(e)
d['error'] = str(e)
d['slice_id'] = self.id
d['slice_name'] = self.slice_name
d['description'] = self.description
d['slice_url'] = self.slice_url
d['edit_url'] = self.edit_url
d['description_markeddown'] = self.description_markeddown
return d
@property
def json_data(self):
return json.dumps(self.data)
@property
def slice_url(self):
"""Defines the url to access the slice"""
try:
slice_params = json.loads(self.params)
except Exception as e:
logging.exception(e)
slice_params = {}
slice_params['slice_id'] = self.id
slice_params['json'] = "false"
slice_params['slice_name'] = self.slice_name
from werkzeug.urls import Href
href = Href(
"/superset/explore/{obj.datasource_type}/"
"{obj.datasource_id}/".format(obj=self))
return href(slice_params)
@property
def slice_id_url(self):
return (
"/superset/{slc.datasource_type}/{slc.datasource_id}/{slc.id}/"
).format(slc=self)
@property
def edit_url(self):
return "/slicemodelview/edit/{}".format(self.id)
@property
def slice_link(self):
url = self.slice_url
name = escape(self.slice_name)
return Markup('<a href="{url}">{name}</a>'.format(**locals()))
def get_viz(self, url_params_multidict=None):
"""Creates :py:class:viz.BaseViz object from the url_params_multidict.
:param werkzeug.datastructures.MultiDict url_params_multidict:
Contains the visualization params, they override the self.params
stored in the database
:return: object of the 'viz_type' type that is taken from the
url_params_multidict or self.params.
:rtype: :py:class:viz.BaseViz
"""
slice_params = json.loads(self.params) # {}
slice_params['slice_id'] = self.id
slice_params['json'] = "false"
slice_params['slice_name'] = self.slice_name
slice_params['viz_type'] = self.viz_type if self.viz_type else "table"
if url_params_multidict:
slice_params.update(url_params_multidict)
to_del = [k for k in slice_params if k not in url_params_multidict]
for k in to_del:
del slice_params[k]
immutable_slice_params = ImmutableMultiDict(slice_params)
return viz_types[immutable_slice_params.get('viz_type')](
self.datasource,
form_data=immutable_slice_params,
slice_=self
)
@classmethod
def import_obj(cls, slc_to_import, import_time=None):
"""Inserts or overrides slc in the database.
remote_id and import_time fields in params_dict are set to track the
slice origin and ensure correct overrides for multiple imports.
Slice.perm is used to find the datasources and connect them.
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(
remote_id=slc_to_import.id, import_time=import_time)
# find if the slice was already imported
slc_to_override = None
for slc in session.query(Slice).all():
if ('remote_id' in slc.params_dict and
slc.params_dict['remote_id'] == slc_to_import.id):
slc_to_override = slc
slc_to_import = slc_to_import.copy()
params = slc_to_import.params_dict
slc_to_import.datasource_id = SourceRegistry.get_datasource_by_name(
session, slc_to_import.datasource_type, params['datasource_name'],
params['schema'], params['database_name']).id
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
return slc_to_override.id
session.add(slc_to_import)
logging.info('Final slice: {}'.format(slc_to_import.to_json()))
session.flush()
return slc_to_import.id
sqla.event.listen(Slice, 'before_insert', set_related_perm)
sqla.event.listen(Slice, 'before_update', set_related_perm)
dashboard_slices = Table(
'dashboard_slices', Model.metadata,
Column('id', Integer, primary_key=True),
Column('dashboard_id', Integer, ForeignKey('dashboards.id')),
Column('slice_id', Integer, ForeignKey('slices.id')),
)
dashboard_user = Table(
'dashboard_user', Model.metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('ab_user.id')),
Column('dashboard_id', Integer, ForeignKey('dashboards.id'))
)
class Dashboard(Model, AuditMixinNullable, ImportMixin):
"""The dashboard object!"""
__tablename__ = 'dashboards'
id = Column(Integer, primary_key=True)
dashboard_title = Column(String(500))
position_json = Column(Text)
description = Column(Text)
css = Column(Text)
json_metadata = Column(Text)
slug = Column(String(255), unique=True)
slices = relationship(
'Slice', secondary=dashboard_slices, backref='dashboards')
owners = relationship("User", secondary=dashboard_user)
export_fields = ('dashboard_title', 'position_json', 'json_metadata',
'description', 'css', 'slug')
def __repr__(self):
return self.dashboard_title
@property
def table_names(self):
return ", ".join(
{"{}".format(s.datasource.name) for s in self.slices})
@property
def url(self):
return "/superset/dashboard/{}/".format(self.slug or self.id)
@property
def datasources(self):
return {slc.datasource for slc in self.slices}
@property
def sqla_metadata(self):
metadata = MetaData(bind=self.get_sqla_engine())
return metadata.reflect()
def dashboard_link(self):
title = escape(self.dashboard_title)
return Markup(
'<a href="{self.url}">{title}</a>'.format(**locals()))
@property
def json_data(self):
positions = self.position_json
if positions:
positions = json.loads(positions)
d = {
'id': self.id,
'metadata': self.params_dict,
'css': self.css,
'dashboard_title': self.dashboard_title,
'slug': self.slug,
'slices': [slc.data for slc in self.slices],
'position_json': positions,
}
return json.dumps(d)
@property
def params(self):
return self.json_metadata
@params.setter
def params(self, value):
self.json_metadata = value
@property
def position_array(self):
if self.position_json:
return json.loads(self.position_json)
return []
@classmethod
def import_obj(cls, dashboard_to_import, import_time=None):
"""Imports the dashboard from the object to the database.
Once dashboard is imported, json_metadata field is extended and stores
remote_id and import_time. It helps to decide if the dashboard has to
be overridden or just copies over. Slices that belong to this
dashboard will be wired to existing tables. This function can be used
to import/export dashboards between multiple superset instances.
Audit metadata isn't copies over.
"""
def alter_positions(dashboard, old_to_new_slc_id_dict):
""" Updates slice_ids in the position json.
Sample position json:
[{
"col": 5,
"row": 10,
"size_x": 4,
"size_y": 2,
"slice_id": "3610"
}]
"""
position_array = dashboard.position_array
for position in position_array:
if 'slice_id' not in position:
continue
old_slice_id = int(position['slice_id'])
if old_slice_id in old_to_new_slc_id_dict:
position['slice_id'] = '{}'.format(
old_to_new_slc_id_dict[old_slice_id])
dashboard.position_json = json.dumps(position_array)
logging.info('Started import of the dashboard: {}'
.format(dashboard_to_import.to_json()))
session = db.session
logging.info('Dashboard has {} slices'
.format(len(dashboard_to_import.slices)))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
slices = copy(dashboard_to_import.slices)
old_to_new_slc_id_dict = {}
new_filter_immune_slices = []
new_expanded_slices = {}
i_params_dict = dashboard_to_import.params_dict
for slc in slices:
logging.info('Importing slice {} from the dashboard: {}'.format(
slc.to_json(), dashboard_to_import.dashboard_title))
new_slc_id = Slice.import_obj(slc, import_time=import_time)
old_to_new_slc_id_dict[slc.id] = new_slc_id
# update json metadata that deals with slice ids
new_slc_id_str = '{}'.format(new_slc_id)
old_slc_id_str = '{}'.format(slc.id)
if ('filter_immune_slices' in i_params_dict and
old_slc_id_str in i_params_dict['filter_immune_slices']):
new_filter_immune_slices.append(new_slc_id_str)
if ('expanded_slices' in i_params_dict and
old_slc_id_str in i_params_dict['expanded_slices']):
new_expanded_slices[new_slc_id_str] = (
i_params_dict['expanded_slices'][old_slc_id_str])
# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
if ('remote_id' in dash.params_dict and
dash.params_dict['remote_id'] ==
dashboard_to_import.id):
existing_dashboard = dash
dashboard_to_import.id = None
alter_positions(dashboard_to_import, old_to_new_slc_id_dict)
dashboard_to_import.alter_params(import_time=import_time)
if new_expanded_slices:
dashboard_to_import.alter_params(
expanded_slices=new_expanded_slices)
if new_filter_immune_slices:
dashboard_to_import.alter_params(
filter_immune_slices=new_filter_immune_slices)
new_slices = session.query(Slice).filter(
Slice.id.in_(old_to_new_slc_id_dict.values())).all()
if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
return existing_dashboard.id
else:
# session.add(dashboard_to_import) causes sqlachemy failures
# related to the attached users / slices. Creating new object
# allows to avoid conflicts in the sql alchemy state.
copied_dash = dashboard_to_import.copy()
copied_dash.slices = new_slices
session.add(copied_dash)
session.flush()
return copied_dash.id
@classmethod
def export_dashboards(cls, dashboard_ids):
copied_dashboards = []
datasource_ids = set()
for dashboard_id in dashboard_ids:
# make sure that dashboard_id is an integer
dashboard_id = int(dashboard_id)
copied_dashboard = (
db.session.query(Dashboard)
.options(subqueryload(Dashboard.slices))
.filter_by(id=dashboard_id).first()
)
make_transient(copied_dashboard)
for slc in copied_dashboard.slices:
datasource_ids.add((slc.datasource_id, slc.datasource_type))
# add extra params for the import
slc.alter_params(
remote_id=slc.id,
datasource_name=slc.datasource.name,
schema=slc.datasource.name,
database_name=slc.datasource.database.name,
)
copied_dashboard.alter_params(remote_id=dashboard_id)
copied_dashboards.append(copied_dashboard)
eager_datasources = []
for dashboard_id, dashboard_type in datasource_ids:
eager_datasource = SourceRegistry.get_eager_datasource(
db.session, dashboard_type, dashboard_id)
eager_datasource.alter_params(
remote_id=eager_datasource.id,
database_name=eager_datasource.database.name,
)
make_transient(eager_datasource)
eager_datasources.append(eager_datasource)
return pickle.dumps({
'dashboards': copied_dashboards,
'datasources': eager_datasources,
})
class Queryable(object):
"""A common interface to objects that are queryable (tables and datasources)"""
@property
def column_names(self):
return sorted([c.column_name for c in self.columns])
@property
def main_dttm_col(self):
return "timestamp"
@property
def groupby_column_names(self):
return sorted([c.column_name for c in self.columns if c.groupby])
@property
def filterable_column_names(self):
return sorted([c.column_name for c in self.columns if c.filterable])
@property
def dttm_cols(self):
return []
@property
def url(self):
return '/{}/edit/{}'.format(self.baselink, self.id)
@property
def explore_url(self):
if self.default_endpoint:
return self.default_endpoint
else:
return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
@property
def data(self):
"""data representation of the datasource sent to the frontend"""
gb_cols = [(col, col) for col in self.groupby_column_names]
all_cols = [(c, c) for c in self.column_names]
order_by_choices = []
for s in sorted(self.column_names):
order_by_choices.append((json.dumps([s, True]), s + ' [asc]'))
order_by_choices.append((json.dumps([s, False]), s + ' [desc]'))
d = {
'id': self.id,
'type': self.type,
'name': self.name,
'metrics_combo': self.metrics_combo,
'order_by_choices': order_by_choices,
'gb_cols': gb_cols,
'all_cols': all_cols,
'filterable_cols': self.filterable_column_names,
}
if (self.type == 'table'):
grains = self.database.grains() or []
if grains:
grains = [(g.name, g.name) for g in grains]
d['granularity_sqla'] = [(c, c) for c in self.dttm_cols]
d['time_grain_sqla'] = grains
return d
class Database(Model, AuditMixinNullable):
"""An ORM object that stores Database related information"""
__tablename__ = 'dbs'
type = "table"
id = Column(Integer, primary_key=True)
database_name = Column(String(250), unique=True)
sqlalchemy_uri = Column(String(1024))
password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
cache_timeout = Column(Integer)
select_as_create_table_as = Column(Boolean, default=False)
expose_in_sqllab = Column(Boolean, default=False)
allow_run_sync = Column(Boolean, default=True)
allow_run_async = Column(Boolean, default=False)
allow_ctas = Column(Boolean, default=False)
allow_dml = Column(Boolean, default=False)
force_ctas_schema = Column(String(250))
extra = Column(Text, default=textwrap.dedent("""\
{
"metadata_params": {},
"engine_params": {}
}
"""))
perm = Column(String(1000))
def __repr__(self):
return self.database_name
@property
def name(self):
return self.database_name
@property
def backend(self):
url = make_url(self.sqlalchemy_uri_decrypted)
return url.get_backend_name()
def set_sqlalchemy_uri(self, uri):
password_mask = "X" * 10
conn = sqla.engine.url.make_url(uri)
if conn.password != password_mask:
# do not over-write the password with the password mask
self.password = conn.password
conn.password = password_mask if conn.password else None
self.sqlalchemy_uri = str(conn) # hides the password
def get_sqla_engine(self, schema=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
params = extra.get('engine_params', {})
if self.backend == 'presto' and schema:
if '/' in url.database:
url.database = url.database.split('/')[0] + '/' + schema
else:
url.database += '/' + schema
elif schema:
url.database = schema
return create_engine(url, **params)
def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
def get_quoter(self):
return self.get_sqla_engine().dialect.identifier_preparer.quote
def get_df(self, sql, schema):
sql = sql.strip().strip(';')
eng = self.get_sqla_engine(schema=schema)
cur = eng.execute(sql, schema=schema)
cols = [col[0] for col in cur.cursor.description]
df = pd.DataFrame(cur.fetchall(), columns=cols)
return df
def compile_sqla_query(self, qry, schema=None):
eng = self.get_sqla_engine(schema=schema)
compiled = qry.compile(eng, compile_kwargs={"literal_binds": True})
return '{}'.format(compiled)
def select_star(
self, table_name, schema=None, limit=100, show_cols=False,
indent=True):
"""Generates a ``select *`` statement in the proper dialect"""
quote = self.get_quoter()
fields = '*'
table = self.get_table(table_name, schema=schema)
if show_cols:
fields = [quote(c.name) for c in table.columns]
if schema:
table_name = schema + '.' + table_name
qry = select(fields).select_from(text(table_name))
if limit:
qry = qry.limit(limit)
sql = self.compile_sqla_query(qry)
if indent:
sql = sqlparse.format(sql, reindent=True)
return sql
def wrap_sql_limit(self, sql, limit=1000):
qry = (
select('*')
.select_from(TextAsFrom(text(sql), ['*'])
.alias('inner_qry')).limit(limit)
)
return self.compile_sqla_query(qry)
def safe_sqlalchemy_uri(self):
return self.sqlalchemy_uri
@property
def inspector(self):
engine = self.get_sqla_engine()
return sqla.inspect(engine)
def all_table_names(self, schema=None):
return sorted(self.inspector.get_table_names(schema))
def all_view_names(self, schema=None):
views = []
try:
views = self.inspector.get_view_names(schema)
except Exception as e:
pass
return views
def all_schema_names(self):
return sorted(self.inspector.get_schema_names())
@property
def db_engine_spec(self):
engine_name = self.get_sqla_engine().name or 'base'
return db_engine_specs.engines.get(
engine_name, db_engine_specs.BaseEngineSpec)
def grains(self):
"""Defines time granularity database-specific expressions.
The idea here is to make it easy for users to change the time grain
form a datetime (maybe the source grain is arbitrary timestamps, daily
or 5 minutes increments) to another, "truncated" datetime. Since
each database has slightly different but similar datetime functions,
this allows a mapping between database engines and actual functions.
"""
return self.db_engine_spec.time_grains
def grains_dict(self):
return {grain.name: grain for grain in self.grains()}
def get_extra(self):
extra = {}
if self.extra:
try:
extra = json.loads(self.extra)
except Exception as e:
logging.error(e)
return extra
def get_table(self, table_name, schema=None):
extra = self.get_extra()
meta = MetaData(**extra.get('metadata_params', {}))
return Table(
table_name, meta,
schema=schema or None,
autoload=True,
autoload_with=self.get_sqla_engine())
def get_columns(self, table_name, schema=None):
return self.inspector.get_columns(table_name, schema)
def get_indexes(self, table_name, schema=None):
return self.inspector.get_indexes(table_name, schema)
def get_pk_constraint(self, table_name, schema=None):
return self.inspector.get_pk_constraint(table_name, schema)
def get_foreign_keys(self, table_name, schema=None):
return self.inspector.get_foreign_keys(table_name, schema)
@property
def sqlalchemy_uri_decrypted(self):
conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
conn.password = self.password
return str(conn)
@property
def sql_url(self):
return '/superset/sql/{}/'.format(self.id)
def get_perm(self):
return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
sqla.event.listen(Database, 'after_insert', set_perm)
sqla.event.listen(Database, 'after_update', set_perm)
class TableColumn(Model, AuditMixinNullable, ImportMixin):
"""ORM object for table columns, each table can have multiple columns"""
__tablename__ = 'table_columns'
id = Column(Integer, primary_key=True)
table_id = Column(Integer, ForeignKey('tables.id'))
table = relationship(
'SqlaTable',
backref=backref('columns', cascade='all, delete-orphan'),
foreign_keys=[table_id])
column_name = Column(String(255))
verbose_name = Column(String(1024))
is_dttm = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
type = Column(String(32), default='')
groupby = Column(Boolean, default=False)
count_distinct = Column(Boolean, default=False)
sum = Column(Boolean, default=False)
avg = Column(Boolean, default=False)
max = Column(Boolean, default=False)
min = Column(Boolean, default=False)
filterable = Column(Boolean, default=False)
expression = Column(Text, default='')
description = Column(Text, default='')
python_date_format = Column(String(255))
database_expression = Column(String(255))
num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG')
date_types = ('DATE', 'TIME')
str_types = ('VARCHAR', 'STRING', 'CHAR')
export_fields = (
'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active',
'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min',
'filterable', 'expression', 'description', 'python_date_format',
'database_expression'
)
def __repr__(self):
return self.column_name
@property
def isnum(self):
return any([t in self.type.upper() for t in self.num_types])
@property
def is_time(self):
return any([t in self.type.upper() for t in self.date_types])
@property
def is_string(self):
return any([t in self.type.upper() for t in self.str_types])
@property
def sqla_col(self):
name = self.column_name
if not self.expression:
col = column(self.column_name).label(name)
else:
col = literal_column(self.expression).label(name)
return col
def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time')
return and_(
col >= text(self.dttm_sql_literal(start_dttm)),
col <= text(self.dttm_sql_literal(end_dttm)),
)
def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
expr = self.expression or self.column_name
if not self.expression and not time_grain:
return column(expr, type_=DateTime).label(DTTM_ALIAS)
if time_grain:
pdf = self.python_date_format
if pdf in ('epoch_s', 'epoch_ms'):
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
grain = self.table.database.grains_dict().get(time_grain, '{col}')
expr = grain.function.format(col=expr)
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
@classmethod
def import_obj(cls, i_column):
def lookup_obj(lookup_column):
return db.session.query(TableColumn).filter(
TableColumn.table_id == lookup_column.table_id,
TableColumn.column_name == lookup_column.column_name).first()
return import_util.import_simple_obj(db.session, i_column, lookup_obj)
def dttm_sql_literal(self, dttm):
"""Convert datetime object to a SQL expression string
If database_expression is empty, the internal dttm
will be parsed as the string with the pattern that
the user inputted (python_date_format)
If database_expression is not empty, the internal dttm
will be parsed as the sql sentence for the database to convert
"""
tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f'
if self.database_expression:
return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
elif tf == 'epoch_s':
return str((dttm - datetime(1970, 1, 1)).total_seconds())
elif tf == 'epoch_ms':
return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
else:
s = self.table.database.db_engine_spec.convert_dttm(
self.type, dttm)
return s or "'{}'".format(dttm.strftime(tf))
class SqlMetric(Model, AuditMixinNullable, ImportMixin):
"""ORM object for metrics, each table can have multiple metrics"""
__tablename__ = 'sql_metrics'
id = Column(Integer, primary_key=True)
metric_name = Column(String(512))
verbose_name = Column(String(1024))
metric_type = Column(String(32))
table_id = Column(Integer, ForeignKey('tables.id'))
table = relationship(
'SqlaTable',
backref=backref('metrics', cascade='all, delete-orphan'),
foreign_keys=[table_id])
expression = Column(Text)
description = Column(Text)
is_restricted = Column(Boolean, default=False, nullable=True)
d3format = Column(String(128))
export_fields = (
'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
'description', 'is_restricted', 'd3format')
@property
def sqla_col(self):
name = self.metric_name
return literal_column(self.expression).label(name)
@property
def perm(self):
return (
"{parent_name}.[{obj.metric_name}](id:{obj.id})"
).format(obj=self,
parent_name=self.table.full_name) if self.table else None
@classmethod
def import_obj(cls, i_metric):
def lookup_obj(lookup_metric):
return db.session.query(SqlMetric).filter(
SqlMetric.table_id == lookup_metric.table_id,
SqlMetric.metric_name == lookup_metric.metric_name).first()
return import_util.import_simple_obj(db.session, i_metric, lookup_obj)
class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
"""An ORM object for SqlAlchemy table references"""
type = "table"
__tablename__ = 'tables'
id = Column(Integer, primary_key=True)
table_name = Column(String(250))
main_dttm_col = Column(String(250))
description = Column(Text)
default_endpoint = Column(Text)
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
is_featured = Column(Boolean, default=False)
filter_select_enabled = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('ab_user.id'))
owner = relationship('User', backref='tables', foreign_keys=[user_id])
database = relationship(
'Database',
backref=backref('tables', cascade='all, delete-orphan'),
foreign_keys=[database_id])
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
schema = Column(String(255))
sql = Column(Text)
params = Column(Text)
perm = Column(String(1000))
baselink = "tablemodelview"
column_cls = TableColumn
metric_cls = SqlMetric
export_fields = (
'table_name', 'main_dttm_col', 'description', 'default_endpoint',
'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema',
'sql', 'params')
__table_args__ = (
sqla.UniqueConstraint(
'database_id', 'schema', 'table_name',
name='_customer_location_uc'),)
def __repr__(self):
return self.name
@property
def description_markeddown(self):
return utils.markdown(self.description)
@property
def link(self):
name = escape(self.name)
return Markup(
'<a href="{self.explore_url}">{name}</a>'.format(**locals()))
@property
def schema_perm(self):
"""Returns schema permission if present, database one otherwise."""
return utils.get_schema_perm(self.database, self.schema)
def get_perm(self):
return (
"[{obj.database}].[{obj.table_name}]"
"(id:{obj.id})").format(obj=self)
@property
def name(self):
if not self.schema:
return self.table_name
return "{}.{}".format(self.schema, self.table_name)
@property
def full_name(self):
return utils.get_datasource_full_name(
self.database, self.table_name, schema=self.schema)
@property
def dttm_cols(self):
l = [c.column_name for c in self.columns if c.is_dttm]
if self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
def num_cols(self):
return [c.column_name for c in self.columns if c.isnum]
@property
def any_dttm_col(self):
cols = self.dttm_cols
if cols:
return cols[0]
@property
def html(self):
t = ((c.column_name, c.type) for c in self.columns)
df = pd.DataFrame(t)
df.columns = ['field', 'type']
return df.to_html(
index=False,
classes=(
"dataframe table table-striped table-bordered "
"table-condensed"))
@property
def metrics_combo(self):
return sorted(
[
(m.metric_name, m.verbose_name or m.metric_name)
for m in self.metrics],
key=lambda x: x[1])
@property
def sql_url(self):
return self.database.sql_url + "?table_name=" + str(self.table_name)
@property
def time_column_grains(self):
return {
"time_columns": self.dttm_cols,
"time_grains": [grain.name for grain in self.database.grains()]
}
def get_col(self, col_name):
columns = self.columns
for col in columns:
if col_name == col.column_name:
return col
def values_for_column(self,
column_name,
from_dttm,
to_dttm,
limit=500):
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
granularity = self.main_dttm_col
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tbl = table(self.table_name)
qry = select([target_col.sqla_col])
qry = qry.select_from(tbl)
qry = qry.distinct(column_name)
qry = qry.limit(limit)
if granularity:
dttm_col = cols[granularity]
timestamp = dttm_col.sqla_col.label('timestamp')
time_filter = [
timestamp >= text(dttm_col.dttm_sql_literal(from_dttm)),
timestamp <= text(dttm_col.dttm_sql_literal(to_dttm)),
]
qry = qry.where(and_(*time_filter))
engine = self.database.get_sqla_engine()
sql = "{}".format(
qry.compile(
engine, compile_kwargs={"literal_binds": True}, ),
)
return pd.read_sql_query(
sql=sql,
con=engine
)
def query( # sqla
self, groupby, metrics,
granularity,
from_dttm, to_dttm,
filter=None, # noqa
is_timeseries=True,
timeseries_limit=15,
timeseries_limit_metric=None,
row_limit=None,
inner_from_dttm=None,
inner_to_dttm=None,
orderby=None,
extras=None,
columns=None):
"""Querying any sqla table from this common interface"""
template_processor = get_template_processor(
table=self, database=self.database)
# For backward compatibility
if granularity not in self.dttm_cols:
granularity = self.main_dttm_col
cols = {col.column_name: col for col in self.columns}
metrics_dict = {m.metric_name: m for m in self.metrics}
qry_start_dttm = datetime.now()
if not granularity and is_timeseries:
raise Exception(_(
"Datetime column not provided as part table configuration "
"and is required by this type of chart"))
for m in metrics:
if m not in metrics_dict:
raise Exception(_("Metric '{}' is not valid".format(m)))
metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics]
timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric)
timeseries_limit_metric_expr = None
if timeseries_limit_metric:
timeseries_limit_metric_expr = \
timeseries_limit_metric.sqla_col
if metrics:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr = literal_column("COUNT(*)").label("ccount")
select_exprs = []
groupby_exprs = []
if groupby:
select_exprs = []
inner_select_exprs = []
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
outer = col.sqla_col
inner = col.sqla_col.label(col.column_name + '__')
groupby_exprs.append(outer)
select_exprs.append(outer)
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
elif columns:
for s in columns:
select_exprs.append(cols[s].sqla_col)
metrics_exprs = []
if granularity:
@compiles(ColumnClause)
def visit_column(element, compiler, **kw):
"""Patch for sqlalchemy bug
TODO: sqlalchemy 1.2 release should be doing this on its own.
Patch only if the column clause is specific for DateTime
set and granularity is selected.
"""
text = compiler.visit_column(element, **kw)
try:
if (
element.is_literal and
hasattr(element.type, 'python_type') and
type(element.type) is DateTime
):
text = text.replace('%%', '%')
except NotImplementedError:
# Some elements raise NotImplementedError for python_type
pass
return text
dttm_col = cols[granularity]
time_grain = extras.get('time_grain_sqla')
if is_timeseries:
timestamp = dttm_col.get_timestamp_expression(time_grain)
select_exprs += [timestamp]
groupby_exprs += [timestamp]
time_filter = dttm_col.get_time_filter(from_dttm, to_dttm)
select_exprs += metrics_exprs
qry = select(select_exprs)
tbl = table(self.table_name)
if self.schema:
tbl.schema = self.schema
# Supporting arbitrary SQL statements in place of tables
if self.sql:
tbl = TextAsFrom(sqla.text(self.sql), []).alias('expr_qry')
if not columns:
qry = qry.group_by(*groupby_exprs)
where_clause_and = []
having_clause_and = []
for col, op, eq in filter:
col_obj = cols[col]
if op in ('in', 'not in'):
splitted = FillterPattern.split(eq)[1::2]
values = [types.replace("'", '').strip() for types in splitted]
cond = col_obj.sqla_col.in_(values)
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
if extras:
where = extras.get('where')
if where:
where_clause_and += [wrap_clause_in_parens(
template_processor.process_template(where))]
having = extras.get('having')
if having:
having_clause_and += [wrap_clause_in_parens(
template_processor.process_template(having))]
if granularity:
qry = qry.where(and_(*([time_filter] + where_clause_and)))
else:
qry = qry.where(and_(*where_clause_and))
qry = qry.having(and_(*having_clause_and))
if groupby:
qry = qry.order_by(desc(main_metric_expr))
elif orderby:
for col, ascending in orderby:
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
qry = qry.limit(row_limit)
if is_timeseries and timeseries_limit and groupby:
# some sql dialects require for order by expressions
# to also be in the select clause
inner_select_exprs += [main_metric_expr]
subq = select(inner_select_exprs)
subq = subq.select_from(tbl)
inner_time_filter = dttm_col.get_time_filter(
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
)
subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
subq = subq.group_by(*inner_groupby_exprs)
ob = main_metric_expr
if timeseries_limit_metric_expr is not None:
ob = timeseries_limit_metric_expr
subq = subq.order_by(desc(ob))
subq = subq.limit(timeseries_limit)
on_clause = []
for i, gb in enumerate(groupby):
on_clause.append(
groupby_exprs[i] == column(gb + '__'))
tbl = tbl.join(subq.alias(), and_(*on_clause))
qry = qry.select_from(tbl)
engine = self.database.get_sqla_engine()
sql = "{}".format(
qry.compile(
engine, compile_kwargs={"literal_binds": True},),
)
sql = sqlparse.format(sql, reindent=True)
status = QueryStatus.SUCCESS
error_message = None
df = None
try:
df = pd.read_sql_query(sql, con=engine)
except Exception as e:
status = QueryStatus.FAILED
error_message = str(e)
return QueryResult(
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,
query=sql,
error_message=error_message)
def get_sqla_table_object(self):
return self.database.get_table(self.table_name, schema=self.schema)
def fetch_metadata(self):
"""Fetches the metadata for the table and merges it in"""
try:
table = self.get_sqla_table_object()
except Exception:
raise Exception(
"Table doesn't seem to exist in the specified database, "
"couldn't fetch column information")
TC = TableColumn # noqa shortcut to class
M = SqlMetric # noqa
metrics = []
any_date_col = None
for col in table.columns:
try:
datatype = "{}".format(col.type).upper()
except Exception as e:
datatype = "UNKNOWN"
logging.error(
"Unrecognized data type in {}.{}".format(table, col.name))
logging.exception(e)
dbcol = (
db.session
.query(TC)
.filter(TC.table == self)
.filter(TC.column_name == col.name)
.first()
)
db.session.flush()
if not dbcol:
dbcol = TableColumn(column_name=col.name, type=datatype)
dbcol.groupby = dbcol.is_string
dbcol.filterable = dbcol.is_string
dbcol.sum = dbcol.isnum
dbcol.avg = dbcol.isnum
dbcol.is_dttm = dbcol.is_time
db.session.merge(self)
self.columns.append(dbcol)
if not any_date_col and dbcol.is_time:
any_date_col = col.name
quoted = "{}".format(
column(dbcol.column_name).compile(dialect=db.engine.dialect))
if dbcol.sum:
metrics.append(M(
metric_name='sum__' + dbcol.column_name,
verbose_name='sum__' + dbcol.column_name,
metric_type='sum',
expression="SUM({})".format(quoted)
))
if dbcol.avg:
metrics.append(M(
metric_name='avg__' + dbcol.column_name,
verbose_name='avg__' + dbcol.column_name,
metric_type='avg',
expression="AVG({})".format(quoted)
))
if dbcol.max:
metrics.append(M(
metric_name='max__' + dbcol.column_name,
verbose_name='max__' + dbcol.column_name,
metric_type='max',
expression="MAX({})".format(quoted)
))
if dbcol.min:
metrics.append(M(
metric_name='min__' + dbcol.column_name,
verbose_name='min__' + dbcol.column_name,
metric_type='min',
expression="MIN({})".format(quoted)
))
if dbcol.count_distinct:
metrics.append(M(
metric_name='count_distinct__' + dbcol.column_name,
verbose_name='count_distinct__' + dbcol.column_name,
metric_type='count_distinct',
expression="COUNT(DISTINCT {})".format(quoted)
))
dbcol.type = datatype
db.session.merge(self)
db.session.commit()
metrics.append(M(
metric_name='count',
verbose_name='COUNT(*)',
metric_type='count',
expression="COUNT(*)"
))
for metric in metrics:
m = (
db.session.query(M)
.filter(M.metric_name == metric.metric_name)
.filter(M.table_id == self.id)
.first()
)
metric.table_id = self.id
if not m:
db.session.add(metric)
db.session.commit()
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
@classmethod
def import_obj(cls, i_datasource, import_time=None):
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copies over.
"""
def lookup_sqlatable(table):
return db.session.query(SqlaTable).join(Database).filter(
SqlaTable.table_name == table.table_name,
SqlaTable.schema == table.schema,
Database.id == table.database_id,
).first()
def lookup_database(table):
return db.session.query(Database).filter_by(
database_name=table.params_dict['database_name']).one()
return import_util.import_datasource(
db.session, i_datasource, lookup_database, lookup_sqlatable,
import_time)
sqla.event.listen(SqlaTable, 'after_insert', set_perm)
sqla.event.listen(SqlaTable, 'after_update', set_perm)
class DruidCluster(Model, AuditMixinNullable):
"""ORM object referencing the Druid clusters"""
__tablename__ = 'clusters'
type = "druid"
id = Column(Integer, primary_key=True)
cluster_name = Column(String(250), unique=True)
coordinator_host = Column(String(255))
coordinator_port = Column(Integer)
coordinator_endpoint = Column(
String(255), default='druid/coordinator/v1/metadata')
broker_host = Column(String(255))
broker_port = Column(Integer)
broker_endpoint = Column(String(255), default='druid/v2')
metadata_last_refreshed = Column(DateTime)
cache_timeout = Column(Integer)
def __repr__(self):
return self.cluster_name
def get_pydruid_client(self):
cli = PyDruid(
"http://{0}:{1}/".format(self.broker_host, self.broker_port),
self.broker_endpoint)
return cli
def get_datasources(self):
endpoint = (
"http://{obj.coordinator_host}:{obj.coordinator_port}/"
"{obj.coordinator_endpoint}/datasources"
).format(obj=self)
return json.loads(requests.get(endpoint).text)
def get_druid_version(self):
endpoint = (
"http://{obj.coordinator_host}:{obj.coordinator_port}/status"
).format(obj=self)
return json.loads(requests.get(endpoint).text)['version']
def refresh_datasources(self, datasource_name=None, merge_flag=False):
"""Refresh metadata of all datasources in the cluster
If ``datasource_name`` is specified, only that datasource is updated
"""
self.druid_version = self.get_druid_version()
for datasource in self.get_datasources():
if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'):
if not datasource_name or datasource_name == datasource:
DruidDatasource.sync_to_db(datasource, self, merge_flag)
@property
def perm(self):
return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self)
@property
def name(self):
return self.cluster_name
class DruidColumn(Model, AuditMixinNullable, ImportMixin):
"""ORM model for storing Druid datasource column metadata"""
__tablename__ = 'columns'
id = Column(Integer, primary_key=True)
datasource_name = Column(
String(255),
ForeignKey('datasources.datasource_name'))
# Setting enable_typechecks=False disables polymorphic inheritance.
datasource = relationship(
'DruidDatasource',
backref=backref('columns', cascade='all, delete-orphan'),
enable_typechecks=False)
column_name = Column(String(255))
is_active = Column(Boolean, default=True)
type = Column(String(32))
groupby = Column(Boolean, default=False)
count_distinct = Column(Boolean, default=False)
sum = Column(Boolean, default=False)
avg = Column(Boolean, default=False)
max = Column(Boolean, default=False)
min = Column(Boolean, default=False)
filterable = Column(Boolean, default=False)
description = Column(Text)
dimension_spec_json = Column(Text)
export_fields = (
'datasource_name', 'column_name', 'is_active', 'type', 'groupby',
'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable',
'description', 'dimension_spec_json'
)
def __repr__(self):
return self.column_name
@property
def isnum(self):
return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT')
@property
def dimension_spec(self):
if self.dimension_spec_json:
return json.loads(self.dimension_spec_json)
def generate_metrics(self):
"""Generate metrics based on the column metadata"""
M = DruidMetric # noqa
metrics = []
metrics.append(DruidMetric(
metric_name='count',
verbose_name='COUNT(*)',
metric_type='count',
json=json.dumps({'type': 'count', 'name': 'count'})
))
# Somehow we need to reassign this for UDAFs
if self.type in ('DOUBLE', 'FLOAT'):
corrected_type = 'DOUBLE'
else:
corrected_type = self.type
if self.sum and self.isnum:
mt = corrected_type.lower() + 'Sum'
name = 'sum__' + self.column_name
metrics.append(DruidMetric(
metric_name=name,
metric_type='sum',
verbose_name='SUM({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
if self.avg and self.isnum:
mt = corrected_type.lower() + 'Avg'
name = 'avg__' + self.column_name
metrics.append(DruidMetric(
metric_name=name,
metric_type='avg',
verbose_name='AVG({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
if self.min and self.isnum:
mt = corrected_type.lower() + 'Min'
name = 'min__' + self.column_name
metrics.append(DruidMetric(
metric_name=name,
metric_type='min',
verbose_name='MIN({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
if self.max and self.isnum:
mt = corrected_type.lower() + 'Max'
name = 'max__' + self.column_name
metrics.append(DruidMetric(
metric_name=name,
metric_type='max',
verbose_name='MAX({})'.format(self.column_name),
json=json.dumps({
'type': mt, 'name': name, 'fieldName': self.column_name})
))
if self.count_distinct:
name = 'count_distinct__' + self.column_name
if self.type == 'hyperUnique' or self.type == 'thetaSketch':
metrics.append(DruidMetric(
metric_name=name,
verbose_name='COUNT(DISTINCT {})'.format(self.column_name),
metric_type=self.type,
json=json.dumps({
'type': self.type,
'name': name,
'fieldName': self.column_name
})
))
else:
mt = 'count_distinct'
metrics.append(DruidMetric(
metric_name=name,
verbose_name='COUNT(DISTINCT {})'.format(self.column_name),
metric_type='count_distinct',
json=json.dumps({
'type': 'cardinality',
'name': name,
'fieldNames': [self.column_name]})
))
session = get_session()
new_metrics = []
for metric in metrics:
m = (
session.query(M)
.filter(M.metric_name == metric.metric_name)
.filter(M.datasource_name == self.datasource_name)
.filter(DruidCluster.cluster_name == self.datasource.cluster_name)
.first()
)
metric.datasource_name = self.datasource_name
if not m:
new_metrics.append(metric)
session.add(metric)
session.flush()
@classmethod
def import_obj(cls, i_column):
def lookup_obj(lookup_column):
return db.session.query(DruidColumn).filter(
DruidColumn.datasource_name == lookup_column.datasource_name,
DruidColumn.column_name == lookup_column.column_name).first()
return import_util.import_simple_obj(db.session, i_column, lookup_obj)
class DruidMetric(Model, AuditMixinNullable, ImportMixin):
"""ORM object referencing Druid metrics for a datasource"""
__tablename__ = 'metrics'
id = Column(Integer, primary_key=True)
metric_name = Column(String(512))
verbose_name = Column(String(1024))
metric_type = Column(String(32))
datasource_name = Column(
String(255),
ForeignKey('datasources.datasource_name'))
# Setting enable_typechecks=False disables polymorphic inheritance.
datasource = relationship(
'DruidDatasource',
backref=backref('metrics', cascade='all, delete-orphan'),
enable_typechecks=False)
json = Column(Text)
description = Column(Text)
is_restricted = Column(Boolean, default=False, nullable=True)
d3format = Column(String(128))
def refresh_datasources(self, datasource_name=None, merge_flag=False):
"""Refresh metadata of all datasources in the cluster
If ``datasource_name`` is specified, only that datasource is updated
"""
self.druid_version = self.get_druid_version()
for datasource in self.get_datasources():
if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'):
if not datasource_name or datasource_name == datasource:
DruidDatasource.sync_to_db(datasource, self, merge_flag)
export_fields = (
'metric_name', 'verbose_name', 'metric_type', 'datasource_name',
'json', 'description', 'is_restricted', 'd3format'
)
@property
def json_obj(self):
try:
obj = json.loads(self.json)
except Exception:
obj = {}
return obj
@property
def perm(self):
return (
"{parent_name}.[{obj.metric_name}](id:{obj.id})"
).format(obj=self,
parent_name=self.datasource.full_name
) if self.datasource else None
@classmethod
def import_obj(cls, i_metric):
def lookup_obj(lookup_metric):
return db.session.query(DruidMetric).filter(
DruidMetric.datasource_name == lookup_metric.datasource_name,
DruidMetric.metric_name == lookup_metric.metric_name).first()
return import_util.import_simple_obj(db.session, i_metric, lookup_obj)
class DruidDatasource(Model, AuditMixinNullable, Queryable, ImportMixin):
"""ORM object referencing Druid datasources (tables)"""
type = "druid"
baselink = "druiddatasourcemodelview"
__tablename__ = 'datasources'
id = Column(Integer, primary_key=True)
datasource_name = Column(String(255), unique=True)
is_featured = Column(Boolean, default=False)
is_hidden = Column(Boolean, default=False)
filter_select_enabled = Column(Boolean, default=False)
description = Column(Text)
default_endpoint = Column(Text)
user_id = Column(Integer, ForeignKey('ab_user.id'))
owner = relationship(
'User',
backref=backref('datasources', cascade='all, delete-orphan'),
foreign_keys=[user_id])
cluster_name = Column(
String(250), ForeignKey('clusters.cluster_name'))
cluster = relationship(
'DruidCluster', backref='datasources', foreign_keys=[cluster_name])
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
params = Column(String(1000))
perm = Column(String(1000))
metric_cls = DruidMetric
column_cls = DruidColumn
export_fields = (
'datasource_name', 'is_hidden', 'description', 'default_endpoint',
'cluster_name', 'is_featured', 'offset', 'cache_timeout', 'params'
)
@property
def database(self):
return self.cluster
@property
def metrics_combo(self):
return sorted(
[(m.metric_name, m.verbose_name) for m in self.metrics],
key=lambda x: x[1])
@property
def database(self):
return self.cluster
@property
def num_cols(self):
return [c.column_name for c in self.columns if c.isnum]
@property
def name(self):
return self.datasource_name
@property
def schema(self):
name_pieces = self.datasource_name.split('.')
if len(name_pieces) > 1:
return name_pieces[0]
else:
return None
@property
def schema_perm(self):
"""Returns schema permission if present, cluster one otherwise."""
return utils.get_schema_perm(self.cluster, self.schema)
def get_perm(self):
return (
"[{obj.cluster_name}].[{obj.datasource_name}]"
"(id:{obj.id})").format(obj=self)
@property
def link(self):
name = escape(self.datasource_name)
return Markup('<a href="{self.url}">{name}</a>').format(**locals())
@property
def full_name(self):
return utils.get_datasource_full_name(
self.cluster_name, self.datasource_name)
@property
def time_column_grains(self):
return {
"time_columns": [
'all', '5 seconds', '30 seconds', '1 minute',
'5 minutes', '1 hour', '6 hour', '1 day', '7 days',
'week', 'week_starting_sunday', 'week_ending_saturday',
'month',
],
"time_grains": ['now']
}
def __repr__(self):
return self.datasource_name
@renders('datasource_name')
def datasource_link(self):
url = "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
name = escape(self.datasource_name)
return Markup('<a href="{url}">{name}</a>'.format(**locals()))
def get_metric_obj(self, metric_name):
return [
m.json_obj for m in self.metrics
if m.metric_name == metric_name
][0]
@classmethod
def import_obj(cls, i_datasource, import_time=None):
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copies over.
"""
def lookup_datasource(d):
return db.session.query(DruidDatasource).join(DruidCluster).filter(
DruidDatasource.datasource_name == d.datasource_name,
DruidCluster.cluster_name == d.cluster_name,
).first()
def lookup_cluster(d):
return db.session.query(DruidCluster).filter_by(
cluster_name=d.cluster_name).one()
return import_util.import_datasource(
db.session, i_datasource, lookup_cluster, lookup_datasource,
import_time)
@staticmethod
def version_higher(v1, v2):
"""is v1 higher than v2
>>> DruidDatasource.version_higher('0.8.2', '0.9.1')
False
>>> DruidDatasource.version_higher('0.8.2', '0.6.1')
True
>>> DruidDatasource.version_higher('0.8.2', '0.8.2')
False
>>> DruidDatasource.version_higher('0.8.2', '0.9.BETA')
False
>>> DruidDatasource.version_higher('0.8.2', '0.9')
False
"""
def int_or_0(v):
try:
v = int(v)
except (TypeError, ValueError):
v = 0
return v
v1nums = [int_or_0(n) for n in v1.split('.')]
v2nums = [int_or_0(n) for n in v2.split('.')]
v1nums = (v1nums + [0, 0, 0])[:3]
v2nums = (v2nums + [0, 0, 0])[:3]
return v1nums[0] > v2nums[0] or \
(v1nums[0] == v2nums[0] and v1nums[1] > v2nums[1]) or \
(v1nums[0] == v2nums[0] and v1nums[1] == v2nums[1] and v1nums[2] > v2nums[2])
def latest_metadata(self):
"""Returns segment metadata from the latest segment"""
client = self.cluster.get_pydruid_client()
results = client.time_boundary(datasource=self.datasource_name)
if not results:
return
max_time = results[0]['result']['maxTime']
max_time = parse(max_time)
# Query segmentMetadata for 7 days back. However, due to a bug,
# we need to set this interval to more than 1 day ago to exclude
# realtime segments, which trigged a bug (fixed in druid 0.8.2).
# https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ
lbound = (max_time - timedelta(days=7)).isoformat()
rbound = max_time.isoformat()
if not self.version_higher(self.cluster.druid_version, '0.8.2'):
rbound = (max_time - timedelta(1)).isoformat()
segment_metadata = None
try:
segment_metadata = client.segment_metadata(
datasource=self.datasource_name,
intervals=lbound + '/' + rbound,
merge=self.merge_flag,
analysisTypes=config.get('DRUID_ANALYSIS_TYPES'))
except Exception as e:
logging.warning("Failed first attempt to get latest segment")
logging.exception(e)
if not segment_metadata:
# if no segments in the past 7 days, look at all segments
lbound = datetime(1901, 1, 1).isoformat()[:10]
rbound = datetime(2050, 1, 1).isoformat()[:10]
if not self.version_higher(self.cluster.druid_version, '0.8.2'):
rbound = datetime.now().isoformat()[:10]
try:
segment_metadata = client.segment_metadata(
datasource=self.datasource_name,
intervals=lbound + '/' + rbound,
merge=self.merge_flag,
analysisTypes=config.get('DRUID_ANALYSIS_TYPES'))
except Exception as e:
logging.warning("Failed 2nd attempt to get latest segment")
logging.exception(e)
if segment_metadata:
return segment_metadata[-1]['columns']
def generate_metrics(self):
for col in self.columns:
col.generate_metrics()
@classmethod
def sync_to_db_from_config(cls, druid_config, user, cluster):
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session()
datasource = (
session.query(DruidDatasource)
.filter_by(
datasource_name=druid_config['name'])
).first()
# Create a new datasource.
if not datasource:
datasource = DruidDatasource(
datasource_name=druid_config['name'],
cluster=cluster,
owner=user,
changed_by_fk=user.id,
created_by_fk=user.id,
)
session.add(datasource)
dimensions = druid_config['dimensions']
for dim in dimensions:
col_obj = (
session.query(DruidColumn)
.filter_by(
datasource_name=druid_config['name'],
column_name=dim)
).first()
if not col_obj:
col_obj = DruidColumn(
datasource_name=druid_config['name'],
column_name=dim,
groupby=True,
filterable=True,
# TODO: fetch type from Hive.
type="STRING",
datasource=datasource
)
session.add(col_obj)
# Import Druid metrics
for metric_spec in druid_config["metrics_spec"]:
metric_name = metric_spec["name"]
metric_type = metric_spec["type"]
metric_json = json.dumps(metric_spec)
if metric_type == "count":
metric_type = "longSum"
metric_json = json.dumps({
"type": "longSum",
"name": metric_name,
"fieldName": metric_name,
})
metric_obj = (
session.query(DruidMetric)
.filter_by(
datasource_name=druid_config['name'],
metric_name=metric_name)
).first()
if not metric_obj:
metric_obj = DruidMetric(
metric_name=metric_name,
metric_type=metric_type,
verbose_name="%s(%s)" % (metric_type, metric_name),
datasource=datasource,
json=metric_json,
description=(
"Imported from the airolap config dir for %s" %
druid_config['name']),
)
session.add(metric_obj)
session.commit()
@classmethod
def sync_to_db(cls, name, cluster, merge):
"""Fetches metadata for that datasource and merges the Superset db"""
logging.info("Syncing Druid datasource [{}]".format(name))
session = get_session()
datasource = session.query(cls).filter_by(datasource_name=name).first()
if not datasource:
datasource = cls(datasource_name=name)
session.add(datasource)
flasher("Adding new datasource [{}]".format(name), "success")
else:
flasher("Refreshing datasource [{}]".format(name), "info")
session.flush()
datasource.cluster = cluster
datasource.merge_flag = merge
session.flush()
cols = datasource.latest_metadata()
if not cols:
logging.error("Failed at fetching the latest segment")
return
for col in cols:
col_obj = (
session
.query(DruidColumn)
.filter_by(datasource_name=name, column_name=col)
.first()
)
datatype = cols[col]['type']
if not col_obj:
col_obj = DruidColumn(datasource_name=name, column_name=col)
session.add(col_obj)
if datatype == "STRING":
col_obj.groupby = True
col_obj.filterable = True
if datatype == "hyperUnique" or datatype == "thetaSketch":
col_obj.count_distinct = True
if col_obj:
col_obj.type = cols[col]['type']
session.flush()
col_obj.datasource = datasource
col_obj.generate_metrics()
session.flush()
@staticmethod
def time_offset(granularity):
if granularity == 'week_ending_saturday':
return 6 * 24 * 3600 * 1000 # 6 days
return 0
# uses https://en.wikipedia.org/wiki/ISO_8601
# http://druid.io/docs/0.8.0/querying/granularities.html
# TODO: pass origin from the UI
@staticmethod
def granularity(period_name, timezone=None, origin=None):
if not period_name or period_name == 'all':
return 'all'
iso_8601_dict = {
'5 seconds': 'PT5S',
'30 seconds': 'PT30S',
'1 minute': 'PT1M',
'5 minutes': 'PT5M',
'1 hour': 'PT1H',
'6 hour': 'PT6H',
'one day': 'P1D',
'1 day': 'P1D',
'7 days': 'P7D',
'week': 'P1W',
'week_starting_sunday': 'P1W',
'week_ending_saturday': 'P1W',
'month': 'P1M',
}
granularity = {'type': 'period'}
if timezone:
granularity['timezone'] = timezone
if origin:
dttm = utils.parse_human_datetime(origin)
granularity['origin'] = dttm.isoformat()
if period_name in iso_8601_dict:
granularity['period'] = iso_8601_dict[period_name]
if period_name in ('week_ending_saturday', 'week_starting_sunday'):
# use Sunday as start of the week
granularity['origin'] = '2016-01-03T00:00:00'
elif not isinstance(period_name, string_types):
granularity['type'] = 'duration'
granularity['duration'] = period_name
elif period_name.startswith('P'):
# identify if the string is the iso_8601 period
granularity['period'] = period_name
else:
granularity['type'] = 'duration'
granularity['duration'] = utils.parse_human_timedelta(
period_name).total_seconds() * 1000
return granularity
def values_for_column(self,
column_name,
from_dttm,
to_dttm,
limit=500):
"""Retrieve some values for the given column"""
# TODO: Use Lexicographic TopNMeticSpec onces supported by PyDruid
from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ"))
to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ"))
qry = dict(
datasource=self.datasource_name,
granularity="all",
intervals=from_dttm.isoformat() + '/' + to_dttm.isoformat(),
aggregations=dict(count=count("count")),
dimension=column_name,
metric="count",
threshold=limit,
)
client = self.cluster.get_pydruid_client()
client.topn(**qry)
df = client.export_pandas()
if df is None or df.size == 0:
raise Exception(_("No data was returned."))
return df
def query( # druid
self, groupby, metrics,
granularity,
from_dttm, to_dttm,
filter=None, # noqa
is_timeseries=True,
timeseries_limit=None,
timeseries_limit_metric=None,
row_limit=None,
inner_from_dttm=None, inner_to_dttm=None,
orderby=None,
extras=None, # noqa
select=None, # noqa
columns=None, ):
"""Runs a query against Druid and returns a dataframe.
This query interface is common to SqlAlchemy and Druid
"""
# TODO refactor into using a TBD Query object
qry_start_dttm = datetime.now()
if not is_timeseries:
granularity = 'all'
inner_from_dttm = inner_from_dttm or from_dttm
inner_to_dttm = inner_to_dttm or to_dttm
# add tzinfo to native datetime with config
from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ"))
to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ"))
timezone = from_dttm.tzname()
query_str = ""
metrics_dict = {m.metric_name: m for m in self.metrics}
all_metrics = []
post_aggs = {}
columns_dict = {c.column_name: c for c in self.columns}
def recursive_get_fields(_conf):
_fields = _conf.get('fields', [])
field_names = []
for _f in _fields:
_type = _f.get('type')
if _type in ['fieldAccess', 'hyperUniqueCardinality']:
field_names.append(_f.get('fieldName'))
elif _type == 'arithmetic':
field_names += recursive_get_fields(_f)
return list(set(field_names))
for metric_name in metrics:
metric = metrics_dict[metric_name]
if metric.metric_type != 'postagg':
all_metrics.append(metric_name)
else:
conf = metric.json_obj
all_metrics += recursive_get_fields(conf)
all_metrics += conf.get('fieldNames', [])
if conf.get('type') == 'javascript':
post_aggs[metric_name] = JavascriptPostAggregator(
name=conf.get('name'),
field_names=conf.get('fieldNames'),
function=conf.get('function'))
else:
post_aggs[metric_name] = Postaggregator(
conf.get('fn', "/"),
conf.get('fields', []),
conf.get('name', ''))
aggregations = OrderedDict()
for m in self.metrics:
if m.metric_name in all_metrics:
aggregations[m.metric_name] = m.json_obj
rejected_metrics = [
m.metric_name for m in self.metrics
if m.is_restricted and
m.metric_name in aggregations.keys() and
not sm.has_access('metric_access', m.perm)
]
if rejected_metrics:
raise MetricPermException(
"Access to the metrics denied: " + ', '.join(rejected_metrics)
)
# the dimensions list with dimensionSpecs expanded
dimensions = []
groupby = [gb for gb in groupby if gb in columns_dict]
for column_name in groupby:
col = columns_dict.get(column_name)
dim_spec = col.dimension_spec
if dim_spec:
dimensions.append(dim_spec)
else:
dimensions.append(column_name)
qry = dict(
datasource=self.datasource_name,
dimensions=dimensions,
aggregations=aggregations,
granularity=DruidDatasource.granularity(
granularity,
timezone=timezone,
origin=extras.get('druid_time_origin'),
),
post_aggregations=post_aggs,
intervals=from_dttm.isoformat() + '/' + to_dttm.isoformat(),
)
filters = self.get_filters(filter)
if filters:
qry['filter'] = filters
having_filters = self.get_having_filters(extras.get('having_druid'))
if having_filters:
qry['having'] = having_filters
client = self.cluster.get_pydruid_client()
orig_filters = filters
if len(groupby) == 0:
del qry['dimensions']
client.timeseries(**qry)
if len(groupby) == 1:
qry['threshold'] = timeseries_limit or 1000
if row_limit and granularity == 'all':
qry['threshold'] = row_limit
qry['dimension'] = list(qry.get('dimensions'))[0]
del qry['dimensions']
qry['metric'] = list(qry['aggregations'].keys())[0]
client.topn(**qry)
elif len(groupby) > 1 or having_filters:
# If grouping on multiple fields or using a having filter
# we have to force a groupby query
if timeseries_limit and is_timeseries:
order_by = metrics[0] if metrics else self.metrics[0]
if timeseries_limit_metric:
order_by = timeseries_limit_metric
# Limit on the number of timeseries, doing a two-phases query
pre_qry = deepcopy(qry)
pre_qry['granularity'] = "all"
pre_qry['limit_spec'] = {
"type": "default",
"limit": timeseries_limit,
'intervals': (
inner_from_dttm.isoformat() + '/' +
inner_to_dttm.isoformat()),
"columns": [{
"dimension": order_by,
"direction": "descending",
}],
}
client.groupby(**pre_qry)
query_str += "// Two phase query\n// Phase 1\n"
query_str += json.dumps(
client.query_builder.last_query.query_dict, indent=2)
query_str += "\n"
query_str += (
"//\nPhase 2 (built based on phase one's results)\n")
df = client.export_pandas()
if df is not None and not df.empty:
dims = qry['dimensions']
filters = []
for unused, row in df.iterrows():
fields = []
for dim in dims:
f = Dimension(dim) == row[dim]
fields.append(f)
if len(fields) > 1:
filt = Filter(type="and", fields=fields)
filters.append(filt)
elif fields:
filters.append(fields[0])
if filters:
ff = Filter(type="or", fields=filters)
if not orig_filters:
qry['filter'] = ff
else:
qry['filter'] = Filter(type="and", fields=[
ff,
orig_filters])
qry['limit_spec'] = None
if row_limit:
qry['limit_spec'] = {
"type": "default",
"limit": row_limit,
"columns": [{
"dimension": (
metrics[0] if metrics else self.metrics[0]),
"direction": "descending",
}],
}
client.groupby(**qry)
query_str += json.dumps(
client.query_builder.last_query.query_dict, indent=2)
df = client.export_pandas()
if df is None or df.size == 0:
raise Exception(_("No data was returned."))
df.columns = [
DTTM_ALIAS if c == 'timestamp' else c for c in df.columns]
if (
not is_timeseries and
granularity == "all" and
DTTM_ALIAS in df.columns):
del df[DTTM_ALIAS]
# Reordering columns
cols = []
if DTTM_ALIAS in df.columns:
cols += [DTTM_ALIAS]
cols += [col for col in groupby if col in df.columns]
cols += [col for col in metrics if col in df.columns]
df = df[cols]
time_offset = DruidDatasource.time_offset(granularity)
def increment_timestamp(ts):
dt = utils.parse_human_datetime(ts).replace(
tzinfo=config.get("DRUID_TZ"))
return dt + timedelta(milliseconds=time_offset)
if DTTM_ALIAS in df.columns and time_offset:
df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(increment_timestamp)
return QueryResult(
df=df,
query=query_str,
duration=datetime.now() - qry_start_dttm)
@staticmethod
def get_filters(raw_filters):
filters = None
for col, op, eq in raw_filters:
cond = None
if op == '==':
cond = Dimension(col) == eq
elif op == '!=':
cond = ~(Dimension(col) == eq)
elif op in ('in', 'not in'):
fields = []
# Distinguish quoted values with regular value types
splitted = FillterPattern.split(eq)[1::2]
values = [types.replace("'", '') for types in splitted]
if len(values) > 1:
for s in values:
s = s.strip()
fields.append(Dimension(col) == s)
cond = Filter(type="or", fields=fields)
else:
cond = Dimension(col) == eq
if op == 'not in':
cond = ~cond
elif op == 'regex':
cond = Filter(type="regex", pattern=eq, dimension=col)
if filters:
filters = Filter(type="and", fields=[
cond,
filters
])
else:
filters = cond
return filters
def _get_having_obj(self, col, op, eq):
cond = None
if op == '==':
if col in self.column_names:
cond = DimSelector(dimension=col, value=eq)
else:
cond = Aggregation(col) == eq
elif op == '>':
cond = Aggregation(col) > eq
elif op == '<':
cond = Aggregation(col) < eq
return cond
def get_having_filters(self, raw_filters):
filters = None
reversed_op_map = {
'!=': '==',
'>=': '<',
'<=': '>'
}
for col, op, eq in raw_filters:
cond = None
if op in ['==', '>', '<']:
cond = self._get_having_obj(col, op, eq)
elif op in reversed_op_map:
cond = ~self._get_having_obj(col, reversed_op_map[op], eq)
if filters:
filters = filters & cond
else:
filters = cond
return filters
sqla.event.listen(DruidDatasource, 'after_insert', set_perm)
sqla.event.listen(DruidDatasource, 'after_update', set_perm)
class Log(Model):
"""ORM object used to log Superset actions to the database"""
__tablename__ = 'logs'
id = Column(Integer, primary_key=True)
action = Column(String(512))
user_id = Column(Integer, ForeignKey('ab_user.id'))
dashboard_id = Column(Integer)
slice_id = Column(Integer)
json = Column(Text)
user = relationship('User', backref='logs', foreign_keys=[user_id])
dttm = Column(DateTime, default=datetime.utcnow)
dt = Column(Date, default=date.today())
duration_ms = Column(Integer)
referrer = Column(String(1024))
@classmethod
def log_this(cls, f):
"""Decorator to log user actions"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
start_dttm = datetime.now()
user_id = None
if g.user:
user_id = g.user.get_id()
d = request.args.to_dict()
post_data = request.form or {}
d.update(post_data)
d.update(kwargs)
slice_id = d.get('slice_id', 0)
try:
slice_id = int(slice_id) if slice_id else 0
except ValueError:
slice_id = 0
params = ""
try:
params = json.dumps(d)
except:
pass
value = f(*args, **kwargs)
sesh = db.session()
log = cls(
action=f.__name__,
json=params,
dashboard_id=d.get('dashboard_id') or None,
slice_id=slice_id,
duration_ms=(
datetime.now() - start_dttm).total_seconds() * 1000,
referrer=request.referrer[:1000] if request.referrer else None,
user_id=user_id)
sesh.add(log)
sesh.commit()
return value
return wrapper
class FavStar(Model):
__tablename__ = 'favstar'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('ab_user.id'))
class_name = Column(String(50))
obj_id = Column(Integer)
dttm = Column(DateTime, default=datetime.utcnow)
class Query(Model):
"""ORM model for SQL query"""
__tablename__ = 'query'
id = Column(Integer, primary_key=True)
client_id = Column(String(11), unique=True, nullable=False)
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
# Store the tmp table into the DB only if the user asks for it.
tmp_table_name = Column(String(256))
user_id = Column(
Integer, ForeignKey('ab_user.id'), nullable=True)
status = Column(String(16), default=QueryStatus.PENDING)
tab_name = Column(String(256))
sql_editor_id = Column(String(256))
schema = Column(String(256))
sql = Column(Text)
# Query to retrieve the results,
# used only in case of select_as_cta_used is true.
select_sql = Column(Text)
executed_sql = Column(Text)
# Could be configured in the superset config.
limit = Column(Integer)
limit_used = Column(Boolean, default=False)
limit_reached = Column(Boolean, default=False)
select_as_cta = Column(Boolean)
select_as_cta_used = Column(Boolean, default=False)
progress = Column(Integer, default=0) # 1..100
# # of rows in the result set or rows modified.
rows = Column(Integer)
error_message = Column(Text)
# key used to store the results in the results backend
results_key = Column(String(64))
# Using Numeric in place of DateTime for sub-second precision
# stored as seconds since epoch, allowing for milliseconds
start_time = Column(Numeric(precision=3))
end_time = Column(Numeric(precision=3))
changed_on = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=True)
database = relationship(
'Database',
foreign_keys=[database_id],
backref=backref('queries', cascade='all, delete-orphan')
)
user = relationship(
'User',
backref=backref('queries', cascade='all, delete-orphan'),
foreign_keys=[user_id])
__table_args__ = (
sqla.Index('ti_user_id_changed_on', user_id, changed_on),
)
@property
def limit_reached(self):
return self.rows == self.limit if self.limit_used else False
def to_dict(self):
return {
'changedOn': self.changed_on,
'changed_on': self.changed_on.isoformat(),
'dbId': self.database_id,
'db': self.database.database_name,
'endDttm': self.end_time,
'errorMessage': self.error_message,
'executedSql': self.executed_sql,
'id': self.client_id,
'limit': self.limit,
'progress': self.progress,
'rows': self.rows,
'schema': self.schema,
'ctas': self.select_as_cta,
'serverId': self.id,
'sql': self.sql,
'sqlEditorId': self.sql_editor_id,
'startDttm': self.start_time,
'state': self.status.lower(),
'tab': self.tab_name,
'tempTable': self.tmp_table_name,
'userId': self.user_id,
'user': self.user.username,
'limit_reached': self.limit_reached,
'resultsKey': self.results_key,
}
@property
def name(self):
ts = datetime.now().isoformat()
ts = ts.replace('-', '').replace(':', '').split('.')[0]
tab = self.tab_name.replace(' ', '_').lower() if self.tab_name else 'notab'
tab = re.sub(r'\W+', '', tab)
return "sqllab_{tab}_{ts}".format(**locals())
class DatasourceAccessRequest(Model, AuditMixinNullable):
"""ORM model for the access requests for datasources and dbs."""
__tablename__ = 'access_request'
id = Column(Integer, primary_key=True)
datasource_id = Column(Integer)
datasource_type = Column(String(200))
ROLES_BLACKLIST = set(config.get('ROBOT_PERMISSION_ROLES', []))
@property
def cls_model(self):
return SourceRegistry.sources[self.datasource_type]
@property
def username(self):
return self.creator()
@property
def datasource(self):
return self.get_datasource
@datasource.getter
@utils.memoized
def get_datasource(self):
ds = db.session.query(self.cls_model).filter_by(
id=self.datasource_id).first()
return ds
@property
def datasource_link(self):
return self.datasource.link
@property
def roles_with_datasource(self):
action_list = ''
pv = sm.find_permission_view_menu(
'datasource_access', self.datasource.perm)
for r in pv.role:
if r.name in self.ROLES_BLACKLIST:
continue
url = (
'/superset/approve?datasource_type={self.datasource_type}&'
'datasource_id={self.datasource_id}&'
'created_by={self.created_by.username}&role_to_grant={r.name}'
.format(**locals())
)
href = '<a href="{}">Grant {} Role</a>'.format(url, r.name)
action_list = action_list + '<li>' + href + '</li>'
return '<ul>' + action_list + '</ul>'
@property
def user_roles(self):
action_list = ''
for r in self.created_by.roles:
url = (
'/superset/approve?datasource_type={self.datasource_type}&'
'datasource_id={self.datasource_id}&'
'created_by={self.created_by.username}&role_to_extend={r.name}'
.format(**locals())
)
href = '<a href="{}">Extend {} Role</a>'.format(url, r.name)
if r.name in self.ROLES_BLACKLIST:
href = "{} Role".format(r.name)
action_list = action_list + '<li>' + href + '</li>'
return '<ul>' + action_list + '</ul>'