blob: 1324a52be473116b6338f1ae22a64e1ddcd0889d [file] [log] [blame]
"""Defines the templating context for SQL Lab"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import inspect
from jinja2.sandbox import SandboxedEnvironment
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import time
import textwrap
import uuid
import random
from superset import app
from superset.utils import SupersetTemplateException
config = app.config
BASE_CONTEXT = {
'datetime': datetime,
'random': random,
'relativedelta': relativedelta,
'time': time,
'timedelta': timedelta,
'uuid': uuid,
}
BASE_CONTEXT.update(config.get('JINJA_CONTEXT_ADDONS', {}))
class BaseTemplateProcessor(object):
"""Base class for database-specific jinja context
There's this bit of magic in ``process_template`` that instantiates only
the database context for the active database as a ``models.Database``
object binds it to the context object, so that object methods
have access to
that context. This way, {{ hive.latest_partition('mytable') }} just
knows about the database it is operating in.
This means that object methods are only available for the active database
and are given access to the ``models.Database`` object and schema
name. For globally available methods use ``@classmethod``.
"""
engine = None
def __init__(self, database=None, query=None, table=None):
self.database = database
self.query = query
self.schema = None
if query and query.schema:
self.schema = query.schema
elif table:
self.schema = table.schema
self.context = {}
self.context.update(BASE_CONTEXT)
if self.engine:
self.context[self.engine] = self
self.env = SandboxedEnvironment()
def process_template(self, sql):
"""Processes a sql template
>>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
>>> process_template(sql)
"SELECT '2017-01-01T00:00:00'"
"""
template = self.env.from_string(sql)
return template.render(self.context)
class PrestoTemplateProcessor(BaseTemplateProcessor):
"""Presto Jinja context
The methods described here are namespaced under ``presto`` in the
jinja context as in ``SELECT '{{ presto.some_macro_call() }}'``
"""
engine = 'presto'
@staticmethod
def _partition_query(table_name, limit=0, order_by=None, filters=None):
"""Returns a partition query
:param table_name: the name of the table to get partitions from
:type table_name: str
:param limit: the number of partitions to be returned
:type limit: int
:param order_by: a list of tuples of field name and a boolean
that determines if that field should be sorted in descending
order
:type order_by: list of (str, bool) tuples
:param filters: a list of filters to apply
:param filters: dict of field anme and filter value combinations
"""
limit_clause = "LIMIT {}".format(limit) if limit else ''
order_by_clause = ''
if order_by:
l = []
for field, desc in order_by:
l.append(field + ' DESC' if desc else '')
order_by_clause = 'ORDER BY ' + ', '.join(l)
where_clause = ''
if filters:
l = []
for field, value in filters.items():
l.append("{field} = '{value}'".format(**locals()))
where_clause = 'WHERE ' + ' AND '.join(l)
sql = textwrap.dedent("""\
SHOW PARTITIONS FROM {table_name}
{where_clause}
{order_by_clause}
{limit_clause}
""").format(**locals())
return sql
@staticmethod
def _schema_table(table_name, schema):
if '.' in table_name:
schema, table_name = table_name.split('.')
return table_name, schema
def latest_partition(self, table_name):
"""Returns the latest (max) partition value for a table
:param table_name: the name of the table, can be just the table
name or a fully qualified table name as ``schema_name.table_name``
:type table_name: str
>>> latest_partition('foo_table')
'2018-01-01'
"""
table_name, schema = self._schema_table(table_name, self.schema)
indexes = self.database.get_indexes(table_name, schema)
if len(indexes[0]['column_names']) < 1:
raise SupersetTemplateException(
"The table should have one partitioned field")
elif len(indexes[0]['column_names']) > 1:
raise SupersetTemplateException(
"The table should have a single partitioned field "
"to use this function. You may want to use "
"`presto.latest_sub_partition`")
part_field = indexes[0]['column_names'][0]
sql = self._partition_query(table_name, 1, [(part_field, True)])
df = self.database.get_df(sql, schema)
return df.to_records(index=False)[0][0]
def latest_sub_partition(self, table_name, **kwargs):
"""Returns the latest (max) partition value for a table
A filtering criteria should be passed for all fields that are
partitioned except for the field to be returned. For example,
if a table is partitioned by (``ds``, ``event_type`` and
``event_category``) and you want the latest ``ds``, you'll want
to provide a filter as keyword arguments for both
``event_type`` and ``event_category`` as in
``latest_sub_partition('my_table',
event_category='page', event_type='click')``
:param table_name: the name of the table, can be just the table
name or a fully qualified table name as ``schema_name.table_name``
:type table_name: str
:param kwargs: keyword arguments define the filtering criteria
on the partition list. There can be many of these.
:type kwargs: str
>>> latest_sub_partition('sub_partition_table', event_type='click')
'2018-01-01'
"""
table_name, schema = self._schema_table(table_name, self.schema)
indexes = self.database.get_indexes(table_name, schema)
part_fields = indexes[0]['column_names']
for k in kwargs.keys():
if k not in k in part_fields:
msg = "Field [{k}] is not part of the partionning key"
raise SupersetTemplateException(msg)
if len(kwargs.keys()) != len(part_fields) - 1:
msg = (
"A filter needs to be specified for {} out of the "
"{} fields."
).format(len(part_fields)-1, len(part_fields))
raise SupersetTemplateException(msg)
for field in part_fields:
if field not in kwargs.keys():
field_to_return = field
sql = self._partition_query(
table_name, 1, [(field_to_return, True)], kwargs)
df = self.database.get_df(sql, schema)
if df.empty:
return ''
return df.to_dict()[field_to_return][0]
template_processors = {}
keys = tuple(globals().keys())
for k in keys:
o = globals()[k]
if o and inspect.isclass(o) and issubclass(o, BaseTemplateProcessor):
template_processors[o.engine] = o
def get_template_processor(database, table=None, query=None):
TP = template_processors.get(database.backend, BaseTemplateProcessor)
return TP(database=database, table=table, query=query)