Implement table name extraction. (#1598)

* Implement table name extraction tests.

* Address comments.

* Fix tests and reimplement the token processing.

* Exclude aliases.

* Clean up print statements and code.

* Reverse select test.

* Fix failing test.

* Test JOINs

* refactore as a class

* Check for permissions in SQL Lab.

* Implement permissions check for the datasources in sql_lab

* Address comments.
diff --git a/superset/models.py b/superset/models.py
index a8e3fdf..bf2e689 100644
--- a/superset/models.py
+++ b/superset/models.py
@@ -665,6 +665,7 @@
     """An ORM object that stores Database related information"""
 
     __tablename__ = 'dbs'
+    type = "table"
 
     id = Column(Integer, primary_key=True)
     database_name = Column(String(250), unique=True)
@@ -1524,6 +1525,7 @@
     """ORM object referencing the Druid clusters"""
 
     __tablename__ = 'clusters'
+    type = "druid"
 
     id = Column(Integer, primary_key=True)
     cluster_name = Column(String(250), unique=True)
diff --git a/superset/source_registry.py b/superset/source_registry.py
index 0705460..2c72157 100644
--- a/superset/source_registry.py
+++ b/superset/source_registry.py
@@ -41,6 +41,27 @@
         return db_ds[0]
 
     @classmethod
+    def query_datasources_by_name(
+            cls, session, database, datasource_name, schema=None):
+        datasource_class = SourceRegistry.sources[database.type]
+        if database.type == 'table':
+            query = (
+                session.query(datasource_class)
+                .filter_by(database_id=database.id)
+                .filter_by(table_name=datasource_name))
+            if schema:
+                query = query.filter_by(schema=schema)
+            return query.all()
+        if database.type == 'druid':
+            return (
+                session.query(datasource_class)
+                .filter_by(cluster_name=database.id)
+                .filter_by(datasource_name=datasource_name)
+                .all()
+            )
+        return None
+
+    @classmethod
     def get_eager_datasource(cls, session, datasource_type, datasource_id):
         """Returns datasource with columns and metrics."""
         datasource_class = SourceRegistry.sources[datasource_type]
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 5ee6546..0e5b84c 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -11,7 +11,7 @@
 from sqlalchemy.orm import sessionmaker
 
 from superset import (
-    app, db, models, utils, dataframe, results_backend)
+    app, db, models, utils, dataframe, results_backend, sql_parse, sm)
 from superset.db_engine_specs import LimitMethod
 from superset.jinja_context import get_template_processor
 QueryStatus = models.QueryStatus
@@ -19,16 +19,12 @@
 celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
 
 
