Implement table name extraction. (#1598)
* Implement table name extraction tests.
* Address comments.
* Fix tests and reimplement the token processing.
* Exclude aliases.
* Clean up print statements and code.
* Reverse select test.
* Fix failing test.
* Test JOINs
* refactore as a class
* Check for permissions in SQL Lab.
* Implement permissions check for the datasources in sql_lab
* Address comments.
diff --git a/superset/models.py b/superset/models.py
index a8e3fdf..bf2e689 100644
--- a/superset/models.py
+++ b/superset/models.py
@@ -665,6 +665,7 @@
"""An ORM object that stores Database related information"""
__tablename__ = 'dbs'
+ type = "table"
id = Column(Integer, primary_key=True)
database_name = Column(String(250), unique=True)
@@ -1524,6 +1525,7 @@
"""ORM object referencing the Druid clusters"""
__tablename__ = 'clusters'
+ type = "druid"
id = Column(Integer, primary_key=True)
cluster_name = Column(String(250), unique=True)
diff --git a/superset/source_registry.py b/superset/source_registry.py
index 0705460..2c72157 100644
--- a/superset/source_registry.py
+++ b/superset/source_registry.py
@@ -41,6 +41,27 @@
return db_ds[0]
@classmethod
+ def query_datasources_by_name(
+ cls, session, database, datasource_name, schema=None):
+ datasource_class = SourceRegistry.sources[database.type]
+ if database.type == 'table':
+ query = (
+ session.query(datasource_class)
+ .filter_by(database_id=database.id)
+ .filter_by(table_name=datasource_name))
+ if schema:
+ query = query.filter_by(schema=schema)
+ return query.all()
+ if database.type == 'druid':
+ return (
+ session.query(datasource_class)
+ .filter_by(cluster_name=database.id)
+ .filter_by(datasource_name=datasource_name)
+ .all()
+ )
+ return None
+
+ @classmethod
def get_eager_datasource(cls, session, datasource_type, datasource_id):
"""Returns datasource with columns and metrics."""
datasource_class = SourceRegistry.sources[datasource_type]
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 5ee6546..0e5b84c 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -11,7 +11,7 @@
from sqlalchemy.orm import sessionmaker
from superset import (
- app, db, models, utils, dataframe, results_backend)
+ app, db, models, utils, dataframe, results_backend, sql_parse, sm)
from superset.db_engine_specs import LimitMethod
from superset.jinja_context import get_template_processor
QueryStatus = models.QueryStatus
@@ -19,16 +19,12 @@
celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
-def is_query_select(sql):
- return sql.upper().startswith('SELECT')
-
-
def create_table_as(sql, table_name, schema=None, override=False):
"""Reformats the query into the create table as query.
Works only for the single select SQL statements, in all other cases
the sql query is not modified.
- :param sql: string, sql query that will be executed
+ :param superset_query: string, sql query that will be executed
:param table_name: string, will contain the results of the query execution
:param override, boolean, table table_name will be dropped if true
:return: string, create table as query
@@ -41,12 +37,9 @@
if schema:
table_name = schema + '.' + table_name
exec_sql = ''
- if is_query_select(sql):
- if override:
- exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
- exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
- else:
- raise Exception("Could not generate CREATE TABLE statement")
+ if override:
+ exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
+ exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
return exec_sql.format(**locals())
@@ -76,12 +69,12 @@
raise Exception(query.error_message)
# Limit enforced only for retrieving the data, not for the CTA queries.
- is_select = is_query_select(executed_sql);
- if not is_select and not database.allow_dml:
+ superset_query = sql_parse.SupersetQuery(executed_sql)
+ if not superset_query.is_select() and not database.allow_dml:
handle_error(
"Only `SELECT` statements are allowed against this database")
if query.select_as_cta:
- if not is_select:
+ if not superset_query.is_select():
handle_error(
"Only `SELECT` statements can be used with the CREATE TABLE "
"feature.")
@@ -94,7 +87,7 @@
executed_sql, query.tmp_table_name, database.force_ctas_schema)
query.select_as_cta_used = True
elif (
- query.limit and is_select and
+ query.limit and superset_query.is_select() and
db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
query.limit_used = True
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
new file mode 100644
index 0000000..8f2c6e0
--- /dev/null
+++ b/superset/sql_parse.py
@@ -0,0 +1,101 @@
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier
+from sqlparse.tokens import Keyword, Name
+
+RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'}
+PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
+
+
+# TODO: some sql_lab logic here.
+class SupersetQuery(object):
+ def __init__(self, sql_statement):
+ self._tokens = []
+ self.sql = sql_statement
+ self._table_names = set()
+ self._alias_names = set()
+ # TODO: multistatement support
+ for statement in sqlparse.parse(self.sql):
+ self.__extract_from_token(statement)
+ self._table_names = self._table_names - self._alias_names
+
+ @property
+ def tables(self):
+ return self._table_names
+
+ # TODO: use sqlparse for this check.
+ def is_select(self):
+ return self.sql.upper().startswith('SELECT')
+
+ @staticmethod
+ def __precedes_table_name(token_value):
+ for keyword in PRECEDES_TABLE_NAME:
+ if keyword in token_value:
+ return True
+ return False
+
+ @staticmethod
+ def __get_full_name(identifier):
+ if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
+ return "{}.{}".format(identifier.tokens[0].value,
+ identifier.tokens[2].value)
+ return identifier.get_real_name()
+
+ @staticmethod
+ def __is_result_operation(keyword):
+ for operation in RESULT_OPERATIONS:
+ if operation in keyword.upper():
+ return True
+ return False
+
+ @staticmethod
+ def __is_identifier(token):
+ return (
+ isinstance(token, IdentifierList) or isinstance(token, Identifier))
+
+ def __process_identifier(self, identifier):
+ # exclude subselects
+ if '(' not in '{}'.format(identifier):
+ self._table_names.add(SupersetQuery.__get_full_name(identifier))
+ return
+
+ # store aliases
+ if hasattr(identifier, 'get_alias'):
+ self._alias_names.add(identifier.get_alias())
+ if hasattr(identifier, 'tokens'):
+ # some aliases are not parsed properly
+ if identifier.tokens[0].ttype == Name:
+ self._alias_names.add(identifier.tokens[0].value)
+ self.__extract_from_token(identifier)
+
+ def __extract_from_token(self, token):
+ if not hasattr(token, 'tokens'):
+ return
+
+ table_name_preceding_token = False
+
+ for item in token.tokens:
+ if item.is_group and not self.__is_identifier(item):
+ self.__extract_from_token(item)
+
+ if item.ttype in Keyword:
+ if SupersetQuery.__precedes_table_name(item.value.upper()):
+ table_name_preceding_token = True
+ continue
+
+ if not table_name_preceding_token:
+ continue
+
+ if item.ttype in Keyword:
+ if SupersetQuery.__is_result_operation(item.value):
+ table_name_preceding_token = False
+ continue
+ # FROM clause is over
+ break
+
+ if isinstance(item, Identifier):
+ self.__process_identifier(item)
+
+ if isinstance(item, IdentifierList):
+ for token in item.tokens:
+ if SupersetQuery.__is_identifier(token):
+ self.__process_identifier(token)
diff --git a/superset/views.py b/superset/views.py
index f0458e5..8768234 100755
--- a/superset/views.py
+++ b/superset/views.py
@@ -36,7 +36,7 @@
import superset
from superset import (
appbuilder, cache, db, models, viz, utils, app,
- sm, sql_lab, results_backend, security,
+ sm, sql_lab, sql_parse, results_backend, security,
)
from superset.source_registry import SourceRegistry
from superset.models import DatasourceAccessRequest as DAR
@@ -74,6 +74,18 @@
self.can_access("datasource_access", datasource.perm)
)
+ def datasource_access_by_name(
+ self, database, datasource_name, schema=None):
+ if (self.database_access(database) or
+ self.all_datasource_access()):
+ return True
+ datasources = SourceRegistry.query_datasources_by_name(
+ db.session, database, datasource_name, schema=schema)
+ for datasource in datasources:
+ if self.can_access("datasource_access", datasource.perm):
+ return True
+ return False
+
class ListWidgetWithCheckboxes(ListWidget):
"""An alternative to list view that renders Boolean fields as checkboxes
@@ -2303,27 +2315,45 @@
@log_this
def sql_json(self):
"""Runs arbitrary sql and returns and json"""
+ def table_accessible(database, full_table_name, schema_name=None):
+ table_name_pieces = full_table_name.split(".")
+ if len(table_name_pieces) == 2:
+ table_schema = table_name_pieces[0]
+ table_name = table_name_pieces[1]
+ else:
+ table_schema = schema_name
+ table_name = table_name_pieces[0]
+ return self.datasource_access_by_name(
+ database, table_name, schema=table_schema)
+
async = request.form.get('runAsync') == 'true'
sql = request.form.get('sql')
database_id = request.form.get('database_id')
session = db.session()
- mydb = session.query(models.Database).filter_by(id=database_id).first()
+ mydb = session.query(models.Database).filter_by(id=database_id).one()
if not mydb:
json_error_response(
'Database with id {} is missing.'.format(database_id))
- if not self.database_access(mydb):
+ superset_query = sql_parse.SupersetQuery(sql)
+ schema = request.form.get('schema')
+ schema = schema if schema else None
+
+ rejected_tables = [
+ t for t in superset_query.tables if not
+ table_accessible(mydb, t, schema_name=schema)]
+ if rejected_tables:
json_error_response(
- get_database_access_error_msg(mydb.database_name))
+ get_datasource_access_error_msg('{}'.format(rejected_tables)))
session.commit()
query = models.Query(
database_id=int(database_id),
limit=int(app.config.get('SQL_MAX_ROW', None)),
sql=sql,
- schema=request.form.get('schema'),
+ schema=schema,
select_as_cta=request.form.get('select_as_cta') == 'true',
start_time=utils.now_as_float(),
tab_name=request.form.get('tab'),
@@ -2341,7 +2371,8 @@
if async:
# Ignore the celery future object and the request may time out.
sql_lab.get_sql_results.delay(
- query_id, return_results=False, store_results=not query.select_as_cta)
+ query_id, return_results=False,
+ store_results=not query.select_as_cta)
return Response(
json.dumps({'query': query.to_dict()},
default=utils.json_int_dttm_ser,
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
new file mode 100644
index 0000000..284e168
--- /dev/null
+++ b/tests/sql_parse_tests.py
@@ -0,0 +1,295 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+
+from superset import sql_parse
+
+
+class SupersetTestCase(unittest.TestCase):
+
+ def extract_tables(self, query):
+ sq = sql_parse.SupersetQuery(query)
+ return sq.tables
+
+ def test_simple_select(self):
+ query = "SELECT * FROM tbname"
+ self.assertEquals({"tbname"}, self.extract_tables(query))
+
+ # underscores
+ query = "SELECT * FROM tb_name"
+ self.assertEquals({"tb_name"},
+ self.extract_tables(query))
+
+ # quotes
+ query = 'SELECT * FROM "tbname"'
+ self.assertEquals({"tbname"}, self.extract_tables(query))
+
+ # schema
+ self.assertEquals(
+ {"schemaname.tbname"},
+ self.extract_tables("SELECT * FROM schemaname.tbname"))
+
+ # quotes
+ query = "SELECT field1, field2 FROM tb_name"
+ self.assertEquals({"tb_name"}, self.extract_tables(query))
+
+ query = "SELECT t1.f1, t2.f2 FROM t1, t2"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ def test_select_named_table(self):
+ query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
+ self.assertEquals(
+ {"left_table"}, self.extract_tables(query))
+
+ def test_reverse_select(self):
+ query = "FROM t1 SELECT field"
+ self.assertEquals({"t1"}, self.extract_tables(query))
+
+ def test_subselect(self):
+ query = """
+ SELECT sub.*
+ FROM (
+ SELECT *
+ FROM s1.t1
+ WHERE day_of_week = 'Friday'
+ ) sub, s2.t2
+ WHERE sub.resolution = 'NONE'
+ """
+ self.assertEquals({"s1.t1", "s2.t2"},
+ self.extract_tables(query))
+
+ query = """
+ SELECT sub.*
+ FROM (
+ SELECT *
+ FROM s1.t1
+ WHERE day_of_week = 'Friday'
+ ) sub
+ WHERE sub.resolution = 'NONE'
+ """
+ self.assertEquals({"s1.t1"}, self.extract_tables(query))
+
+ query = """
+ SELECT * FROM t1
+ WHERE s11 > ANY
+ (SELECT COUNT(*) /* no hint */ FROM t2
+ WHERE NOT EXISTS
+ (SELECT * FROM t3
+ WHERE ROW(5*t2.s1,77)=
+ (SELECT 50,11*s1 FROM t4)));
+ """
+ self.assertEquals({"t1", "t2", "t3", "t4"},
+ self.extract_tables(query))
+
+ def test_select_in_expression(self):
+ query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ def test_union(self):
+ query = "SELECT * FROM t1 UNION SELECT * FROM t2"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ def test_select_from_values(self):
+ query = "SELECT * FROM VALUES (13, 42)"
+ self.assertFalse(self.extract_tables(query))
+
+ def test_select_array(self):
+ query = """
+ SELECT ARRAY[1, 2, 3] AS my_array
+ FROM t1 LIMIT 10
+ """
+ self.assertEquals({"t1"}, self.extract_tables(query))
+
+ def test_select_if(self):
+ query = """
+ SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
+ FROM t1 LIMIT 10
+ """
+ self.assertEquals({"t1"}, self.extract_tables(query))
+
+ # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
+ def test_show_tables(self):
+ query = 'SHOW TABLES FROM s1 like "%order%"'
+ # TODO: figure out what should code do here
+ self.assertEquals({"s1"}, self.extract_tables(query))
+
+ # SHOW COLUMNS (FROM | IN) qualifiedName
+ def test_show_columns(self):
+ query = "SHOW COLUMNS FROM t1"
+ self.assertEquals({"t1"}, self.extract_tables(query))
+
+ def test_where_subquery(self):
+ query = """
+ SELECT name
+ FROM t1
+ WHERE regionkey = (SELECT max(regionkey) FROM t2)
+ """
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ query = """
+ SELECT name
+ FROM t1
+ WHERE regionkey IN (SELECT regionkey FROM t2)
+ """
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ query = """
+ SELECT name
+ FROM t1
+ WHERE regionkey EXISTS (SELECT regionkey FROM t2)
+ """
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ # DESCRIBE | DESC qualifiedName
+ def test_describe(self):
+ self.assertEquals({"t1"}, self.extract_tables("DESCRIBE t1"))
+ self.assertEquals({"t1"}, self.extract_tables("DESC t1"))
+
+ # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
+ # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
+ def test_show_partitions(self):
+ query = """
+ SHOW PARTITIONS FROM orders
+ WHERE ds >= '2013-01-01' ORDER BY ds DESC;
+ """
+ self.assertEquals({"orders"}, self.extract_tables(query))
+
+ def test_join(self):
+ query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ # subquery + join
+ query = """
+ SELECT a.date, b.name FROM
+ left_table a
+ JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+ ) b
+ ON a.date = b.date
+ """
+ self.assertEquals({"left_table", "right_table"},
+ self.extract_tables(query))
+
+ query = """
+ SELECT a.date, b.name FROM
+ left_table a
+ LEFT INNER JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+ ) b
+ ON a.date = b.date
+ """
+ self.assertEquals({"left_table", "right_table"},
+ self.extract_tables(query))
+
+ query = """
+ SELECT a.date, b.name FROM
+ left_table a
+ RIGHT OUTER JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+ ) b
+ ON a.date = b.date
+ """
+ self.assertEquals({"left_table", "right_table"},
+ self.extract_tables(query))
+
+ query = """
+ SELECT a.date, b.name FROM
+ left_table a
+ FULL OUTER JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+ ) b
+ ON a.date = b.date
+ """
+ self.assertEquals({"left_table", "right_table"},
+ self.extract_tables(query))
+
+ # TODO: add SEMI join support, SQL Parse does not handle it.
+ # query = """
+ # SELECT a.date, b.name FROM
+ # left_table a
+ # LEFT SEMI JOIN (
+ # SELECT
+ # CAST((b.year) as VARCHAR) date,
+ # name
+ # FROM right_table
+ # ) b
+ # ON a.date = b.date
+ # """
+ # self.assertEquals({"left_table", "right_table"},
+ # sql_parse.extract_tables(query))
+
+ def test_combinations(self):
+ query = """
+ SELECT * FROM t1
+ WHERE s11 > ANY
+ (SELECT * FROM t1 UNION ALL SELECT * FROM (
+ SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join
+ WHERE NOT EXISTS
+ (SELECT * FROM t3
+ WHERE ROW(5*t3.s1,77)=
+ (SELECT 50,11*s1 FROM t4)));
+ """
+ self.assertEquals({"t1", "t3", "t4", "t6"},
+ self.extract_tables(query))
+
+ query = """
+ SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
+ AS S1) AS S2) AS S3;
+ """
+ self.assertEquals({"EmployeeS"}, self.extract_tables(query))
+
+ def test_with(self):
+ query = """
+ WITH
+ x AS (SELECT a FROM t1),
+ y AS (SELECT a AS b FROM t2),
+ z AS (SELECT b AS c FROM t3)
+ SELECT c FROM z;
+ """
+ self.assertEquals({"t1", "t2", "t3"},
+ self.extract_tables(query))
+
+ query = """
+ WITH
+ x AS (SELECT a FROM t1),
+ y AS (SELECT a AS b FROM x),
+ z AS (SELECT b AS c FROM y)
+ SELECT c FROM z;
+ """
+ self.assertEquals({"t1"}, self.extract_tables(query))
+
+ def test_reusing_aliases(self):
+ query = """
+ with q1 as ( select key from q2 where key = '5'),
+ q2 as ( select key from src where key = '5')
+ select * from (select key from q1) a;
+ """
+ self.assertEquals({"src"}, self.extract_tables(query))
+
+ def multistatement(self):
+ query = "SELECT * FROM t1; SELECT * FROM t2"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+ query = "SELECT * FROM t1; SELECT * FROM t2;"
+ self.assertEquals({"t1", "t2"}, self.extract_tables(query))