| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import unittest |
| |
| import sqlparse |
| |
| from superset.sql_parse import ParsedQuery, strip_comments_from_sql, Table |
| |
| |
| class TestSupersetSqlParse(unittest.TestCase): |
| def extract_tables(self, query): |
| return ParsedQuery(query).tables |
| |
| def test_table(self): |
| self.assertEqual(str(Table("tbname")), "tbname") |
| self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname") |
| |
| self.assertEqual( |
| str(Table("tbname", "schemaname", "catalogname")), |
| "catalogname.schemaname.tbname", |
| ) |
| |
| self.assertEqual( |
| str(Table("tb.name", "schema/name", "catalog\name")), |
| "catalog%0Aame.schema%2Fname.tb%2Ename", |
| ) |
| |
| def test_simple_select(self): |
| query = "SELECT * FROM tbname" |
| self.assertEqual({Table("tbname")}, self.extract_tables(query)) |
| |
| query = "SELECT * FROM tbname foo" |
| self.assertEqual({Table("tbname")}, self.extract_tables(query)) |
| |
| query = "SELECT * FROM tbname AS foo" |
| self.assertEqual({Table("tbname")}, self.extract_tables(query)) |
| |
| # underscores |
| query = "SELECT * FROM tb_name" |
| self.assertEqual({Table("tb_name")}, self.extract_tables(query)) |
| |
| # quotes |
| query = 'SELECT * FROM "tbname"' |
| self.assertEqual({Table("tbname")}, self.extract_tables(query)) |
| |
| # unicode encoding |
| query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"' |
| self.assertEqual({Table("tb_name")}, self.extract_tables(query)) |
| |
| # schema |
| self.assertEqual( |
| {Table("tbname", "schemaname")}, |
| self.extract_tables("SELECT * FROM schemaname.tbname"), |
| ) |
| |
| self.assertEqual( |
| {Table("tbname", "schemaname")}, |
| self.extract_tables('SELECT * FROM "schemaname"."tbname"'), |
| ) |
| |
| self.assertEqual( |
| {Table("tbname", "schemaname")}, |
| self.extract_tables("SELECT * FROM schemaname.tbname foo"), |
| ) |
| |
| self.assertEqual( |
| {Table("tbname", "schemaname")}, |
| self.extract_tables("SELECT * FROM schemaname.tbname AS foo"), |
| ) |
| |
| self.assertEqual( |
| {Table("tbname", "schemaname", "catalogname")}, |
| self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"), |
| ) |
| |
| # Ill-defined cluster/schema/table. |
| self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname.")) |
| |
| self.assertEqual( |
| set(), self.extract_tables("SELECT * FROM catalogname.schemaname.") |
| ) |
| |
| self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname..")) |
| |
| self.assertEqual( |
| set(), self.extract_tables("SELECT * FROM catalogname..tbname") |
| ) |
| |
| # quotes |
| query = "SELECT field1, field2 FROM tb_name" |
| self.assertEqual({Table("tb_name")}, self.extract_tables(query)) |
| |
| query = "SELECT t1.f1, t2.f2 FROM t1, t2" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| def test_select_named_table(self): |
| query = "SELECT a.date, a.field FROM left_table a LIMIT 10" |
| self.assertEqual({Table("left_table")}, self.extract_tables(query)) |
| |
| def test_reverse_select(self): |
| query = "FROM t1 SELECT field" |
| self.assertEqual({Table("t1")}, self.extract_tables(query)) |
| |
| def test_subselect(self): |
| query = """ |
| SELECT sub.* |
| FROM ( |
| SELECT * |
| FROM s1.t1 |
| WHERE day_of_week = 'Friday' |
| ) sub, s2.t2 |
| WHERE sub.resolution = 'NONE' |
| """ |
| self.assertEqual( |
| {Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query) |
| ) |
| |
| query = """ |
| SELECT sub.* |
| FROM ( |
| SELECT * |
| FROM s1.t1 |
| WHERE day_of_week = 'Friday' |
| ) sub |
| WHERE sub.resolution = 'NONE' |
| """ |
| self.assertEqual({Table("t1", "s1")}, self.extract_tables(query)) |
| |
| query = """ |
| SELECT * FROM t1 |
| WHERE s11 > ANY |
| (SELECT COUNT(*) /* no hint */ FROM t2 |
| WHERE NOT EXISTS |
| (SELECT * FROM t3 |
| WHERE ROW(5*t2.s1,77)= |
| (SELECT 50,11*s1 FROM t4))); |
| """ |
| self.assertEqual( |
| {Table("t1"), Table("t2"), Table("t3"), Table("t4")}, |
| self.extract_tables(query), |
| ) |
| |
| def test_select_in_expression(self): |
| query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| query = "SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| def test_parentheses(self): |
| query = "SELECT f1, (x + y) AS f2 FROM t1" |
| self.assertEqual({Table("t1")}, self.extract_tables(query)) |
| |
| def test_union(self): |
| query = "SELECT * FROM t1 UNION SELECT * FROM t2" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| def test_select_from_values(self): |
| query = "SELECT * FROM VALUES (13, 42)" |
| self.assertFalse(self.extract_tables(query)) |
| |
| def test_select_array(self): |
| query = """ |
| SELECT ARRAY[1, 2, 3] AS my_array |
| FROM t1 LIMIT 10 |
| """ |
| self.assertEqual({Table("t1")}, self.extract_tables(query)) |
| |
| def test_select_if(self): |
| query = """ |
| SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) |
| FROM t1 LIMIT 10 |
| """ |
| self.assertEqual({Table("t1")}, self.extract_tables(query)) |
| |
| # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)? |
| def test_show_tables(self): |
| query = "SHOW TABLES FROM s1 like '%order%'" |
| # TODO: figure out what should code do here |
| self.assertEqual({Table("s1")}, self.extract_tables(query)) |
| # Expected behavior is below, it is fixed in sqlparse>=3.1 |
| # However sqlparse==3.1 breaks some sql formatting. |
| # self.assertEqual(set(), self.extract_tables(query)) |
| |
| # SHOW COLUMNS (FROM | IN) qualifiedName |
| def test_show_columns(self): |
| query = "SHOW COLUMNS FROM t1" |
| self.assertEqual({Table("t1")}, self.extract_tables(query)) |
| |
| def test_where_subquery(self): |
| query = """ |
| SELECT name |
| FROM t1 |
| WHERE regionkey = (SELECT max(regionkey) FROM t2) |
| """ |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| query = """ |
| SELECT name |
| FROM t1 |
| WHERE regionkey IN (SELECT regionkey FROM t2) |
| """ |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| query = """ |
| SELECT name |
| FROM t1 |
| WHERE regionkey EXISTS (SELECT regionkey FROM t2) |
| """ |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| # DESCRIBE | DESC qualifiedName |
| def test_describe(self): |
| self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1")) |
| |
| # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)? |
| # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))? |
| def test_show_partitions(self): |
| query = """ |
| SHOW PARTITIONS FROM orders |
| WHERE ds >= '2013-01-01' ORDER BY ds DESC; |
| """ |
| self.assertEqual({Table("orders")}, self.extract_tables(query)) |
| |
| def test_join(self): |
| query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| # subquery + join |
| query = """ |
| SELECT a.date, b.name FROM |
| left_table a |
| JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| self.assertEqual( |
| {Table("left_table"), Table("right_table")}, self.extract_tables(query) |
| ) |
| |
| query = """ |
| SELECT a.date, b.name FROM |
| left_table a |
| LEFT INNER JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| self.assertEqual( |
| {Table("left_table"), Table("right_table")}, self.extract_tables(query) |
| ) |
| |
| query = """ |
| SELECT a.date, b.name FROM |
| left_table a |
| RIGHT OUTER JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| self.assertEqual( |
| {Table("left_table"), Table("right_table")}, self.extract_tables(query) |
| ) |
| |
| query = """ |
| SELECT a.date, b.name FROM |
| left_table a |
| FULL OUTER JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| self.assertEqual( |
| {Table("left_table"), Table("right_table")}, self.extract_tables(query) |
| ) |
| |
| # TODO: add SEMI join support, SQL Parse does not handle it. |
| # query = """ |
| # SELECT a.date, b.name FROM |
| # left_table a |
| # LEFT SEMI JOIN ( |
| # SELECT |
| # CAST((b.year) as VARCHAR) date, |
| # name |
| # FROM right_table |
| # ) b |
| # ON a.date = b.date |
| # """ |
| # self.assertEqual({'left_table', 'right_table'}, |
| # sql_parse.extract_tables(query)) |
| |
| def test_combinations(self): |
| query = """ |
| SELECT * FROM t1 |
| WHERE s11 > ANY |
| (SELECT * FROM t1 UNION ALL SELECT * FROM ( |
| SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join |
| WHERE NOT EXISTS |
| (SELECT * FROM t3 |
| WHERE ROW(5*t3.s1,77)= |
| (SELECT 50,11*s1 FROM t4))); |
| """ |
| self.assertEqual( |
| {Table("t1"), Table("t3"), Table("t4"), Table("t6")}, |
| self.extract_tables(query), |
| ) |
| |
| query = """ |
| SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS) |
| AS S1) AS S2) AS S3; |
| """ |
| self.assertEqual({Table("EmployeeS")}, self.extract_tables(query)) |
| |
| def test_with(self): |
| query = """ |
| WITH |
| x AS (SELECT a FROM t1), |
| y AS (SELECT a AS b FROM t2), |
| z AS (SELECT b AS c FROM t3) |
| SELECT c FROM z; |
| """ |
| self.assertEqual( |
| {Table("t1"), Table("t2"), Table("t3")}, self.extract_tables(query) |
| ) |
| |
| query = """ |
| WITH |
| x AS (SELECT a FROM t1), |
| y AS (SELECT a AS b FROM x), |
| z AS (SELECT b AS c FROM y) |
| SELECT c FROM z; |
| """ |
| self.assertEqual({Table("t1")}, self.extract_tables(query)) |
| |
| def test_reusing_aliases(self): |
| query = """ |
| with q1 as ( select key from q2 where key = '5'), |
| q2 as ( select key from src where key = '5') |
| select * from (select key from q1) a; |
| """ |
| self.assertEqual({Table("src")}, self.extract_tables(query)) |
| |
| def test_multistatement(self): |
| query = "SELECT * FROM t1; SELECT * FROM t2" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| query = "SELECT * FROM t1; SELECT * FROM t2;" |
| self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) |
| |
| def test_update_not_select(self): |
| sql = ParsedQuery("UPDATE t1 SET col1 = NULL") |
| self.assertEqual(False, sql.is_select()) |
| |
| def test_set(self): |
| sql = ParsedQuery( |
| """ |
| -- comment |
| SET hivevar:desc='Legislators'; |
| """ |
| ) |
| |
| self.assertEqual(True, sql.is_set()) |
| self.assertEqual(False, sql.is_select()) |
| |
| self.assertEqual(True, ParsedQuery("set hivevar:desc='bla'").is_set()) |
| self.assertEqual(False, ParsedQuery("SELECT 1").is_set()) |
| |
| def test_show(self): |
| sql = ParsedQuery( |
| """ |
| -- comment |
| SHOW LOCKS test EXTENDED; |
| -- comment |
| """ |
| ) |
| |
| self.assertEqual(True, sql.is_show()) |
| self.assertEqual(False, sql.is_select()) |
| |
| self.assertEqual(True, ParsedQuery("SHOW TABLES").is_show()) |
| self.assertEqual(True, ParsedQuery("shOw TABLES").is_show()) |
| self.assertEqual(True, ParsedQuery("show TABLES").is_show()) |
| self.assertEqual(False, ParsedQuery("SELECT 1").is_show()) |
| |
| def test_explain(self): |
| sql = ParsedQuery("EXPLAIN SELECT 1") |
| |
| self.assertEqual(True, sql.is_explain()) |
| self.assertEqual(False, sql.is_select()) |
| |
| def test_complex_extract_tables(self): |
| query = """SELECT sum(m_examples) AS "sum__m_example" |
| FROM |
| (SELECT COUNT(DISTINCT id_userid) AS m_examples, |
| some_more_info |
| FROM my_b_table b |
| JOIN my_t_table t ON b.ds=t.ds |
| JOIN my_l_table l ON b.uid=l.uid |
| WHERE b.rid IN |
| (SELECT other_col |
| FROM inner_table) |
| AND l.bla IN ('x', 'y') |
| GROUP BY 2 |
| ORDER BY 2 ASC) AS "meh" |
| ORDER BY "sum__m_example" DESC |
| LIMIT 10;""" |
| self.assertEqual( |
| { |
| Table("my_l_table"), |
| Table("my_b_table"), |
| Table("my_t_table"), |
| Table("inner_table"), |
| }, |
| self.extract_tables(query), |
| ) |
| |
| def test_complex_extract_tables2(self): |
| query = """SELECT * |
| FROM table_a AS a, table_b AS b, table_c as c |
| WHERE a.id = b.id and b.id = c.id""" |
| self.assertEqual( |
| {Table("table_a"), Table("table_b"), Table("table_c")}, |
| self.extract_tables(query), |
| ) |
| |
| def test_mixed_from_clause(self): |
| query = """SELECT * |
| FROM table_a AS a, (select * from table_b) AS b, table_c as c |
| WHERE a.id = b.id and b.id = c.id""" |
| self.assertEqual( |
| {Table("table_a"), Table("table_b"), Table("table_c")}, |
| self.extract_tables(query), |
| ) |
| |
| def test_nested_selects(self): |
| query = """ |
| select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) |
| from INFORMATION_SCHEMA.COLUMNS |
| WHERE TABLE_SCHEMA like "%bi%"),0x7e))); |
| """ |
| self.assertEqual( |
| {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query) |
| ) |
| query = """ |
| select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) |
| from INFORMATION_SCHEMA.COLUMNS |
| WHERE TABLE_NAME="bi_achivement_daily"),0x7e))); |
| """ |
| self.assertEqual( |
| {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query) |
| ) |
| |
| def test_complex_extract_tables3(self): |
| query = """SELECT somecol AS somecol |
| FROM |
| (WITH bla AS |
| (SELECT col_a |
| FROM a |
| WHERE 1=1 |
| AND column_of_choice NOT IN |
| ( SELECT interesting_col |
| FROM b ) ), |
| rb AS |
| ( SELECT yet_another_column |
| FROM |
| ( SELECT a |
| FROM c |
| GROUP BY the_other_col ) not_table |
| LEFT JOIN bla foo ON foo.prop = not_table.bad_col0 |
| WHERE 1=1 |
| GROUP BY not_table.bad_col1 , |
| not_table.bad_col2 , |
| ORDER BY not_table.bad_col_3 DESC , |
| not_table.bad_col4 , |
| not_table.bad_col5) SELECT random_col |
| FROM d |
| WHERE 1=1 |
| UNION ALL SELECT even_more_cols |
| FROM e |
| WHERE 1=1 |
| UNION ALL SELECT lets_go_deeper |
| FROM f |
| WHERE 1=1 |
| WHERE 2=2 |
| GROUP BY last_col |
| LIMIT 50000;""" |
| self.assertEqual( |
| {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}, |
| self.extract_tables(query), |
| ) |
| |
| def test_complex_cte_with_prefix(self): |
| query = """ |
| WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) |
| AS ( |
| SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear |
| FROM SalesOrderHeader |
| WHERE SalesPersonID IS NOT NULL |
| ) |
| SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear |
| FROM CTE__test |
| GROUP BY SalesYear, SalesPersonID |
| ORDER BY SalesPersonID, SalesYear; |
| """ |
| self.assertEqual({Table("SalesOrderHeader")}, self.extract_tables(query)) |
| |
| def test_get_query_with_new_limit_comment(self): |
| sql = "SELECT * FROM birth_names -- SOME COMMENT" |
| parsed = ParsedQuery(sql) |
| newsql = parsed.set_or_update_query_limit(1000) |
| self.assertEqual(newsql, sql + "\nLIMIT 1000") |
| |
| def test_get_query_with_new_limit_comment_with_limit(self): |
| sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555" |
| parsed = ParsedQuery(sql) |
| newsql = parsed.set_or_update_query_limit(1000) |
| self.assertEqual(newsql, sql + "\nLIMIT 1000") |
| |
| def test_get_query_with_new_limit_lower(self): |
| sql = "SELECT * FROM birth_names LIMIT 555" |
| parsed = ParsedQuery(sql) |
| newsql = parsed.set_or_update_query_limit(1000) |
| # not applied as new limit is higher |
| expected = "SELECT * FROM birth_names LIMIT 555" |
| self.assertEqual(newsql, expected) |
| |
| def test_get_query_with_new_limit_upper(self): |
| sql = "SELECT * FROM birth_names LIMIT 1555" |
| parsed = ParsedQuery(sql) |
| newsql = parsed.set_or_update_query_limit(1000) |
| # applied as new limit is lower |
| expected = "SELECT * FROM birth_names LIMIT 1000" |
| self.assertEqual(newsql, expected) |
| |
| def test_basic_breakdown_statements(self): |
| multi_sql = """ |
| SELECT * FROM birth_names; |
| SELECT * FROM birth_names LIMIT 1; |
| """ |
| parsed = ParsedQuery(multi_sql) |
| statements = parsed.get_statements() |
| self.assertEqual(len(statements), 2) |
| expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"] |
| self.assertEqual(statements, expected) |
| |
| def test_messy_breakdown_statements(self): |
| multi_sql = """ |
| SELECT 1;\t\n\n\n \t |
| \t\nSELECT 2; |
| SELECT * FROM birth_names;;; |
| SELECT * FROM birth_names LIMIT 1 |
| """ |
| parsed = ParsedQuery(multi_sql) |
| statements = parsed.get_statements() |
| self.assertEqual(len(statements), 4) |
| expected = [ |
| "SELECT 1", |
| "SELECT 2", |
| "SELECT * FROM birth_names", |
| "SELECT * FROM birth_names LIMIT 1", |
| ] |
| self.assertEqual(statements, expected) |
| |
| def test_identifier_list_with_keyword_as_alias(self): |
| query = """ |
| WITH |
| f AS (SELECT * FROM foo), |
| match AS (SELECT * FROM f) |
| SELECT * FROM match |
| """ |
| self.assertEqual({Table("foo")}, self.extract_tables(query)) |
| |
| def test_sqlparse_formatting(self): |
| # sqlparse 0.3.1 has a bug and removes space between from and from_unixtime while formatting: |
| # SELECT extract(HOUR\n fromfrom_unixtime(hour_ts) |
| # AT TIME ZONE 'America/Los_Angeles')\nfrom table |
| self.assertEqual( |
| "SELECT extract(HOUR\n from from_unixtime(hour_ts) " |
| "AT TIME ZONE 'America/Los_Angeles')\nfrom table", |
| sqlparse.format( |
| "SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table", |
| reindent=True, |
| ), |
| ) |
| |
| def test_is_explain(self): |
| query = """ |
| -- comment |
| EXPLAIN select * from table |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query) |
| self.assertEqual(parsed.is_explain(), True) |
| |
| query = """ |
| -- comment |
| EXPLAIN select * from table |
| where col1 = 'something' |
| -- comment 2 |
| |
| -- comment 3 |
| EXPLAIN select * from table |
| where col1 = 'something' |
| -- comment 4 |
| """ |
| parsed = ParsedQuery(query) |
| self.assertEqual(parsed.is_explain(), True) |
| |
| query = """ |
| -- This is a comment |
| -- this is another comment but with a space in the front |
| EXPLAIN SELECT * FROM TABLE |
| """ |
| parsed = ParsedQuery(query) |
| self.assertEqual(parsed.is_explain(), True) |
| |
| query = """ |
| /* This is a comment |
| with stars instead */ |
| EXPLAIN SELECT * FROM TABLE |
| """ |
| parsed = ParsedQuery(query) |
| self.assertEqual(parsed.is_explain(), True) |
| |
| query = """ |
| -- comment |
| select * from table |
| where col1 = 'something' |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query) |
| self.assertEqual(parsed.is_explain(), False) |
| |
| def test_is_valid_ctas(self): |
| """A valid CTAS has a SELECT as its last statement""" |
| query = "SELECT * FROM table" |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert parsed.is_valid_ctas() |
| |
| query = """ |
| -- comment |
| SELECT * FROM table |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert parsed.is_valid_ctas() |
| |
| query = """ |
| -- comment |
| SET @value = 42; |
| SELECT @value as foo; |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert parsed.is_valid_ctas() |
| |
| query = """ |
| -- comment |
| EXPLAIN SELECT * FROM table |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert not parsed.is_valid_ctas() |
| |
| query = """ |
| SELECT * FROM table; |
| INSERT INTO TABLE (foo) VALUES (42); |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert not parsed.is_valid_ctas() |
| |
| def test_is_valid_cvas(self): |
| """A valid CVAS has a single SELECT statement""" |
| query = "SELECT * FROM table" |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert parsed.is_valid_cvas() |
| |
| query = """ |
| -- comment |
| SELECT * FROM table |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert parsed.is_valid_cvas() |
| |
| query = """ |
| -- comment |
| SET @value = 42; |
| SELECT @value as foo; |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert not parsed.is_valid_cvas() |
| |
| query = """ |
| -- comment |
| EXPLAIN SELECT * FROM table |
| -- comment 2 |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert not parsed.is_valid_ctas() |
| |
| query = """ |
| SELECT * FROM table; |
| INSERT INTO TABLE (foo) VALUES (42); |
| """ |
| parsed = ParsedQuery(query, strip_comments=True) |
| assert not parsed.is_valid_ctas() |
| |
| def test_strip_comments_from_sql(self): |
| """Test that we are able to strip comments out of SQL stmts""" |
| |
| assert ( |
| strip_comments_from_sql("SELECT col1, col2 FROM table1") |
| == "SELECT col1, col2 FROM table1" |
| ) |
| assert ( |
| strip_comments_from_sql("SELECT col1, col2 FROM table1\n-- comment") |
| == "SELECT col1, col2 FROM table1\n" |
| ) |
| assert ( |
| strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n") |
| == "SELECT '--abc' as abc, col2 FROM table1" |
| ) |