-def is_query_select(sql):
-    return sql.upper().startswith('SELECT')
-
-
 def create_table_as(sql, table_name, schema=None, override=False):
     """Reformats the query into the create table as query.
 
     Works only for the single select SQL statements, in all other cases
     the sql query is not modified.
-    :param sql: string, sql query that will be executed
+    :param superset_query: string, sql query that will be executed
     :param table_name: string, will contain the results of the query execution
     :param override, boolean, table table_name will be dropped if true
     :return: string, create table as query
@@ -41,12 +37,9 @@
     if schema:
         table_name = schema + '.' + table_name
     exec_sql = ''
-    if is_query_select(sql):
-        if override:
-            exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
-        exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
-    else:
-        raise Exception("Could not generate CREATE TABLE statement")
+    if override:
+        exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
+    exec_sql += "CREATE TABLE {table_name} AS \n{sql}"
     return exec_sql.format(**locals())
 
 
@@ -76,12 +69,12 @@
         raise Exception(query.error_message)
 
     # Limit enforced only for retrieving the data, not for the CTA queries.
-    is_select = is_query_select(executed_sql);
-    if not is_select and not database.allow_dml:
+    superset_query = sql_parse.SupersetQuery(executed_sql)
+    if not superset_query.is_select() and not database.allow_dml:
         handle_error(
             "Only `SELECT` statements are allowed against this database")
     if query.select_as_cta:
-        if not is_select:
+        if not superset_query.is_select():
             handle_error(
                 "Only `SELECT` statements can be used with the CREATE TABLE "
                 "feature.")
@@ -94,7 +87,7 @@
             executed_sql, query.tmp_table_name, database.force_ctas_schema)
         query.select_as_cta_used = True
     elif (
-            query.limit and is_select and
+            query.limit and superset_query.is_select() and
             db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
         executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
         query.limit_used = True
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
new file mode 100644
index 0000000..8f2c6e0
--- /dev/null
+++ b/superset/sql_parse.py
@@ -0,0 +1,101 @@
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier
+from sqlparse.tokens import Keyword, Name
+
+RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'}
+PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
+
+
+# TODO: some sql_lab logic here.
+class SupersetQuery(object):
+    def __init__(self, sql_statement):
+        self._tokens = []
+        self.sql = sql_statement
+        self._table_names = set()
+        self._alias_names = set()
+        # TODO: multistatement support
+        for statement in sqlparse.parse(self.sql):
+            self.__extract_from_token(statement)
+        self._table_names = self._table_names - self._alias_names
+
+    @property
+    def tables(self):
+        return self._table_names
+
+    # TODO: use sqlparse for this check.
+    def is_select(self):
+        return self.sql.upper().startswith('SELECT')
+
+    @staticmethod
+    def __precedes_table_name(token_value):
+        for keyword in PRECEDES_TABLE_NAME:
+            if keyword in token_value:
+                return True
+        return False
+
+    @staticmethod
+    def __get_full_name(identifier):
+        if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
+            return "{}.{}".format(identifier.tokens[0].value,
+                                  identifier.tokens[2].value)
+        return identifier.get_real_name()
+
+    @staticmethod
+    def __is_result_operation(keyword):
+        for operation in RESULT_OPERATIONS:
+            if operation in keyword.upper():
+                return True
+        return False
+
+    @staticmethod
+    def __is_identifier(token):
+        return (
+            isinstance(token, IdentifierList) or isinstance(token, Identifier))
+
+    def __process_identifier(self, identifier):
+        # exclude subselects
+        if '(' not in '{}'.format(identifier):
+            self._table_names.add(SupersetQuery.__get_full_name(identifier))
+            return
+
+        # store aliases
+        if hasattr(identifier, 'get_alias'):
+            self._alias_names.add(identifier.get_alias())
+        if hasattr(identifier, 'tokens'):
+            # some aliases are not parsed properly
+            if identifier.tokens[0].ttype == Name:
+                self._alias_names.add(identifier.tokens[0].value)
+        self.__extract_from_token(identifier)
+
+    def __extract_from_token(self, token):
+        if not hasattr(token, 'tokens'):
+            return
+
+        table_name_preceding_token = False
+
+        for item in token.tokens:
+            if item.is_group and not self.__is_identifier(item):
+                self.__extract_from_token(item)
+
+            if item.ttype in Keyword:
+                if SupersetQuery.__precedes_table_name(item.value.upper()):
+                    table_name_preceding_token = True
+                    continue
+
+            if not table_name_preceding_token:
+                continue
+
+            if item.ttype in Keyword:
+                if SupersetQuery.__is_result_operation(item.value):
+                    table_name_preceding_token = False
+                    continue
+                # FROM clause is over
+                break
+
+            if isinstance(item, Identifier):
+                self.__process_identifier(item)
+
+            if isinstance(item, IdentifierList):
+                for token in item.tokens:
+                    if SupersetQuery.__is_identifier(token):
+                        self.__process_identifier(token)
diff --git a/superset/views.py b/superset/views.py
index f0458e5..8768234 100755
--- a/superset/views.py
+++ b/superset/views.py
@@ -36,7 +36,7 @@
 import superset
 from superset import (
     appbuilder, cache, db, models, viz, utils, app,
-    sm, sql_lab, results_backend, security,
+    sm, sql_lab, sql_parse, results_backend, security,
 )
 from superset.source_registry import SourceRegistry
 from superset.models import DatasourceAccessRequest as DAR
@@ -74,6 +74,18 @@
             self.can_access("datasource_access", datasource.perm)
         )
 
+    def datasource_access_by_name(
+            self, database, datasource_name, schema=None):
+        if (self.database_access(database) or
+                self.all_datasource_access()):
+            return True
+        datasources = SourceRegistry.query_datasources_by_name(
+            db.session, database, datasource_name, schema=schema)
+        for datasource in datasources:
+            if self.can_access("datasource_access", datasource.perm):
+                return True
+        return False
+
 
 class ListWidgetWithCheckboxes(ListWidget):
     """An alternative to list view that renders Boolean fields as checkboxes
