| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # isort:skip_file |
| """Unit tests for Superset""" |
| |
| import datetime |
| import random |
| import csv |
| import pandas as pd |
| import io |
| |
| import pytest |
| import prison |
| from sqlalchemy.sql import func # noqa: F401 |
| from unittest import mock |
| |
| from flask_appbuilder.security.sqla.models import Role |
| from tests.integration_tests.test_app import app |
| from superset import db, sql_lab |
| from superset.common.db_query_status import QueryStatus |
| from superset.models.core import Database # noqa: F401 |
| from superset.utils.database import get_example_database, get_main_database # noqa: F401 |
| from superset.utils import core as utils, json |
| from superset.models.sql_lab import Query |
| |
| from tests.integration_tests.base_tests import SupersetTestCase |
| from tests.integration_tests.constants import ( |
| ADMIN_USERNAME, |
| GAMMA_SQLLAB_NO_DATA_USERNAME, |
| ) |
| from tests.integration_tests.fixtures.birth_names_dashboard import load_birth_names_data # noqa: F401 |
| from tests.integration_tests.fixtures.users import create_gamma_sqllab_no_data # noqa: F401 |
| |
| QUERIES_FIXTURE_COUNT = 10 |
| |
| |
| class TestSqlLabApi(SupersetTestCase): |
| @pytest.mark.usefixtures("create_gamma_sqllab_no_data") |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| {"SQLLAB_BACKEND_PERSISTENCE": False}, |
| clear=True, |
| ) |
| def test_get_from_empty_bootstrap_data(self): |
| if utils.backend() == "postgresql": |
| # failing |
| return |
| |
| self.login(GAMMA_SQLLAB_NO_DATA_USERNAME) |
| resp = self.client.get("/api/v1/sqllab/") |
| assert resp.status_code == 200 |
| data = json.loads(resp.data.decode("utf-8")) |
| result = data.get("result") |
| assert result["active_tab"] is None # noqa: E711 |
| assert result["tab_state_ids"] == [] |
| assert len(result["databases"]) == 0 |
| |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| {"SQLLAB_BACKEND_PERSISTENCE": False}, |
| clear=True, |
| ) |
| def test_get_from_bootstrap_data_for_non_persisted_tab_state(self): |
| self.login(ADMIN_USERNAME) |
| # create a tab |
| data = { |
| "queryEditor": json.dumps( |
| { |
| "title": "Untitled Query 1", |
| "dbId": 1, |
| "schema": None, |
| "autorun": False, |
| "sql": "SELECT ...", |
| "queryLimit": 1000, |
| } |
| ) |
| } |
| self.get_json_resp("/tabstateview/", data=data) |
| resp = self.client.get("/api/v1/sqllab/") |
| assert resp.status_code == 200 |
| data = json.loads(resp.data.decode("utf-8")) |
| result = data.get("result") |
| assert result["active_tab"] is None # noqa: E711 |
| assert result["tab_state_ids"] == [] |
| |
| @pytest.mark.usefixtures("load_birth_names_data") |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| {"SQLLAB_BACKEND_PERSISTENCE": True}, |
| clear=True, |
| ) |
| def test_get_from_bootstrap_data_with_latest_query(self): |
| self.login(ADMIN_USERNAME) |
| |
| # create a tab |
| data = { |
| "queryEditor": json.dumps( |
| { |
| "title": "Untitled Query 1", |
| "dbId": 1, |
| "schema": None, |
| "autorun": False, |
| "sql": "SELECT ...", |
| "queryLimit": 1000, |
| } |
| ) |
| } |
| resp = self.get_json_resp("/tabstateview/", data=data) |
| tab_state_id = resp["id"] |
| |
| # we should have only 1 query returned, since the second one is not |
| # associated with any tabs |
| resp = self.get_json_resp("/api/v1/sqllab/") |
| result = resp["result"] |
| assert result["active_tab"]["id"] == tab_state_id |
| |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| {"SQLLAB_BACKEND_PERSISTENCE": True}, |
| clear=True, |
| ) |
| def test_deleted_tab(self): |
| username = "admin" |
| self.login(username) |
| data = { |
| "queryEditor": json.dumps( |
| { |
| "title": "Untitled Query 2", |
| "dbId": 1, |
| "schema": None, |
| "autorun": False, |
| "sql": "SELECT ...", |
| "queryLimit": 1000, |
| } |
| ) |
| } |
| resp = self.get_json_resp("/tabstateview/", data=data) |
| tab_state_id = resp["id"] |
| resp = self.client.delete("/tabstateview/" + str(tab_state_id)) |
| assert resp.status_code == 200 |
| resp = self.client.get("/tabstateview/" + str(tab_state_id)) |
| assert resp.status_code == 404 |
| resp = self.client.put( |
| "/tabstateview/" + str(tab_state_id), |
| json=data, |
| ) |
| assert resp.status_code == 404 |
| |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| {"SQLLAB_BACKEND_PERSISTENCE": True}, |
| clear=True, |
| ) |
| def test_delete_tab_already_removed(self): |
| username = "admin" |
| self.login(username) |
| data = { |
| "queryEditor": json.dumps( |
| { |
| "title": "Untitled Query 3", |
| "dbId": 1, |
| "schema": None, |
| "autorun": False, |
| "sql": "SELECT ...", |
| "queryLimit": 1000, |
| } |
| ) |
| } |
| resp = self.get_json_resp("/tabstateview/", data=data) |
| tab_state_id = resp["id"] |
| resp = self.client.delete("/tabstateview/" + str(tab_state_id)) |
| assert resp.status_code == 200 |
| resp = self.client.delete("/tabstateview/" + str(tab_state_id)) |
| assert resp.status_code == 404 |
| |
| def test_get_access_denied(self): |
| new_role = Role(name="Dummy Role", permissions=[]) |
| db.session.add(new_role) |
| db.session.commit() |
| unauth_user = self.create_user( |
| "unauth_user1", |
| "password", |
| "Dummy Role", |
| email="unauth_user1@superset.org", # noqa: F541 |
| ) |
| self.login(username="unauth_user1", password="password") # noqa: S106 |
| rv = self.client.get("/api/v1/sqllab/") |
| |
| assert rv.status_code == 403 |
| |
| db.session.delete(unauth_user) |
| db.session.delete(new_role) |
| db.session.commit() |
| |
| def test_estimate_required_params(self): |
| self.login(ADMIN_USERNAME) |
| |
| rv = self.client.post( |
| "/api/v1/sqllab/estimate/", |
| json={}, |
| ) |
| failed_resp = { |
| "message": { |
| "sql": ["Missing data for required field."], |
| "database_id": ["Missing data for required field."], |
| } |
| } |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 |
| assert rv.status_code == 400 |
| |
| data = {"sql": "SELECT 1"} |
| rv = self.client.post( |
| "/api/v1/sqllab/estimate/", |
| json=data, |
| ) |
| failed_resp = {"message": {"database_id": ["Missing data for required field."]}} |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 |
| assert rv.status_code == 400 |
| |
| data = {"database_id": 1} |
| rv = self.client.post( |
| "/api/v1/sqllab/estimate/", |
| json=data, |
| ) |
| failed_resp = {"message": {"sql": ["Missing data for required field."]}} |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 |
| assert rv.status_code == 400 |
| |
| def test_estimate_valid_request(self): |
| self.login(ADMIN_USERNAME) |
| |
| formatter_response = [ |
| { |
| "value": 100, |
| } |
| ] |
| |
| db_mock = mock.Mock() |
| db_mock.db_engine_spec = mock.Mock() |
| db_mock.db_engine_spec.estimate_query_cost = mock.Mock(return_value=100) |
| db_mock.db_engine_spec.query_cost_formatter = mock.Mock( |
| return_value=formatter_response |
| ) |
| |
| with mock.patch("superset.commands.sql_lab.estimate.db") as mock_superset_db: |
| mock_superset_db.session.query().get.return_value = db_mock |
| |
| data = {"database_id": 1, "sql": "SELECT 1"} |
| rv = self.client.post( |
| "/api/v1/sqllab/estimate/", |
| json=data, |
| ) |
| |
| success_resp = {"result": formatter_response} |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, success_resp) # noqa: PT009 |
| assert rv.status_code == 200 |
| |
| def test_format_sql_request(self): |
| self.login(ADMIN_USERNAME) |
| |
| data = {"sql": "select 1 from my_table"} |
| rv = self.client.post( |
| "/api/v1/sqllab/format_sql/", |
| json=data, |
| ) |
| success_resp = {"result": "SELECT 1\nFROM my_table"} |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, success_resp) # noqa: PT009 |
| assert rv.status_code == 200 |
| |
| @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False) |
| def test_execute_required_params(self): |
| self.login(ADMIN_USERNAME) |
| client_id = f"{random.getrandbits(64)}"[:10] |
| |
| data = {"client_id": client_id} |
| rv = self.client.post( |
| "/api/v1/sqllab/execute/", |
| json=data, |
| ) |
| failed_resp = { |
| "message": { |
| "sql": ["Missing data for required field."], |
| "database_id": ["Missing data for required field."], |
| } |
| } |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 |
| assert rv.status_code == 400 |
| |
| data = {"sql": "SELECT 1", "client_id": client_id} |
| rv = self.client.post( |
| "/api/v1/sqllab/execute/", |
| json=data, |
| ) |
| failed_resp = {"message": {"database_id": ["Missing data for required field."]}} |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 |
| assert rv.status_code == 400 |
| |
| data = {"database_id": 1, "client_id": client_id} |
| rv = self.client.post( |
| "/api/v1/sqllab/execute/", |
| json=data, |
| ) |
| failed_resp = {"message": {"sql": ["Missing data for required field."]}} |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 |
| assert rv.status_code == 400 |
| |
| @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False) |
| def test_execute_valid_request(self) -> None: |
| from superset import sql_lab as core |
| |
| core.results_backend = mock.Mock() |
| core.results_backend.get.return_value = {} |
| |
| self.login(ADMIN_USERNAME) |
| client_id = f"{random.getrandbits(64)}"[:10] |
| |
| data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id} |
| rv = self.client.post( |
| "/api/v1/sqllab/execute/", |
| json=data, |
| ) |
| resp_data = json.loads(rv.data.decode("utf-8")) |
| assert resp_data.get("status") == "success" |
| assert rv.status_code == 200 |
| |
| @mock.patch( |
| "tests.integration_tests.superset_test_custom_template_processors.datetime" |
| ) |
| @mock.patch("superset.sqllab.api.get_sql_results") |
| def test_execute_custom_templated(self, sql_lab_mock, mock_dt) -> None: |
| mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1)) |
| self.login(ADMIN_USERNAME) |
| sql = "SELECT '$DATE()' as test" |
| resp = { |
| "status": QueryStatus.SUCCESS, |
| "query": {"rows": 1}, |
| "data": [{"test": "'1970-01-01'"}], |
| } |
| sql_lab_mock.return_value = resp |
| |
| dbobj = self.create_fake_db_for_macros() |
| json_payload = dict(database_id=dbobj.id, sql=sql) # noqa: C408 |
| self.get_json_resp( |
| "/api/v1/sqllab/execute/", raise_on_error=False, json_=json_payload |
| ) |
| assert sql_lab_mock.called |
| assert sql_lab_mock.call_args[0][1] == "SELECT '1970-01-01' as test" |
| |
| self.delete_fake_db_for_macros() |
| |
| @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False) |
| def test_get_results_with_display_limit(self): |
| from superset.commands.sql_lab import results as command |
| |
| command.results_backend = mock.Mock() |
| self.login(ADMIN_USERNAME) |
| |
| data = [{"col_0": i} for i in range(100)] |
| payload = { |
| "status": QueryStatus.SUCCESS, |
| "query": {"rows": 100}, |
| "data": data, |
| } |
| # limit results to 1 |
| expected_key = {"status": "success", "query": {"rows": 100}, "data": data} |
| limited_data = data[:1] |
| expected_limited = { |
| "status": "success", |
| "query": {"rows": 100}, |
| "data": limited_data, |
| "displayLimitReached": True, |
| } |
| |
| query_mock = mock.Mock() |
| query_mock.sql = "SELECT *" |
| query_mock.database = 1 |
| query_mock.schema = "superset" |
| |
| # do not apply msgpack serialization |
| use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"] |
| app.config["RESULTS_BACKEND_USE_MSGPACK"] = False |
| serialized_payload = sql_lab._serialize_payload(payload, False) |
| compressed = utils.zlib_compress(serialized_payload) |
| command.results_backend.get.return_value = compressed |
| |
| with mock.patch("superset.commands.sql_lab.results.db") as mock_superset_db: |
| mock_superset_db.session.query().filter_by().one_or_none.return_value = ( |
| query_mock |
| ) |
| # get all results |
| arguments = {"key": "key"} |
| result_key = json.loads( |
| self.get_resp(f"/api/v1/sqllab/results/?q={prison.dumps(arguments)}") |
| ) |
| arguments = {"key": "key", "rows": 1} |
| result_limited = json.loads( |
| self.get_resp(f"/api/v1/sqllab/results/?q={prison.dumps(arguments)}") |
| ) |
| |
| assert result_key == expected_key |
| assert result_limited == expected_limited |
| |
| app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack |
| |
| @mock.patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) # noqa: PT008 |
| @mock.patch("superset.models.core.Database.get_df") |
| def test_export_results(self, get_df_mock: mock.Mock) -> None: |
| self.login(ADMIN_USERNAME) |
| |
| database = get_example_database() |
| query_obj = Query( |
| client_id="test", |
| database=database, |
| tab_name="test_tab", |
| sql_editor_id="test_editor_id", |
| sql="select * from bar", |
| select_sql=None, |
| executed_sql="select * from bar limit 2", |
| limit=100, |
| select_as_cta=False, |
| rows=104, |
| error_message="none", |
| results_key="test_abc", |
| ) |
| |
| db.session.add(query_obj) |
| db.session.commit() |
| |
| get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) |
| |
| resp = self.get_resp("/api/v1/sqllab/export/test/") |
| data = csv.reader(io.StringIO(resp)) |
| expected_data = csv.reader(io.StringIO("foo\n1\n2")) |
| |
| assert list(expected_data) == list(data) |
| db.session.delete(query_obj) |
| db.session.commit() |