| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # isort:skip_file |
| """Unit tests for Superset""" |
| |
| from datetime import datetime, timedelta |
| from unittest import mock |
| import random |
| import string |
| |
| import pytest |
| import prison |
| from sqlalchemy.sql import func |
| |
| import tests.integration_tests.test_app # noqa: F401 |
| from superset import db, security_manager |
| from superset.common.db_query_status import QueryStatus |
| from superset.models.core import Database |
| from superset.utils.database import get_example_database, get_main_database |
| from superset.utils import json |
| from superset.models.sql_lab import Query |
| |
| from tests.integration_tests.base_tests import SupersetTestCase |
| from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_SQLLAB_USERNAME |
| |
| QUERIES_FIXTURE_COUNT = 10 |
| |
| |
| class TestQueryApi(SupersetTestCase): |
| def insert_query( |
| self, |
| database_id: int, |
| user_id: int, |
| client_id: str, |
| sql: str = "", |
| select_sql: str = "", |
| executed_sql: str = "", |
| limit: int = 100, |
| progress: int = 100, |
| rows: int = 100, |
| tab_name: str = "", |
| status: str = "success", |
| changed_on: datetime = datetime(2020, 1, 1), |
| ) -> Query: |
| database = db.session.query(Database).get(database_id) |
| user = db.session.query(security_manager.user_model).get(user_id) |
| query = Query( |
| database=database, |
| user=user, |
| client_id=client_id, |
| sql=sql, |
| select_sql=select_sql, |
| executed_sql=executed_sql, |
| limit=limit, |
| progress=progress, |
| rows=rows, |
| tab_name=tab_name, |
| status=status, |
| changed_on=changed_on, |
| ) |
| db.session.add(query) |
| db.session.commit() |
| return query |
| |
| @pytest.fixture |
| def create_queries(self): |
| with self.create_app().app_context(): |
| queries = [] |
| admin_id = self.get_user("admin").id |
| alpha_id = self.get_user("alpha").id |
| example_database_id = get_example_database().id |
| main_database_id = get_main_database().id |
| for cx in range(QUERIES_FIXTURE_COUNT - 1): |
| queries.append( |
| self.insert_query( |
| example_database_id, |
| admin_id, |
| self.get_random_string(), |
| sql=f"SELECT col1, col2 from table{cx}", # noqa: S608 |
| rows=cx, |
| status=QueryStatus.SUCCESS |
| if (cx % 2) == 0 |
| else QueryStatus.RUNNING, |
| ) |
| ) |
| queries.append( |
| self.insert_query( |
| main_database_id, |
| alpha_id, |
| self.get_random_string(), |
| sql=f"SELECT col1, col2 from table{QUERIES_FIXTURE_COUNT}", # noqa: S608 |
| rows=QUERIES_FIXTURE_COUNT, |
| status=QueryStatus.SUCCESS, |
| ) |
| ) |
| |
| yield queries |
| |
| # rollback changes |
| for query in queries: |
| db.session.delete(query) |
| db.session.commit() |
| |
| @staticmethod |
| def get_random_string(length: int = 10): |
| letters = string.ascii_letters |
| return "".join(random.choice(letters) for i in range(length)) # noqa: S311 |
| |
| def test_get_query(self): |
| """ |
| Query API: Test get query |
| """ |
| admin = self.get_user("admin") |
| client_id = self.get_random_string() |
| example_db = get_example_database() |
| query = self.insert_query( |
| example_db.id, |
| admin.id, |
| client_id, |
| sql="SELECT col1, col2 from table1", |
| select_sql="SELECT col1, col2 from table1", |
| executed_sql="SELECT col1, col2 from table1 LIMIT 100", |
| ) |
| self.login(ADMIN_USERNAME) |
| uri = f"api/v1/query/{query.id}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| |
| expected_result = { |
| "database": {"id": example_db.id}, |
| "client_id": client_id, |
| "end_result_backend_time": None, |
| "error_message": None, |
| "executed_sql": "SELECT col1, col2 from table1 LIMIT 100", |
| "limit": 100, |
| "progress": 100, |
| "results_key": None, |
| "rows": 100, |
| "schema": None, |
| "select_as_cta": None, |
| "select_as_cta_used": False, |
| "select_sql": "SELECT col1, col2 from table1", |
| "sql": "SELECT col1, col2 from table1", |
| "sql_editor_id": None, |
| "status": "success", |
| "tab_name": "", |
| "tmp_schema_name": None, |
| "tmp_table_name": None, |
| "tracking_url": None, |
| } |
| data = json.loads(rv.data.decode("utf-8")) |
| assert "changed_on" in data["result"] |
| for key, value in data["result"].items(): |
| # We can't assert timestamp |
| if key not in ( |
| "changed_on", |
| "end_time", |
| "start_running_time", |
| "start_time", |
| "id", |
| ): |
| assert value == expected_result[key] |
| # rollback changes |
| db.session.delete(query) |
| db.session.commit() |
| |
| def test_get_query_not_found(self): |
| """ |
| Query API: Test get query not found |
| """ |
| admin = self.get_user("admin") |
| client_id = self.get_random_string() |
| query = self.insert_query(get_example_database().id, admin.id, client_id) |
| max_id = db.session.query(func.max(Query.id)).scalar() |
| self.login(ADMIN_USERNAME) |
| uri = f"api/v1/query/{max_id + 1}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 404 |
| |
| db.session.delete(query) |
| db.session.commit() |
| |
| def test_get_query_no_data_access(self): |
| """ |
| Query API: Test get query without data access |
| """ |
| gamma1 = self.create_user( |
| "gamma_1", "password", "Gamma", email="gamma1@superset.org" |
| ) |
| gamma2 = self.create_user( |
| "gamma_2", "password", "Gamma", email="gamma2@superset.org" |
| ) |
| # Add SQLLab role to these gamma users, so they have access to queries |
| sqllab_role = self.get_role("sql_lab") |
| gamma1.roles.append(sqllab_role) |
| gamma2.roles.append(sqllab_role) |
| |
| gamma1_client_id = self.get_random_string() |
| gamma2_client_id = self.get_random_string() |
| query_gamma1 = self.insert_query( |
| get_example_database().id, gamma1.id, gamma1_client_id |
| ) |
| query_gamma2 = self.insert_query( |
| get_example_database().id, gamma2.id, gamma2_client_id |
| ) |
| |
| # Gamma1 user, only sees their own queries |
| self.login(username="gamma_1", password="password") # noqa: S106 |
| uri = f"api/v1/query/{query_gamma2.id}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 404 |
| uri = f"api/v1/query/{query_gamma1.id}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| |
| # Gamma2 user, only sees their own queries |
| self.logout() |
| self.login(username="gamma_2", password="password") # noqa: S106 |
| uri = f"api/v1/query/{query_gamma1.id}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 404 |
| uri = f"api/v1/query/{query_gamma2.id}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| |
| # Admin's have the "all query access" permission |
| self.logout() |
| self.login(ADMIN_USERNAME) |
| uri = f"api/v1/query/{query_gamma1.id}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| uri = f"api/v1/query/{query_gamma2.id}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| |
| # rollback changes |
| db.session.delete(query_gamma1) |
| db.session.delete(query_gamma2) |
| db.session.delete(gamma1) |
| db.session.delete(gamma2) |
| db.session.commit() |
| |
| @pytest.mark.usefixtures("create_queries") |
| def test_get_list_query(self): |
| """ |
| Query API: Test get list query |
| """ |
| self.login(ADMIN_USERNAME) |
| uri = "api/v1/query/" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["count"] == QUERIES_FIXTURE_COUNT |
| # check expected columns |
| assert sorted(list(data["result"][0].keys())) == [ # noqa: C414 |
| "changed_on", |
| "database", |
| "end_time", |
| "executed_sql", |
| "id", |
| "rows", |
| "schema", |
| "sql", |
| "sql_tables", |
| "start_time", |
| "status", |
| "tab_name", |
| "tmp_table_name", |
| "tracking_url", |
| "user", |
| ] |
| assert sorted(list(data["result"][0]["user"].keys())) == [ # noqa: C414 |
| "first_name", |
| "id", |
| "last_name", |
| ] |
| assert list(data["result"][0]["database"].keys()) == [ |
| "database_name", |
| ] |
| |
| @pytest.mark.usefixtures("create_queries") |
| def test_get_list_query_filter_sql(self): |
| """ |
| Query API: Test get list query filter |
| """ |
| self.login(ADMIN_USERNAME) |
| arguments = {"filters": [{"col": "sql", "opr": "ct", "value": "table2"}]} |
| uri = f"api/v1/query/?q={prison.dumps(arguments)}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["count"] == 1 |
| |
| @pytest.mark.usefixtures("create_queries") |
| def test_get_list_query_filter_database(self): |
| """ |
| Query API: Test get list query filter database |
| """ |
| self.login(ADMIN_USERNAME) |
| database_id = get_main_database().id |
| arguments = { |
| "filters": [{"col": "database", "opr": "rel_o_m", "value": database_id}] |
| } |
| uri = f"api/v1/query/?q={prison.dumps(arguments)}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["count"] == 1 |
| |
| @pytest.mark.usefixtures("create_queries") |
| def test_get_list_query_filter_user(self): |
| """ |
| Query API: Test get list query filter user |
| """ |
| self.login(ADMIN_USERNAME) |
| alpha_id = self.get_user("alpha").id |
| arguments = {"filters": [{"col": "user", "opr": "rel_o_m", "value": alpha_id}]} |
| uri = f"api/v1/query/?q={prison.dumps(arguments)}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["count"] == 1 |
| |
| @pytest.mark.usefixtures("create_queries") |
| def test_get_list_query_filter_changed_on(self): |
| """ |
| Query API: Test get list query filter changed_on |
| """ |
| self.login(ADMIN_USERNAME) |
| arguments = { |
| "filters": [ |
| {"col": "changed_on", "opr": "lt", "value": "2020-02-01T00:00:00Z"}, |
| {"col": "changed_on", "opr": "gt", "value": "2019-12-30T00:00:00Z"}, |
| ] |
| } |
| uri = f"api/v1/query/?q={prison.dumps(arguments)}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["count"] == QUERIES_FIXTURE_COUNT |
| |
| @pytest.mark.usefixtures("create_queries") |
| def test_get_list_query_order(self): |
| """ |
| Query API: Test get list query filter changed_on |
| """ |
| self.login(ADMIN_USERNAME) |
| order_columns = [ |
| "changed_on", |
| "database.database_name", |
| "rows", |
| "schema", |
| "sql", |
| "tab_name", |
| "user.first_name", |
| ] |
| |
| for order_column in order_columns: |
| arguments = {"order_column": order_column, "order_direction": "asc"} |
| uri = f"api/v1/query/?q={prison.dumps(arguments)}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| |
| def test_get_list_query_no_data_access(self): |
| """ |
| Query API: Test get queries no data access |
| """ |
| admin = self.get_user("admin") |
| client_id = self.get_random_string() |
| query = self.insert_query( |
| get_example_database().id, |
| admin.id, |
| client_id, |
| sql="SELECT col1, col2 from table1", |
| ) |
| |
| self.login(GAMMA_SQLLAB_USERNAME) |
| arguments = {"filters": [{"col": "sql", "opr": "sw", "value": "SELECT col1"}]} |
| uri = f"api/v1/query/?q={prison.dumps(arguments)}" |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["count"] == 0 |
| |
| # rollback changes |
| db.session.delete(query) |
| db.session.commit() |
| |
| def test_get_updated_since(self): |
| """ |
| Query API: Test get queries updated since timestamp |
| """ |
| now = datetime.utcnow() |
| client_id = self.get_random_string() |
| |
| admin = self.get_user("admin") |
| example_db = get_example_database() |
| |
| old_query = self.insert_query( |
| example_db.id, |
| admin.id, |
| self.get_random_string(), |
| sql="SELECT col1, col2 from table1", |
| select_sql="SELECT col1, col2 from table1", |
| executed_sql="SELECT col1, col2 from table1 LIMIT 100", |
| changed_on=now - timedelta(days=3), |
| ) |
| updated_query = self.insert_query( |
| example_db.id, |
| admin.id, |
| client_id, |
| sql="SELECT col1, col2 from table1", |
| select_sql="SELECT col1, col2 from table1", |
| executed_sql="SELECT col1, col2 from table1 LIMIT 100", |
| changed_on=now - timedelta(days=1), |
| ) |
| |
| self.login(ADMIN_USERNAME) |
| timestamp = datetime.timestamp(now - timedelta(days=2)) * 1000 |
| uri = f"api/v1/query/updated_since?q={prison.dumps({'last_updated_ms': timestamp})}" # noqa: E501 |
| rv = self.client.get(uri) |
| assert rv.status_code == 200 |
| |
| expected_result = updated_query.to_dict() |
| data = json.loads(rv.data.decode("utf-8")) |
| assert len(data["result"]) == 1 |
| for key, value in data["result"][0].items(): |
| # We can't assert timestamp |
| if key not in ( |
| "changed_on", |
| "end_time", |
| "start_running_time", |
| "start_time", |
| "id", |
| ): |
| assert value == expected_result[key] |
| # rollback changes |
| db.session.delete(old_query) |
| db.session.delete(updated_query) |
| db.session.commit() |
| |
| @mock.patch("superset.sql_lab.cancel_query") |
| @mock.patch("superset.views.core.db.session") |
| def test_stop_query_not_found( |
| self, mock_superset_db_session, mock_sql_lab_cancel_query |
| ): |
| """ |
| Handles stop query when the DB engine spec does not |
| have a cancel query method (with invalid client_id). |
| """ |
| form_data = {"client_id": "foo2"} |
| query_mock = mock.Mock() |
| query_mock.return_value = None |
| self.login(ADMIN_USERNAME) |
| mock_superset_db_session.query().filter_by().one_or_none = query_mock |
| mock_sql_lab_cancel_query.return_value = True |
| rv = self.client.post( |
| "/api/v1/query/stop", |
| data=json.dumps(form_data), |
| content_type="application/json", |
| ) |
| |
| assert rv.status_code == 404 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["message"] == "Query with client_id foo2 not found" |
| |
| @mock.patch("superset.sql_lab.cancel_query") |
| @mock.patch("superset.views.core.db.session") |
| def test_stop_query(self, mock_superset_db_session, mock_sql_lab_cancel_query): |
| """ |
| Handles stop query when the DB engine spec does not |
| have a cancel query method. |
| """ |
| form_data = {"client_id": "foo"} |
| query_mock = mock.Mock() |
| query_mock.client_id = "foo" |
| query_mock.status = QueryStatus.RUNNING |
| self.login(ADMIN_USERNAME) |
| mock_superset_db_session.query().filter_by().one_or_none().return_value = ( |
| query_mock |
| ) |
| mock_sql_lab_cancel_query.return_value = True |
| rv = self.client.post( |
| "/api/v1/query/stop", |
| data=json.dumps(form_data), |
| content_type="application/json", |
| ) |
| |
| assert rv.status_code == 200 |
| data = json.loads(rv.data.decode("utf-8")) |
| assert data["result"] == "OK" |