@@ -2303,27 +2315,45 @@
     @log_this
     def sql_json(self):
         """Runs arbitrary sql and returns and json"""
+        def table_accessible(database, full_table_name, schema_name=None):
+            table_name_pieces = full_table_name.split(".")
+            if len(table_name_pieces) == 2:
+                table_schema = table_name_pieces[0]
+                table_name = table_name_pieces[1]
+            else:
+                table_schema = schema_name
+                table_name = table_name_pieces[0]
+            return self.datasource_access_by_name(
+                database, table_name, schema=table_schema)
+
         async = request.form.get('runAsync') == 'true'
         sql = request.form.get('sql')
         database_id = request.form.get('database_id')
 
         session = db.session()
-        mydb = session.query(models.Database).filter_by(id=database_id).first()
+        mydb = session.query(models.Database).filter_by(id=database_id).one()
 
         if not mydb:
             json_error_response(
                 'Database with id {} is missing.'.format(database_id))
 
-        if not self.database_access(mydb):
+        superset_query = sql_parse.SupersetQuery(sql)
+        schema = request.form.get('schema')
+        schema = schema if schema else None
+
+        rejected_tables = [
+            t for t in superset_query.tables if not
+            table_accessible(mydb, t, schema_name=schema)]
+        if rejected_tables:
             json_error_response(
-                get_database_access_error_msg(mydb.database_name))
+                get_datasource_access_error_msg('{}'.format(rejected_tables)))
         session.commit()
 
         query = models.Query(
             database_id=int(database_id),
             limit=int(app.config.get('SQL_MAX_ROW', None)),
             sql=sql,
-            schema=request.form.get('schema'),
+            schema=schema,
             select_as_cta=request.form.get('select_as_cta') == 'true',
             start_time=utils.now_as_float(),
             tab_name=request.form.get('tab'),
@@ -2341,7 +2371,8 @@
         if async:
             # Ignore the celery future object and the request may time out.
             sql_lab.get_sql_results.delay(
-                query_id, return_results=False, store_results=not query.select_as_cta)
+                query_id, return_results=False,
+                store_results=not query.select_as_cta)
             return Response(
                 json.dumps({'query': query.to_dict()},
                            default=utils.json_int_dttm_ser,
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
new file mode 100644
index 0000000..284e168
--- /dev/null
+++ b/tests/sql_parse_tests.py
@@ -0,0 +1,295 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+
+from superset import sql_parse
+
+
+class SupersetTestCase(unittest.TestCase):
+
+    def extract_tables(self, query):
+        sq = sql_parse.SupersetQuery(query)
+        return sq.tables
+
+    def test_simple_select(self):
+        query = "SELECT * FROM tbname"
+        self.assertEquals({"tbname"}, self.extract_tables(query))
+
+        # underscores
+        query = "SELECT * FROM tb_name"
+        self.assertEquals({"tb_name"},
+                          self.extract_tables(query))
+
+        # quotes
+        query = 'SELECT * FROM "tbname"'
+        self.assertEquals({"tbname"}, self.extract_tables(query))
+
+        # schema
+        self.assertEquals(
+            {"schemaname.tbname"},
+            self.extract_tables("SELECT * FROM schemaname.tbname"))
+
+        # quotes
+        query = "SELECT field1, field2 FROM tb_name"
+        self.assertEquals({"tb_name"}, self.extract_tables(query))
+
+        query = "SELECT t1.f1, t2.f2 FROM t1, t2"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+    def test_select_named_table(self):
+        query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
+        self.assertEquals(
+            {"left_table"}, self.extract_tables(query))
+
+    def test_reverse_select(self):
+        query = "FROM t1 SELECT field"
+        self.assertEquals({"t1"}, self.extract_tables(query))
+
+    def test_subselect(self):
+        query = """
+          SELECT sub.*
+              FROM (
+                    SELECT *
+                      FROM s1.t1
+                     WHERE day_of_week = 'Friday'
+                   ) sub, s2.t2
+          WHERE sub.resolution = 'NONE'
+        """
+        self.assertEquals({"s1.t1", "s2.t2"},
+                          self.extract_tables(query))
+
+        query = """
+          SELECT sub.*
+              FROM (
+                    SELECT *
+                      FROM s1.t1
+                     WHERE day_of_week = 'Friday'
+                   ) sub
+          WHERE sub.resolution = 'NONE'
+        """
+        self.assertEquals({"s1.t1"}, self.extract_tables(query))
+
+        query = """
+            SELECT * FROM t1
+            WHERE s11 > ANY
+             (SELECT COUNT(*) /* no hint */ FROM t2
+               WHERE NOT EXISTS
+                (SELECT * FROM t3
+                  WHERE ROW(5*t2.s1,77)=
+                    (SELECT 50,11*s1 FROM t4)));
+        """
+        self.assertEquals({"t1", "t2", "t3", "t4"},
+                          self.extract_tables(query))
+
+    def test_select_in_expression(self):
+        query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+    def test_union(self):
+        query = "SELECT * FROM t1 UNION SELECT * FROM t2"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+        query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+        query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+    def test_select_from_values(self):
+        query = "SELECT * FROM VALUES (13, 42)"
+        self.assertFalse(self.extract_tables(query))
+
+    def test_select_array(self):
+        query = """
+            SELECT ARRAY[1, 2, 3] AS my_array
+            FROM t1 LIMIT 10
+        """
+        self.assertEquals({"t1"}, self.extract_tables(query))
+
+    def test_select_if(self):
+        query = """
+            SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
+            FROM t1 LIMIT 10
+        """
+        self.assertEquals({"t1"}, self.extract_tables(query))
+
+    # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
+    def test_show_tables(self):
+        query = 'SHOW TABLES FROM s1 like "%order%"'
+        # TODO: figure out what should code do here
+        self.assertEquals({"s1"}, self.extract_tables(query))
+
+    # SHOW COLUMNS (FROM | IN) qualifiedName
+    def test_show_columns(self):
+        query = "SHOW COLUMNS FROM t1"
+        self.assertEquals({"t1"}, self.extract_tables(query))
+
+    def test_where_subquery(self):
+        query = """
+          SELECT name
+            FROM t1
+            WHERE regionkey = (SELECT max(regionkey) FROM t2)
+        """
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+        query = """
+          SELECT name
+            FROM t1
+            WHERE regionkey IN (SELECT regionkey FROM t2)
+        """
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+        query = """
+          SELECT name
+            FROM t1
+            WHERE regionkey EXISTS (SELECT regionkey FROM t2)
+        """
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+    # DESCRIBE | DESC qualifiedName
+    def test_describe(self):
+        self.assertEquals({"t1"}, self.extract_tables("DESCRIBE t1"))
+        self.assertEquals({"t1"}, self.extract_tables("DESC t1"))
+
+    # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
+    # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
+    def test_show_partitions(self):
+        query = """
+            SHOW PARTITIONS FROM orders
+            WHERE ds >= '2013-01-01' ORDER BY ds DESC;
+        """
+        self.assertEquals({"orders"}, self.extract_tables(query))
+
+    def test_join(self):
+        query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+        # subquery + join
+        query = """
+            SELECT a.date, b.name FROM
+                left_table a
+                JOIN (
+                  SELECT
+                    CAST((b.year) as VARCHAR) date,
+                    name
+                  FROM right_table
+                ) b
+                ON a.date = b.date
+        """
+        self.assertEquals({"left_table", "right_table"},
+                          self.extract_tables(query))
+
+        query = """
+            SELECT a.date, b.name FROM
+                left_table a
+                LEFT INNER JOIN (
+                  SELECT
+                    CAST((b.year) as VARCHAR) date,
+                    name
+                  FROM right_table
+                ) b
+                ON a.date = b.date
+        """
+        self.assertEquals({"left_table", "right_table"},
+                          self.extract_tables(query))
+
+        query = """
+            SELECT a.date, b.name FROM
+                left_table a
+                RIGHT OUTER JOIN (
+                  SELECT
+                    CAST((b.year) as VARCHAR) date,
+                    name
+                  FROM right_table
+                ) b
+                ON a.date = b.date
+        """
+        self.assertEquals({"left_table", "right_table"},
+                          self.extract_tables(query))
+
+        query = """
+            SELECT a.date, b.name FROM
+                left_table a
+                FULL OUTER JOIN (
+                  SELECT
+                    CAST((b.year) as VARCHAR) date,
+                    name
+                  FROM right_table
+                ) b
+                ON a.date = b.date
+        """
+        self.assertEquals({"left_table", "right_table"},
+                          self.extract_tables(query))
+
+        # TODO: add SEMI join support, SQL Parse does not handle it.
+        # query = """
+        #     SELECT a.date, b.name FROM
+        #         left_table a
+        #         LEFT SEMI JOIN (
+        #           SELECT
+        #             CAST((b.year) as VARCHAR) date,
+        #             name
+        #           FROM right_table
+        #         ) b
+        #         ON a.date = b.date
+        # """
+        # self.assertEquals({"left_table", "right_table"},
+        #                   sql_parse.extract_tables(query))
+
+    def test_combinations(self):
+        query = """
+            SELECT * FROM t1
+            WHERE s11 > ANY
+             (SELECT * FROM t1 UNION ALL SELECT * FROM (
+               SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join
+               WHERE NOT EXISTS
+                (SELECT * FROM t3
+                  WHERE ROW(5*t3.s1,77)=
+                    (SELECT 50,11*s1 FROM t4)));
+        """
+        self.assertEquals({"t1", "t3", "t4", "t6"},
+                          self.extract_tables(query))
+
+        query = """
+        SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
+            AS S1) AS S2) AS S3;
+        """
+        self.assertEquals({"EmployeeS"}, self.extract_tables(query))
+
+    def test_with(self):
+        query = """
+            WITH
+              x AS (SELECT a FROM t1),
+              y AS (SELECT a AS b FROM t2),
+              z AS (SELECT b AS c FROM t3)
+            SELECT c FROM z;
+        """
+        self.assertEquals({"t1", "t2", "t3"},
+                          self.extract_tables(query))
+
+        query = """
+            WITH
+              x AS (SELECT a FROM t1),
+              y AS (SELECT a AS b FROM x),
+              z AS (SELECT b AS c FROM y)
+            SELECT c FROM z;
+        """
+        self.assertEquals({"t1"}, self.extract_tables(query))
+
+    def test_reusing_aliases(self):
+        query = """
+            with q1 as ( select key from q2 where key = '5'),
+            q2 as ( select key from src where key = '5')
+            select * from (select key from q1) a;
+        """
+        self.assertEquals({"src"}, self.extract_tables(query))
+
+    def multistatement(self):
+        query = "SELECT * FROM t1; SELECT * FROM t2"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))
+
+        query = "SELECT * FROM t1; SELECT * FROM t2;"
+        self.assertEquals({"t1", "t2"}, self.extract_tables(query))